Refactor entire system to be more contextual so that conversation flow can be more easily managed
This commit is contained in:
parent
0d909edd44
commit
7f5e34e437
97
context.go
Normal file
97
context.go
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
package go_llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Context struct {
|
||||||
|
context.Context
|
||||||
|
request Request
|
||||||
|
response *ResponseChoice
|
||||||
|
toolcall *ToolCall
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) ToNewRequest(toolResults ...ToolCallResponse) Request {
|
||||||
|
var res Request
|
||||||
|
|
||||||
|
res.Toolbox = c.request.Toolbox
|
||||||
|
res.Temperature = c.request.Temperature
|
||||||
|
|
||||||
|
res.Conversation = make([]Input, len(c.request.Conversation))
|
||||||
|
copy(res.Conversation, c.request.Conversation)
|
||||||
|
|
||||||
|
// now for every input message, convert those to an Input to add to the conversation
|
||||||
|
for _, msg := range c.request.Messages {
|
||||||
|
res.Conversation = append(res.Conversation, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// if there are tool calls, then we need to add those to the conversation
|
||||||
|
if c.response != nil {
|
||||||
|
for _, call := range c.response.Calls {
|
||||||
|
res.Conversation = append(res.Conversation, call)
|
||||||
|
|
||||||
|
if c.response.Content != "" || c.response.Refusal != "" {
|
||||||
|
res.Conversation = append(res.Conversation, Message{
|
||||||
|
Role: RoleAssistant,
|
||||||
|
Text: c.response.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if there are tool results, then we need to add those to the conversation
|
||||||
|
for _, result := range toolResults {
|
||||||
|
res.Conversation = append(res.Conversation, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewContext(ctx context.Context, request Request, response *ResponseChoice, toolcall *ToolCall) *Context {
|
||||||
|
return &Context{Context: ctx, request: request, response: response, toolcall: toolcall}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) Request() Request {
|
||||||
|
return c.request
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) WithContext(ctx context.Context) *Context {
|
||||||
|
return &Context{Context: ctx, request: c.request, response: c.response, toolcall: c.toolcall}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) WithRequest(request Request) *Context {
|
||||||
|
return &Context{Context: c.Context, request: request, response: c.response, toolcall: c.toolcall}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) WithResponse(response *ResponseChoice) *Context {
|
||||||
|
return &Context{Context: c.Context, request: c.request, response: response, toolcall: c.toolcall}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) WithToolCall(toolcall *ToolCall) *Context {
|
||||||
|
return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: toolcall}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) Deadline() (deadline time.Time, ok bool) {
|
||||||
|
return c.Context.Deadline()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) Done() <-chan struct{} {
|
||||||
|
return c.Context.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) Err() error {
|
||||||
|
return c.Context.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) Value(key any) any {
|
||||||
|
if key == "request" {
|
||||||
|
return c.request
|
||||||
|
}
|
||||||
|
return c.Context.Value(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) WithTimeout(timeout time.Duration) (*Context, context.CancelFunc) {
|
||||||
|
ctx, cancel := context.WithTimeout(c.Context, timeout)
|
||||||
|
return c.WithContext(ctx), cancel
|
||||||
|
}
|
18
function.go
18
function.go
@ -31,7 +31,7 @@ type Function struct {
|
|||||||
definition *jsonschema.Definition
|
definition *jsonschema.Definition
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Function) Execute(ctx context.Context, input string) (string, error) {
|
func (f *Function) Execute(ctx *Context, input string) (string, error) {
|
||||||
if !f.fn.IsValid() {
|
if !f.fn.IsValid() {
|
||||||
return "", fmt.Errorf("function %s is not implemented", f.Name)
|
return "", fmt.Errorf("function %s is not implemented", f.Name)
|
||||||
}
|
}
|
||||||
@ -46,7 +46,7 @@ func (f *Function) Execute(ctx context.Context, input string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// now we can call the function
|
// now we can call the function
|
||||||
exec := func(ctx context.Context) (string, error) {
|
exec := func(ctx *Context) (string, error) {
|
||||||
out := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()})
|
out := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()})
|
||||||
|
|
||||||
if len(out) != 2 {
|
if len(out) != 2 {
|
||||||
@ -62,7 +62,7 @@ func (f *Function) Execute(ctx context.Context, input string) (string, error) {
|
|||||||
|
|
||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
if f.Timeout > 0 {
|
if f.Timeout > 0 {
|
||||||
ctx, cancel = context.WithTimeout(ctx, f.Timeout)
|
ctx, cancel = ctx.WithTimeout(f.Timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -90,3 +90,15 @@ type FunctionCall struct {
|
|||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
Arguments string `json:"arguments,omitempty"`
|
Arguments string `json:"arguments,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (fc *FunctionCall) toRaw() map[string]any {
|
||||||
|
res := map[string]interface{}{
|
||||||
|
"name": fc.Name,
|
||||||
|
}
|
||||||
|
|
||||||
|
if fc.Arguments != "" {
|
||||||
|
res["arguments"] = fc.Arguments
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package go_llm
|
package go_llm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
|
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
@ -13,7 +12,7 @@ import (
|
|||||||
// The struct parameters can have the following tags:
|
// The struct parameters can have the following tags:
|
||||||
// - Description: a string that describes the parameter, passed to openaiImpl to tell it what the parameter is for
|
// - Description: a string that describes the parameter, passed to openaiImpl to tell it what the parameter is for
|
||||||
|
|
||||||
func NewFunction[T any](name string, description string, fn func(context.Context, T) (string, error)) *Function {
|
func NewFunction[T any](name string, description string, fn func(*Context, T) (string, error)) *Function {
|
||||||
var o T
|
var o T
|
||||||
|
|
||||||
res := Function{
|
res := Function{
|
||||||
|
2
go.sum
2
go.sum
@ -37,8 +37,6 @@ github.com/liushuangls/go-anthropic/v2 v2.13.0 h1:f7KJ54IHxIpHPPhrCzs3SrdP2PfErX
|
|||||||
github.com/liushuangls/go-anthropic/v2 v2.13.0/go.mod h1:5ZwRLF5TQ+y5s/MC9Z1IJYx9WUFgQCKfqFM2xreIQLk=
|
github.com/liushuangls/go-anthropic/v2 v2.13.0/go.mod h1:5ZwRLF5TQ+y5s/MC9Z1IJYx9WUFgQCKfqFM2xreIQLk=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/sashabaranov/go-openai v1.36.0 h1:fcSrn8uGuorzPWCBp8L0aCR95Zjb/Dd+ZSML0YZy9EI=
|
|
||||||
github.com/sashabaranov/go-openai v1.36.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
|
||||||
github.com/sashabaranov/go-openai v1.36.1 h1:EVfRXwIlW2rUzpx6vR+aeIKCK/xylSrVYAx1TMTSX3g=
|
github.com/sashabaranov/go-openai v1.36.1 h1:EVfRXwIlW2rUzpx6vR+aeIKCK/xylSrVYAx1TMTSX3g=
|
||||||
github.com/sashabaranov/go-openai v1.36.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
github.com/sashabaranov/go-openai v1.36.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||||
|
141
llm.go
141
llm.go
@ -2,6 +2,7 @@ package go_llm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Role string
|
type Role string
|
||||||
@ -18,6 +19,26 @@ type Image struct {
|
|||||||
Url string
|
Url string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i Image) toRaw() map[string]any {
|
||||||
|
res := map[string]any{
|
||||||
|
"base64": i.Base64,
|
||||||
|
"contenttype": i.ContentType,
|
||||||
|
"url": i.Url,
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Image) fromRaw(raw map[string]any) Image {
|
||||||
|
var res Image
|
||||||
|
|
||||||
|
res.Base64 = raw["base64"].(string)
|
||||||
|
res.ContentType = raw["contenttype"].(string)
|
||||||
|
res.Url = raw["url"].(string)
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role Role
|
Role Role
|
||||||
Name string
|
Name string
|
||||||
@ -25,10 +46,66 @@ type Message struct {
|
|||||||
Images []Image
|
Images []Image
|
||||||
}
|
}
|
||||||
|
|
||||||
type Request struct {
|
func (m Message) toRaw() map[string]any {
|
||||||
Messages []Message
|
res := map[string]any{
|
||||||
Toolbox *ToolBox
|
"role": m.Role,
|
||||||
Temperature *float32
|
"name": m.Name,
|
||||||
|
"text": m.Text,
|
||||||
|
}
|
||||||
|
|
||||||
|
images := make([]map[string]any, 0, len(m.Images))
|
||||||
|
for _, img := range m.Images {
|
||||||
|
images = append(images, img.toRaw())
|
||||||
|
}
|
||||||
|
|
||||||
|
res["images"] = images
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) fromRaw(raw map[string]any) Message {
|
||||||
|
var res Message
|
||||||
|
|
||||||
|
res.Role = Role(raw["role"].(string))
|
||||||
|
res.Name = raw["name"].(string)
|
||||||
|
res.Text = raw["text"].(string)
|
||||||
|
|
||||||
|
images := raw["images"].([]map[string]any)
|
||||||
|
for _, img := range images {
|
||||||
|
var i Image
|
||||||
|
|
||||||
|
res.Images = append(res.Images, i.fromRaw(img))
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Message) toChatCompletionMessages() []openai.ChatCompletionMessage {
|
||||||
|
var res openai.ChatCompletionMessage
|
||||||
|
|
||||||
|
res.Role = string(m.Role)
|
||||||
|
res.Name = m.Name
|
||||||
|
res.Content = m.Text
|
||||||
|
|
||||||
|
for _, img := range m.Images {
|
||||||
|
if img.Base64 != "" {
|
||||||
|
res.MultiContent = append(res.MultiContent, openai.ChatMessagePart{
|
||||||
|
Type: "image_url",
|
||||||
|
ImageURL: &openai.ChatMessageImageURL{
|
||||||
|
URL: "data:" + img.ContentType + ";base64," + img.Base64,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
} else if img.Url != "" {
|
||||||
|
res.MultiContent = append(res.MultiContent, openai.ChatMessagePart{
|
||||||
|
Type: "image_url",
|
||||||
|
ImageURL: &openai.ChatMessageImageURL{
|
||||||
|
URL: img.Url,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return []openai.ChatCompletionMessage{res}
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolCall struct {
|
type ToolCall struct {
|
||||||
@ -36,16 +113,54 @@ type ToolCall struct {
|
|||||||
FunctionCall FunctionCall
|
FunctionCall FunctionCall
|
||||||
}
|
}
|
||||||
|
|
||||||
type ResponseChoice struct {
|
func (t ToolCall) toRaw() map[string]any {
|
||||||
Index int
|
res := map[string]any{
|
||||||
Role Role
|
"id": t.ID,
|
||||||
Content string
|
}
|
||||||
Refusal string
|
|
||||||
Name string
|
res["function"] = t.FunctionCall.toRaw()
|
||||||
Calls []ToolCall
|
|
||||||
|
return res
|
||||||
}
|
}
|
||||||
type Response struct {
|
|
||||||
Choices []ResponseChoice
|
func (t ToolCall) toChatCompletionMessages() []openai.ChatCompletionMessage {
|
||||||
|
return []openai.ChatCompletionMessage{{
|
||||||
|
Role: openai.ChatMessageRoleTool,
|
||||||
|
ToolCallID: t.ID,
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCallResponse struct {
|
||||||
|
ID string
|
||||||
|
Result string
|
||||||
|
Error error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t ToolCallResponse) toRaw() map[string]any {
|
||||||
|
res := map[string]any{
|
||||||
|
"id": t.ID,
|
||||||
|
"result": t.Result,
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.Error != nil {
|
||||||
|
res["error"] = t.Error.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t ToolCallResponse) toChatCompletionMessages() []openai.ChatCompletionMessage {
|
||||||
|
var refusal string
|
||||||
|
if t.Error != nil {
|
||||||
|
refusal = t.Error.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
return []openai.ChatCompletionMessage{{
|
||||||
|
Role: openai.ChatMessageRoleTool,
|
||||||
|
Content: t.Result,
|
||||||
|
Refusal: refusal,
|
||||||
|
ToolCallID: t.ID,
|
||||||
|
}}
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletion interface {
|
type ChatCompletion interface {
|
||||||
|
46
openai.go
46
openai.go
@ -3,6 +3,7 @@ package go_llm
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
oai "github.com/sashabaranov/go-openai"
|
oai "github.com/sashabaranov/go-openai"
|
||||||
@ -15,47 +16,17 @@ type openaiImpl struct {
|
|||||||
|
|
||||||
var _ LLM = openaiImpl{}
|
var _ LLM = openaiImpl{}
|
||||||
|
|
||||||
func (o openaiImpl) requestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
|
func (o openaiImpl) newRequestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
|
||||||
res := oai.ChatCompletionRequest{
|
res := oai.ChatCompletionRequest{
|
||||||
Model: o.model,
|
Model: o.model,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, i := range request.Conversation {
|
||||||
|
res.Messages = append(res.Messages, i.toChatCompletionMessages()...)
|
||||||
|
}
|
||||||
|
|
||||||
for _, msg := range request.Messages {
|
for _, msg := range request.Messages {
|
||||||
m := oai.ChatCompletionMessage{
|
res.Messages = append(res.Messages, msg.toChatCompletionMessages()...)
|
||||||
Content: msg.Text,
|
|
||||||
Role: string(msg.Role),
|
|
||||||
Name: msg.Name,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, img := range msg.Images {
|
|
||||||
if img.Base64 != "" {
|
|
||||||
m.MultiContent = append(m.MultiContent, oai.ChatMessagePart{
|
|
||||||
Type: "image_url",
|
|
||||||
ImageURL: &oai.ChatMessageImageURL{
|
|
||||||
URL: fmt.Sprintf("data:%s;base64,%s", img.ContentType, img.Base64),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
} else if img.Url != "" {
|
|
||||||
m.MultiContent = append(m.MultiContent, oai.ChatMessagePart{
|
|
||||||
Type: "image_url",
|
|
||||||
ImageURL: &oai.ChatMessageImageURL{
|
|
||||||
URL: img.Url,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// openai does not allow Content and MultiContent to be set at the same time, so we need to check
|
|
||||||
if len(m.MultiContent) > 0 && m.Content != "" {
|
|
||||||
m.MultiContent = append([]oai.ChatMessagePart{{
|
|
||||||
Type: "text",
|
|
||||||
Text: m.Content,
|
|
||||||
}}, m.MultiContent...)
|
|
||||||
|
|
||||||
m.Content = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
res.Messages = append(res.Messages, m)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if request.Toolbox != nil {
|
if request.Toolbox != nil {
|
||||||
@ -130,8 +101,9 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R
|
|||||||
func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) {
|
func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) {
|
||||||
cl := oai.NewClient(o.key)
|
cl := oai.NewClient(o.key)
|
||||||
|
|
||||||
req := o.requestToOpenAIRequest(request)
|
req := o.newRequestToOpenAIRequest(request)
|
||||||
|
|
||||||
|
slog.Info("openaiImpl.ChatComplete", "req", fmt.Sprintf("%#v", req))
|
||||||
resp, err := cl.CreateChatCompletion(ctx, req)
|
resp, err := cl.CreateChatCompletion(ctx, req)
|
||||||
|
|
||||||
fmt.Println("resp:", fmt.Sprintf("%#v", resp))
|
fmt.Println("resp:", fmt.Sprintf("%#v", resp))
|
||||||
|
54
request.go
Normal file
54
request.go
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
package go_llm
|
||||||
|
|
||||||
|
import "github.com/sashabaranov/go-openai"
|
||||||
|
|
||||||
|
type rawAble interface {
|
||||||
|
toRaw() map[string]any
|
||||||
|
fromRaw(raw map[string]any) Input
|
||||||
|
}
|
||||||
|
|
||||||
|
type Input interface {
|
||||||
|
toChatCompletionMessages() []openai.ChatCompletionMessage
|
||||||
|
}
|
||||||
|
type Request struct {
|
||||||
|
Conversation []Input
|
||||||
|
Messages []Message
|
||||||
|
Toolbox *ToolBox
|
||||||
|
Temperature *float32
|
||||||
|
}
|
||||||
|
|
||||||
|
// NextRequest will take the current request's conversation, messages, the response, and any tool results, and
|
||||||
|
// return a new request with the conversation updated to include the response and tool results.
|
||||||
|
func (req Request) NextRequest(resp ResponseChoice, toolResults []ToolCallResponse) Request {
|
||||||
|
var res Request
|
||||||
|
|
||||||
|
res.Toolbox = req.Toolbox
|
||||||
|
res.Temperature = req.Temperature
|
||||||
|
|
||||||
|
res.Conversation = make([]Input, len(req.Conversation))
|
||||||
|
copy(res.Conversation, req.Conversation)
|
||||||
|
|
||||||
|
// now for every input message, convert those to an Input to add to the conversation
|
||||||
|
for _, msg := range req.Messages {
|
||||||
|
res.Conversation = append(res.Conversation, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// if there are tool calls, then we need to add those to the conversation
|
||||||
|
for _, call := range resp.Calls {
|
||||||
|
res.Conversation = append(res.Conversation, call)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Content != "" || resp.Refusal != "" {
|
||||||
|
res.Conversation = append(res.Conversation, Message{
|
||||||
|
Role: RoleAssistant,
|
||||||
|
Text: resp.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// if there are tool results, then we need to add those to the conversation
|
||||||
|
for _, result := range toolResults {
|
||||||
|
res.Conversation = append(res.Conversation, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
70
response.go
Normal file
70
response.go
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
package go_llm
|
||||||
|
|
||||||
|
import "github.com/sashabaranov/go-openai"
|
||||||
|
|
||||||
|
type ResponseChoice struct {
|
||||||
|
Index int
|
||||||
|
Role Role
|
||||||
|
Content string
|
||||||
|
Refusal string
|
||||||
|
Name string
|
||||||
|
Calls []ToolCall
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r ResponseChoice) toRaw() map[string]any {
|
||||||
|
res := map[string]any{
|
||||||
|
"index": r.Index,
|
||||||
|
"role": r.Role,
|
||||||
|
"content": r.Content,
|
||||||
|
"refusal": r.Refusal,
|
||||||
|
"name": r.Name,
|
||||||
|
}
|
||||||
|
|
||||||
|
calls := make([]map[string]any, 0, len(r.Calls))
|
||||||
|
for _, call := range r.Calls {
|
||||||
|
calls = append(calls, call.toRaw())
|
||||||
|
}
|
||||||
|
|
||||||
|
res["calls"] = calls
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r ResponseChoice) toChatCompletionMessages() []openai.ChatCompletionMessage {
|
||||||
|
var res []openai.ChatCompletionMessage
|
||||||
|
|
||||||
|
for _, call := range r.Calls {
|
||||||
|
res = append(res, call.toChatCompletionMessages()...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Refusal != "" || r.Content != "" {
|
||||||
|
res = append(res, openai.ChatCompletionMessage{
|
||||||
|
Role: openai.ChatMessageRoleAssistant,
|
||||||
|
Content: r.Content,
|
||||||
|
Refusal: r.Refusal,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r ResponseChoice) toInput() []Input {
|
||||||
|
var res []Input
|
||||||
|
|
||||||
|
for _, call := range r.Calls {
|
||||||
|
res = append(res, call)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Content != "" || r.Refusal != "" {
|
||||||
|
res = append(res, Message{
|
||||||
|
Role: RoleAssistant,
|
||||||
|
Text: r.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
type Response struct {
|
||||||
|
Choices []ResponseChoice
|
||||||
|
}
|
@ -1,7 +1,6 @@
|
|||||||
package go_llm
|
package go_llm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
@ -64,7 +63,7 @@ var (
|
|||||||
ErrFunctionNotFound = errors.New("function not found")
|
ErrFunctionNotFound = errors.New("function not found")
|
||||||
)
|
)
|
||||||
|
|
||||||
func (t *ToolBox) ExecuteFunction(ctx context.Context, functionName string, params string) (string, error) {
|
func (t *ToolBox) executeFunction(ctx *Context, functionName string, params string) (string, error) {
|
||||||
f, ok := t.names[functionName]
|
f, ok := t.names[functionName]
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -74,6 +73,6 @@ func (t *ToolBox) ExecuteFunction(ctx context.Context, functionName string, para
|
|||||||
return f.Execute(ctx, params)
|
return f.Execute(ctx, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ToolBox) Execute(ctx context.Context, toolCall ToolCall) (string, error) {
|
func (t *ToolBox) Execute(ctx *Context, toolCall ToolCall) (string, error) {
|
||||||
return t.ExecuteFunction(ctx, toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments)
|
return t.executeFunction(ctx.WithToolCall(&toolCall), toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user