diff --git a/anthropic.go b/anthropic.go index 90ebdf5..aa81a90 100644 --- a/anthropic.go +++ b/anthropic.go @@ -134,15 +134,16 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest { } } - if req.Toolbox != nil { - for _, tool := range req.Toolbox.funcs { + /* + for _, tool := range req.Toolbox.functions { res.Tools = append(res.Tools, anth.ToolDefinition{ Name: tool.Name, Description: tool.Description, - InputSchema: tool.Parameters, + InputSchema: tool.Parameters.OpenAIParameters(), }) } - } + + */ res.Messages = msgs diff --git a/context.go b/context.go index 8320269..0f8a838 100644 --- a/context.go +++ b/context.go @@ -7,9 +7,10 @@ import ( type Context struct { context.Context - request Request - response *ResponseChoice - toolcall *ToolCall + request Request + response *ResponseChoice + toolcall *ToolCall + syntheticFields map[string]string } func (c *Context) ToNewRequest(toolResults ...ToolCallResponse) Request { @@ -55,20 +56,32 @@ func (c *Context) ToolCall() *ToolCall { return c.toolcall } +func (c *Context) SyntheticFields() map[string]string { + if c.syntheticFields == nil { + c.syntheticFields = map[string]string{} + } + + return c.syntheticFields +} + func (c *Context) WithContext(ctx context.Context) *Context { - return &Context{Context: ctx, request: c.request, response: c.response, toolcall: c.toolcall} + return &Context{Context: ctx, request: c.request, response: c.response, toolcall: c.toolcall, syntheticFields: c.syntheticFields} } func (c *Context) WithRequest(request Request) *Context { - return &Context{Context: c.Context, request: request, response: c.response, toolcall: c.toolcall} + return &Context{Context: c.Context, request: request, response: c.response, toolcall: c.toolcall, syntheticFields: c.syntheticFields} } func (c *Context) WithResponse(response *ResponseChoice) *Context { - return &Context{Context: c.Context, request: c.request, response: response, toolcall: c.toolcall} + return &Context{Context: c.Context, request: c.request, response: response, toolcall: c.toolcall, syntheticFields: c.syntheticFields} } func (c *Context) WithToolCall(toolcall *ToolCall) *Context { - return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: toolcall} + return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: toolcall, syntheticFields: c.syntheticFields} +} + +func (c *Context) WithSyntheticFields(syntheticFields map[string]string) *Context { + return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: c.toolcall, syntheticFields: syntheticFields} } func (c *Context) Deadline() (deadline time.Time, ok bool) { @@ -84,8 +97,19 @@ func (c *Context) Err() error { } func (c *Context) Value(key any) any { - if key == "request" { + switch key { + case "request": return c.request + + case "response": + return c.response + + case "toolcall": + return c.toolcall + + case "syntheticFields": + return c.syntheticFields + } return c.Context.Value(key) } diff --git a/function.go b/function.go index 7ffe6b1..9bdd17a 100644 --- a/function.go +++ b/function.go @@ -4,11 +4,10 @@ import ( "context" "encoding/json" "fmt" + "log/slog" "reflect" "time" - "github.com/sashabaranov/go-openai" - "gitea.stevedudenhoeffer.com/steve/go-llm/schema" ) @@ -29,16 +28,63 @@ type Function struct { paramType reflect.Type } -func (f *Function) Execute(ctx *Context, input string) (any, error) { +func (f Function) WithSyntheticField(name string, description string) Function { + if obj, o := f.Parameters.(schema.Object); o { + f.Parameters = obj.WithSyntheticField(name, description) + } + + return f +} + +func (f Function) WithSyntheticFields(fieldsAndDescriptions map[string]string) Function { + if obj, o := f.Parameters.(schema.Object); o { + for k, v := range fieldsAndDescriptions { + obj = obj.WithSyntheticField(k, v) + } + f.Parameters = obj + } + + return f +} + +func (f Function) Execute(ctx *Context, input string) (any, error) { if !f.fn.IsValid() { return "", fmt.Errorf("function %s is not implemented", f.Name) } + slog.Info("Function.Execute", "name", f.Name, "input", input, "f", f.paramType) // first, we need to parse the input into the struct p := reflect.New(f.paramType) fmt.Println("Function.Execute", f.Name, "input:", input) - //m := map[string]any{} - err := json.Unmarshal([]byte(input), p.Interface()) + + var vals map[string]any + err := json.Unmarshal([]byte(input), &vals) + + var syntheticFields map[string]string + + // first eat up any synthetic fields + if obj, o := f.Parameters.(schema.Object); o { + for k := range obj.SyntheticFields() { + key := schema.SyntheticFieldPrefix + k + if val, ok := vals[key]; ok { + if syntheticFields == nil { + syntheticFields = map[string]string{} + } + + syntheticFields[k] = fmt.Sprint(val) + delete(vals, key) + } + } + } + + // now for any remaining fields, re-marshal them into json and then unmarshal into the struct + b, err := json.Marshal(vals) + if err != nil { + return "", fmt.Errorf("failed to marshal input: %w (input: %s)", err, input) + } + + // now we can unmarshal the input into the struct + err = json.Unmarshal(b, p.Interface()) if err != nil { return "", fmt.Errorf("failed to unmarshal input: %w (input: %s)", err, input) } @@ -67,15 +113,6 @@ func (f *Function) Execute(ctx *Context, input string) (any, error) { return exec(ctx) } -func (f *Function) toOpenAIFunction() *openai.FunctionDefinition { - return &openai.FunctionDefinition{ - Name: f.Name, - Description: f.Description, - Strict: f.Strict, - Parameters: f.Parameters, - } -} - type FunctionCall struct { Name string `json:"name,omitempty"` Arguments string `json:"arguments,omitempty"` diff --git a/functions.go b/functions.go index e1b7ccc..71a0a1c 100644 --- a/functions.go +++ b/functions.go @@ -13,7 +13,7 @@ import ( // The struct parameters can have the following tags: // - Description: a string that describes the parameter, passed to openaiImpl to tell it what the parameter is for -func NewFunction[T any](name string, description string, fn func(*Context, T) (any, error)) *Function { +func NewFunction[T any](name string, description string, fn func(*Context, T) (any, error)) Function { var o T res := Function{ @@ -31,5 +31,5 @@ func NewFunction[T any](name string, description string, fn func(*Context, T) (a panic("function parameter must be a struct") } - return &res + return res } diff --git a/go.mod b/go.mod index 8454d65..e9596d7 100644 --- a/go.mod +++ b/go.mod @@ -5,8 +5,7 @@ go 1.23.1 require ( github.com/google/generative-ai-go v0.19.0 github.com/liushuangls/go-anthropic/v2 v2.15.0 - github.com/openai/openai-go v0.1.0-beta.6 - github.com/sashabaranov/go-openai v1.38.1 + github.com/openai/openai-go v0.1.0-beta.9 google.golang.org/api v0.228.0 ) @@ -35,14 +34,14 @@ require ( go.opentelemetry.io/otel/metric v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect golang.org/x/crypto v0.37.0 // indirect - golang.org/x/net v0.38.0 // indirect + golang.org/x/net v0.39.0 // indirect golang.org/x/oauth2 v0.29.0 // indirect golang.org/x/sync v0.13.0 // indirect golang.org/x/sys v0.32.0 // indirect golang.org/x/text v0.24.0 // indirect golang.org/x/time v0.11.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20250404141209-ee84b53bf3d0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20250404141209-ee84b53bf3d0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a // indirect google.golang.org/grpc v1.71.1 // indirect google.golang.org/protobuf v1.36.6 // indirect ) diff --git a/go.sum b/go.sum index 869d1a3..331f8b6 100644 --- a/go.sum +++ b/go.sum @@ -37,10 +37,10 @@ github.com/liushuangls/go-anthropic/v2 v2.15.0 h1:zpplg7BRV/9FlMmeMPI0eDwhViB0l9 github.com/liushuangls/go-anthropic/v2 v2.15.0/go.mod h1:kq2yW3JVy1/rph8u5KzX7F3q95CEpCT2RXp/2nfCmb4= github.com/openai/openai-go v0.1.0-beta.6 h1:JquYDpprfrGnlKvQQg+apy9dQ8R9mIrm+wNvAPp6jCQ= github.com/openai/openai-go v0.1.0-beta.6/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= +github.com/openai/openai-go v0.1.0-beta.9 h1:ABpubc5yU/3ejee2GgRrbFta81SG/d7bQbB8mIdP0Xo= +github.com/openai/openai-go v0.1.0-beta.9/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/sashabaranov/go-openai v1.38.1 h1:TtZabbFQZa1nEni/IhVtDF/WQjVqDgd+cWR5OeddzF8= -github.com/sashabaranov/go-openai v1.38.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -73,6 +73,8 @@ golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= +golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= golang.org/x/oauth2 v0.29.0 h1:WdYw2tdTK1S8olAzWHdgeqfy+Mtm9XNhv/xJsY65d98= golang.org/x/oauth2 v0.29.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= @@ -87,8 +89,12 @@ google.golang.org/api v0.228.0 h1:X2DJ/uoWGnY5obVjewbp8icSL5U4FzuCfy9OjbLSnLs= google.golang.org/api v0.228.0/go.mod h1:wNvRS1Pbe8r4+IfBIniV8fwCpGwTrYa+kMUDiC5z5a4= google.golang.org/genproto/googleapis/api v0.0.0-20250404141209-ee84b53bf3d0 h1:Qbb5RVn5xzI4naMJSpJ7lhvmos6UwZkbekd5Uz7rt9E= google.golang.org/genproto/googleapis/api v0.0.0-20250404141209-ee84b53bf3d0/go.mod h1:6T35kB3IPpdw7Wul09by0G/JuOuIFkXV6OOvt8IZeT8= +google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a h1:OQ7sHVzkx6L57dQpzUS4ckfWJ51KDH74XHTDe23xWAs= +google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a/go.mod h1:2R6XrVC8Oc08GlNh8ujEpc7HkLiEZ16QeY7FxIs20ac= google.golang.org/genproto/googleapis/rpc v0.0.0-20250404141209-ee84b53bf3d0 h1:0K7wTWyzxZ7J+L47+LbFogJW1nn/gnnMCN0vGXNYtTI= google.golang.org/genproto/googleapis/rpc v0.0.0-20250404141209-ee84b53bf3d0/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a h1:GIqLhp/cYUkuGuiT+vJk8vhOP86L4+SP5j8yXgeVpvI= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= google.golang.org/grpc v1.71.1 h1:ffsFWr7ygTUscGPI0KKK6TLrGz0476KUvvsbqWK0rPI= google.golang.org/grpc v1.71.1/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= diff --git a/google.go b/google.go index 8eea66f..b371978 100644 --- a/google.go +++ b/google.go @@ -23,25 +23,22 @@ func (g google) ModelVersion(modelVersion string) (ChatCompletion, error) { func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (*genai.GenerativeModel, *genai.ChatSession, []genai.Part) { res := *model - if in.Toolbox != nil { - for _, tool := range in.Toolbox.funcs { - - res.Tools = append(res.Tools, &genai.Tool{ - FunctionDeclarations: []*genai.FunctionDeclaration{ - { - Name: tool.Name, - Description: tool.Description, - Parameters: tool.Parameters.GoogleParameters(), - }, + for _, tool := range in.Toolbox.functions { + res.Tools = append(res.Tools, &genai.Tool{ + FunctionDeclarations: []*genai.FunctionDeclaration{ + { + Name: tool.Name, + Description: tool.Description, + Parameters: tool.Parameters.GoogleParameters(), }, - }) - } + }, + }) + } - if !in.Toolbox.dontRequireTool { - res.ToolConfig = &genai.ToolConfig{FunctionCallingConfig: &genai.FunctionCallingConfig{ - Mode: genai.FunctionCallingAny, - }} - } + if !in.Toolbox.RequiresTool() { + res.ToolConfig = &genai.ToolConfig{FunctionCallingConfig: &genai.FunctionCallingConfig{ + Mode: genai.FunctionCallingAny, + }} } cs := res.StartChat() diff --git a/openai.go b/openai.go index 80c6d40..929ab6c 100644 --- a/openai.go +++ b/openai.go @@ -31,23 +31,21 @@ func (o openaiImpl) newRequestToOpenAIRequest(request Request) openai.ChatComple res.Messages = append(res.Messages, msg.toChatCompletionMessages(o.model)...) } - if request.Toolbox != nil { - for _, tool := range request.Toolbox.funcs { - res.Tools = append(res.Tools, openai.ChatCompletionToolParam{ - Type: "function", - Function: shared.FunctionDefinitionParam{ - Name: tool.Name, - Description: openai.String(tool.Description), - Strict: openai.Bool(tool.Strict), - Parameters: tool.Parameters.OpenAIParameters(), - }, - }) - } + for _, tool := range request.Toolbox.functions { + res.Tools = append(res.Tools, openai.ChatCompletionToolParam{ + Type: "function", + Function: shared.FunctionDefinitionParam{ + Name: tool.Name, + Description: openai.String(tool.Description), + Strict: openai.Bool(tool.Strict), + Parameters: tool.Parameters.OpenAIParameters(), + }, + }) + } - if !request.Toolbox.dontRequireTool { - res.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{ - OfAuto: openai.String("required"), - } + if request.Toolbox.RequiresTool() { + res.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{ + OfAuto: openai.String("required"), } } diff --git a/request.go b/request.go index 961fae7..1a39e33 100644 --- a/request.go +++ b/request.go @@ -15,7 +15,7 @@ type Input interface { type Request struct { Conversation []Input Messages []Message - Toolbox *ToolBox + Toolbox ToolBox Temperature *float64 } diff --git a/schema/GetType.go b/schema/GetType.go index 4aa365b..342dea2 100644 --- a/schema/GetType.go +++ b/schema/GetType.go @@ -25,27 +25,27 @@ func getFromType(t reflect.Type, b basic) Type { switch t.Kind() { case reflect.String: - b.DataType = String + b.DataType = TypeString b.typeName = "string" return b case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - b.DataType = Integer + b.DataType = TypeInteger b.typeName = "integer" return b case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - b.DataType = Integer + b.DataType = TypeInteger b.typeName = "integer" return b case reflect.Float32, reflect.Float64: - b.DataType = Number + b.DataType = TypeNumber b.typeName = "number" return b case reflect.Bool: - b.DataType = Boolean + b.DataType = TypeBoolean b.typeName = "boolean" return b @@ -92,7 +92,7 @@ func getField(f reflect.StructField, index int) Type { } } - b.DataType = String + b.DataType = TypeString b.typeName = "string" return enum{ basic: b, @@ -104,15 +104,26 @@ func getField(f reflect.StructField, index int) Type { return getFromType(t, b) } -func getObject(t reflect.Type) object { +func getObject(t reflect.Type) Object { fields := make(map[string]Type, t.NumField()) for i := 0; i < t.NumField(); i++ { field := t.Field(i) - fields[field.Name] = getField(field, i) + + if field.Anonymous { + // if the field is anonymous, we need to get the fields of the anonymous struct + // and add them to the object + anon := getObject(field.Type) + for k, v := range anon.fields { + fields[k] = v + } + continue + } else { + fields[field.Name] = getField(field, i) + } } - return object{ - basic: basic{DataType: Object, typeName: "object"}, + return Object{ + basic: basic{DataType: TypeObject, typeName: "object"}, fields: fields, } } @@ -120,7 +131,7 @@ func getObject(t reflect.Type) object { func getArray(t reflect.Type) array { res := array{ basic: basic{ - DataType: Array, + DataType: TypeArray, typeName: "array", }, } diff --git a/schema/basic.go b/schema/basic.go index e7f824c..087b794 100644 --- a/schema/basic.go +++ b/schema/basic.go @@ -15,12 +15,12 @@ var _ Type = basic{} type DataType string const ( - String DataType = "string" - Integer DataType = "integer" - Number DataType = "number" - Boolean DataType = "boolean" - Object DataType = "object" - Array DataType = "array" + TypeString DataType = "string" + TypeInteger DataType = "integer" + TypeNumber DataType = "number" + TypeBoolean DataType = "boolean" + TypeObject DataType = "object" + TypeArray DataType = "array" ) type basic struct { @@ -49,17 +49,17 @@ func (b basic) GoogleParameters() *genai.Schema { var t = genai.TypeUnspecified switch b.DataType { - case String: + case TypeString: t = genai.TypeString - case Integer: + case TypeInteger: t = genai.TypeInteger - case Number: + case TypeNumber: t = genai.TypeNumber - case Boolean: + case TypeBoolean: t = genai.TypeBoolean - case Object: + case TypeObject: t = genai.TypeObject - case Array: + case TypeArray: t = genai.TypeArray default: t = genai.TypeUnspecified @@ -82,12 +82,12 @@ func (b basic) FromAny(val any) (reflect.Value, error) { v := reflect.ValueOf(val) switch b.DataType { - case String: + case TypeString: var val = v.String() return reflect.ValueOf(val), nil - case Integer: + case TypeInteger: if v.Kind() == reflect.Float64 { return v.Convert(reflect.TypeOf(int(0))), nil } else if v.Kind() != reflect.Int { @@ -96,7 +96,7 @@ func (b basic) FromAny(val any) (reflect.Value, error) { return v, nil } - case Number: + case TypeNumber: if v.Kind() == reflect.Float64 { return v.Convert(reflect.TypeOf(float64(0))), nil } else if v.Kind() != reflect.Float64 { @@ -105,7 +105,7 @@ func (b basic) FromAny(val any) (reflect.Value, error) { return v, nil } - case Boolean: + case TypeBoolean: if v.Kind() == reflect.Bool { return v, nil } else if v.Kind() == reflect.String { diff --git a/schema/object.go b/schema/object.go index 3dd733d..7e96c44 100644 --- a/schema/object.go +++ b/schema/object.go @@ -8,15 +8,44 @@ import ( "github.com/openai/openai-go" ) -type object struct { +const ( + // SyntheticFieldPrefix is any prefix that is added to any synthetic fields that are added to the object, to prevent + // collisions with the fields in the struct. + SyntheticFieldPrefix = "__" +) + +type Object struct { basic ref reflect.Type fields map[string]Type + + // syntheticFields are fields that are not in the struct but are generated by a system. + synetheticFields map[string]Type } -func (o object) OpenAIParameters() openai.FunctionParameters { +func (o Object) WithSyntheticField(name string, description string) Object { + if o.synetheticFields == nil { + o.synetheticFields = map[string]Type{} + } + + o.synetheticFields[name] = basic{ + DataType: TypeString, + typeName: "string", + index: -1, + required: false, + description: description, + } + + return o +} + +func (o Object) SyntheticFields() map[string]Type { + return o.synetheticFields +} + +func (o Object) OpenAIParameters() openai.FunctionParameters { var properties = map[string]openai.FunctionParameters{} var required []string for k, v := range o.fields { @@ -26,6 +55,13 @@ func (o object) OpenAIParameters() openai.FunctionParameters { } } + for k, v := range o.synetheticFields { + properties[SyntheticFieldPrefix+k] = v.OpenAIParameters() + if v.Required() { + required = append(required, SyntheticFieldPrefix+k) + } + } + var res = openai.FunctionParameters{ "type": "object", "description": o.Description(), @@ -39,7 +75,7 @@ func (o object) OpenAIParameters() openai.FunctionParameters { return res } -func (o object) GoogleParameters() *genai.Schema { +func (o Object) GoogleParameters() *genai.Schema { var properties = map[string]*genai.Schema{} var required []string for k, v := range o.fields { @@ -62,7 +98,8 @@ func (o object) GoogleParameters() *genai.Schema { return res } -func (o object) FromAny(val any) (reflect.Value, error) { +// FromAny converts the value from any to the correct type, returning the value, and an error if any +func (o Object) FromAny(val any) (reflect.Value, error) { // if the value is nil, we can't do anything if val == nil { return reflect.Value{}, nil @@ -99,7 +136,7 @@ func (o object) FromAny(val any) (reflect.Value, error) { return obj, nil } -func (o object) SetValueOnField(obj reflect.Value, val reflect.Value) { +func (o Object) SetValueOnField(obj reflect.Value, val reflect.Value) { // if this basic type is not required that means it's a pointer type so we need to set the value to the address of the value if !o.required { val = val.Addr() diff --git a/toolbox.go b/toolbox.go index 61fb44e..6737362 100644 --- a/toolbox.go +++ b/toolbox.go @@ -4,79 +4,82 @@ import ( "context" "errors" "fmt" - - "github.com/sashabaranov/go-openai" ) // ToolBox is a collection of tools that OpenAI can use to execute functions. // It is a wrapper around a collection of functions, and provides a way to automatically call the correct function with // the correct parameters. type ToolBox struct { - funcs []Function - names map[string]Function + functions map[string]Function dontRequireTool bool } -func NewToolBox(fns ...*Function) *ToolBox { +func NewToolBox(fns ...Function) ToolBox { res := ToolBox{ - funcs: []Function{}, - names: map[string]Function{}, + functions: map[string]Function{}, } for _, f := range fns { - o := *f - res.names[o.Name] = o - res.funcs = append(res.funcs, o) - } - - return &res -} - -func (t *ToolBox) WithFunction(f Function) *ToolBox { - t2 := *t - t2.names[f.Name] = f - t2.funcs = append(t2.funcs, f) - - return &t2 -} - -func (t *ToolBox) WithFunctionRemoved(name string) *ToolBox { - t2 := *t - - delete(t2.names, name) - - for i, f := range t2.funcs { - if f.Name == name { - t2.funcs = append(t2.funcs[:i], t2.funcs[i+1:]...) - break - } - } - - return &t2 -} - -func (t *ToolBox) WithRequireTool(val bool) *ToolBox { - t2 := *t - t2.dontRequireTool = !val - return &t2 -} - -// ToOpenAI will convert the current ToolBox to a slice of openai.Tool, which can be used to send to the OpenAI API. -func (t *ToolBox) toOpenAI() []openai.Tool { - var res []openai.Tool - - for _, f := range t.funcs { - res = append(res, openai.Tool{ - Type: "function", - Function: f.toOpenAIFunction(), - }) + res.functions[f.Name] = f } return res } -func (t *ToolBox) ToToolChoice() any { - if len(t.funcs) == 0 { +func (t ToolBox) Functions() []Function { + var res []Function + + for _, f := range t.functions { + res = append(res, f) + } + + return res +} + +func (t ToolBox) WithFunction(f Function) ToolBox { + t.functions[f.Name] = f + + return t +} + +func (t ToolBox) WithFunctions(fns ...Function) ToolBox { + for _, f := range fns { + t.functions[f.Name] = f + } + + return t +} + +func (t ToolBox) WithSyntheticFieldsAddedToAllFunctions(fieldsAndDescriptions map[string]string) ToolBox { + for k, v := range t.functions { + t.functions[k] = v.WithSyntheticFields(fieldsAndDescriptions) + } + + return t +} + +func (t ToolBox) ForEachFunction(fn func(f Function)) { + for _, f := range t.functions { + fn(f) + } +} + +func (t ToolBox) WithFunctionRemoved(name string) ToolBox { + delete(t.functions, name) + return t +} + +func (t ToolBox) WithRequireTool(val bool) ToolBox { + t.dontRequireTool = !val + return t +} + +func (t ToolBox) RequiresTool() bool { + return !t.dontRequireTool && len(t.functions) > 0 +} + +func (t ToolBox) ToToolChoice() any { + if len(t.functions) == 0 { return nil } @@ -87,8 +90,8 @@ var ( ErrFunctionNotFound = errors.New("function not found") ) -func (t *ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) { - f, ok := t.names[functionName] +func (t ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) { + f, ok := t.functions[functionName] if !ok { return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName)) @@ -97,14 +100,29 @@ func (t *ToolBox) executeFunction(ctx *Context, functionName string, params stri return f.Execute(ctx, params) } -func (t *ToolBox) Execute(ctx *Context, toolCall ToolCall) (any, error) { +func (t ToolBox) Execute(ctx *Context, toolCall ToolCall) (any, error) { return t.executeFunction(ctx.WithToolCall(&toolCall), toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments) } +func (t ToolBox) GetSyntheticParametersFromFunctionContext(ctx context.Context) map[string]string { + val := ctx.Value("syntheticParameters") + + if val == nil { + return nil + } + + syntheticParameters, ok := val.(map[string]string) + if !ok { + return nil + } + + return syntheticParameters +} + // ExecuteCallbacks will execute all the tool calls in the given list, and call the given callbacks when a new function is created, and when a function is finished. // OnNewFunction is called when a new function is created // OnFunctionFinished is called when a function is finished -func (t *ToolBox) ExecuteCallbacks(ctx *Context, toolCalls []ToolCall, OnNewFunction func(ctx context.Context, funcName string, parameter string) (any, error), OnFunctionFinished func(ctx context.Context, funcName string, parameter string, result any, err error, newFunctionResult any) error) ([]ToolCallResponse, error) { +func (t ToolBox) ExecuteCallbacks(ctx *Context, toolCalls []ToolCall, OnNewFunction func(ctx context.Context, funcName string, parameter string) (any, error), OnFunctionFinished func(ctx context.Context, funcName string, parameter string, result any, err error, newFunctionResult any) error) ([]ToolCallResponse, error) { var res []ToolCallResponse for _, call := range toolCalls {