Refactor toolbox and function handling to support synthetic fields and improve type definitions
This commit is contained in:
		| @@ -134,15 +134,16 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if req.Toolbox != nil { | 	/* | ||||||
| 		for _, tool := range req.Toolbox.funcs { | 		for _, tool := range req.Toolbox.functions { | ||||||
| 			res.Tools = append(res.Tools, anth.ToolDefinition{ | 			res.Tools = append(res.Tools, anth.ToolDefinition{ | ||||||
| 				Name:        tool.Name, | 				Name:        tool.Name, | ||||||
| 				Description: tool.Description, | 				Description: tool.Description, | ||||||
| 				InputSchema: tool.Parameters, | 				InputSchema: tool.Parameters.OpenAIParameters(), | ||||||
| 			}) | 			}) | ||||||
| 		} | 		} | ||||||
| 	} |  | ||||||
|  | 	*/ | ||||||
|  |  | ||||||
| 	res.Messages = msgs | 	res.Messages = msgs | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										34
									
								
								context.go
									
									
									
									
									
								
							
							
						
						
									
										34
									
								
								context.go
									
									
									
									
									
								
							| @@ -10,6 +10,7 @@ type Context struct { | |||||||
| 	request         Request | 	request         Request | ||||||
| 	response        *ResponseChoice | 	response        *ResponseChoice | ||||||
| 	toolcall        *ToolCall | 	toolcall        *ToolCall | ||||||
|  | 	syntheticFields map[string]string | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *Context) ToNewRequest(toolResults ...ToolCallResponse) Request { | func (c *Context) ToNewRequest(toolResults ...ToolCallResponse) Request { | ||||||
| @@ -55,20 +56,32 @@ func (c *Context) ToolCall() *ToolCall { | |||||||
| 	return c.toolcall | 	return c.toolcall | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (c *Context) SyntheticFields() map[string]string { | ||||||
|  | 	if c.syntheticFields == nil { | ||||||
|  | 		c.syntheticFields = map[string]string{} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return c.syntheticFields | ||||||
|  | } | ||||||
|  |  | ||||||
| func (c *Context) WithContext(ctx context.Context) *Context { | func (c *Context) WithContext(ctx context.Context) *Context { | ||||||
| 	return &Context{Context: ctx, request: c.request, response: c.response, toolcall: c.toolcall} | 	return &Context{Context: ctx, request: c.request, response: c.response, toolcall: c.toolcall, syntheticFields: c.syntheticFields} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *Context) WithRequest(request Request) *Context { | func (c *Context) WithRequest(request Request) *Context { | ||||||
| 	return &Context{Context: c.Context, request: request, response: c.response, toolcall: c.toolcall} | 	return &Context{Context: c.Context, request: request, response: c.response, toolcall: c.toolcall, syntheticFields: c.syntheticFields} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *Context) WithResponse(response *ResponseChoice) *Context { | func (c *Context) WithResponse(response *ResponseChoice) *Context { | ||||||
| 	return &Context{Context: c.Context, request: c.request, response: response, toolcall: c.toolcall} | 	return &Context{Context: c.Context, request: c.request, response: response, toolcall: c.toolcall, syntheticFields: c.syntheticFields} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *Context) WithToolCall(toolcall *ToolCall) *Context { | func (c *Context) WithToolCall(toolcall *ToolCall) *Context { | ||||||
| 	return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: toolcall} | 	return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: toolcall, syntheticFields: c.syntheticFields} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Context) WithSyntheticFields(syntheticFields map[string]string) *Context { | ||||||
|  | 	return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: c.toolcall, syntheticFields: syntheticFields} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *Context) Deadline() (deadline time.Time, ok bool) { | func (c *Context) Deadline() (deadline time.Time, ok bool) { | ||||||
| @@ -84,8 +97,19 @@ func (c *Context) Err() error { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (c *Context) Value(key any) any { | func (c *Context) Value(key any) any { | ||||||
| 	if key == "request" { | 	switch key { | ||||||
|  | 	case "request": | ||||||
| 		return c.request | 		return c.request | ||||||
|  |  | ||||||
|  | 	case "response": | ||||||
|  | 		return c.response | ||||||
|  |  | ||||||
|  | 	case "toolcall": | ||||||
|  | 		return c.toolcall | ||||||
|  |  | ||||||
|  | 	case "syntheticFields": | ||||||
|  | 		return c.syntheticFields | ||||||
|  |  | ||||||
| 	} | 	} | ||||||
| 	return c.Context.Value(key) | 	return c.Context.Value(key) | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										65
									
								
								function.go
									
									
									
									
									
								
							
							
						
						
									
										65
									
								
								function.go
									
									
									
									
									
								
							| @@ -4,11 +4,10 @@ import ( | |||||||
| 	"context" | 	"context" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"log/slog" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/sashabaranov/go-openai" |  | ||||||
|  |  | ||||||
| 	"gitea.stevedudenhoeffer.com/steve/go-llm/schema" | 	"gitea.stevedudenhoeffer.com/steve/go-llm/schema" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -29,16 +28,63 @@ type Function struct { | |||||||
| 	paramType reflect.Type | 	paramType reflect.Type | ||||||
| } | } | ||||||
|  |  | ||||||
| func (f *Function) Execute(ctx *Context, input string) (any, error) { | func (f Function) WithSyntheticField(name string, description string) Function { | ||||||
|  | 	if obj, o := f.Parameters.(schema.Object); o { | ||||||
|  | 		f.Parameters = obj.WithSyntheticField(name, description) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return f | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (f Function) WithSyntheticFields(fieldsAndDescriptions map[string]string) Function { | ||||||
|  | 	if obj, o := f.Parameters.(schema.Object); o { | ||||||
|  | 		for k, v := range fieldsAndDescriptions { | ||||||
|  | 			obj = obj.WithSyntheticField(k, v) | ||||||
|  | 		} | ||||||
|  | 		f.Parameters = obj | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return f | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (f Function) Execute(ctx *Context, input string) (any, error) { | ||||||
| 	if !f.fn.IsValid() { | 	if !f.fn.IsValid() { | ||||||
| 		return "", fmt.Errorf("function %s is not implemented", f.Name) | 		return "", fmt.Errorf("function %s is not implemented", f.Name) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	slog.Info("Function.Execute", "name", f.Name, "input", input, "f", f.paramType) | ||||||
| 	// first, we need to parse the input into the struct | 	// first, we need to parse the input into the struct | ||||||
| 	p := reflect.New(f.paramType) | 	p := reflect.New(f.paramType) | ||||||
| 	fmt.Println("Function.Execute", f.Name, "input:", input) | 	fmt.Println("Function.Execute", f.Name, "input:", input) | ||||||
| 	//m := map[string]any{} |  | ||||||
| 	err := json.Unmarshal([]byte(input), p.Interface()) | 	var vals map[string]any | ||||||
|  | 	err := json.Unmarshal([]byte(input), &vals) | ||||||
|  |  | ||||||
|  | 	var syntheticFields map[string]string | ||||||
|  |  | ||||||
|  | 	// first eat up any synthetic fields | ||||||
|  | 	if obj, o := f.Parameters.(schema.Object); o { | ||||||
|  | 		for k := range obj.SyntheticFields() { | ||||||
|  | 			key := schema.SyntheticFieldPrefix + k | ||||||
|  | 			if val, ok := vals[key]; ok { | ||||||
|  | 				if syntheticFields == nil { | ||||||
|  | 					syntheticFields = map[string]string{} | ||||||
|  | 				} | ||||||
|  |  | ||||||
|  | 				syntheticFields[k] = fmt.Sprint(val) | ||||||
|  | 				delete(vals, key) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// now for any remaining fields, re-marshal them into json and then unmarshal into the struct | ||||||
|  | 	b, err := json.Marshal(vals) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", fmt.Errorf("failed to marshal input: %w (input: %s)", err, input) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// now we can unmarshal the input into the struct | ||||||
|  | 	err = json.Unmarshal(b, p.Interface()) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", fmt.Errorf("failed to unmarshal input: %w (input: %s)", err, input) | 		return "", fmt.Errorf("failed to unmarshal input: %w (input: %s)", err, input) | ||||||
| 	} | 	} | ||||||
| @@ -67,15 +113,6 @@ func (f *Function) Execute(ctx *Context, input string) (any, error) { | |||||||
| 	return exec(ctx) | 	return exec(ctx) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (f *Function) toOpenAIFunction() *openai.FunctionDefinition { |  | ||||||
| 	return &openai.FunctionDefinition{ |  | ||||||
| 		Name:        f.Name, |  | ||||||
| 		Description: f.Description, |  | ||||||
| 		Strict:      f.Strict, |  | ||||||
| 		Parameters:  f.Parameters, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type FunctionCall struct { | type FunctionCall struct { | ||||||
| 	Name      string `json:"name,omitempty"` | 	Name      string `json:"name,omitempty"` | ||||||
| 	Arguments string `json:"arguments,omitempty"` | 	Arguments string `json:"arguments,omitempty"` | ||||||
|   | |||||||
| @@ -13,7 +13,7 @@ import ( | |||||||
| // The struct parameters can have the following tags: | // The struct parameters can have the following tags: | ||||||
| // - Description: a string that describes the parameter, passed to openaiImpl to tell it what the parameter is for | // - Description: a string that describes the parameter, passed to openaiImpl to tell it what the parameter is for | ||||||
|  |  | ||||||
| func NewFunction[T any](name string, description string, fn func(*Context, T) (any, error)) *Function { | func NewFunction[T any](name string, description string, fn func(*Context, T) (any, error)) Function { | ||||||
| 	var o T | 	var o T | ||||||
|  |  | ||||||
| 	res := Function{ | 	res := Function{ | ||||||
| @@ -31,5 +31,5 @@ func NewFunction[T any](name string, description string, fn func(*Context, T) (a | |||||||
| 		panic("function parameter must be a struct") | 		panic("function parameter must be a struct") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return &res | 	return res | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										9
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								go.mod
									
									
									
									
									
								
							| @@ -5,8 +5,7 @@ go 1.23.1 | |||||||
| require ( | require ( | ||||||
| 	github.com/google/generative-ai-go v0.19.0 | 	github.com/google/generative-ai-go v0.19.0 | ||||||
| 	github.com/liushuangls/go-anthropic/v2 v2.15.0 | 	github.com/liushuangls/go-anthropic/v2 v2.15.0 | ||||||
| 	github.com/openai/openai-go v0.1.0-beta.6 | 	github.com/openai/openai-go v0.1.0-beta.9 | ||||||
| 	github.com/sashabaranov/go-openai v1.38.1 |  | ||||||
| 	google.golang.org/api v0.228.0 | 	google.golang.org/api v0.228.0 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -35,14 +34,14 @@ require ( | |||||||
| 	go.opentelemetry.io/otel/metric v1.35.0 // indirect | 	go.opentelemetry.io/otel/metric v1.35.0 // indirect | ||||||
| 	go.opentelemetry.io/otel/trace v1.35.0 // indirect | 	go.opentelemetry.io/otel/trace v1.35.0 // indirect | ||||||
| 	golang.org/x/crypto v0.37.0 // indirect | 	golang.org/x/crypto v0.37.0 // indirect | ||||||
| 	golang.org/x/net v0.38.0 // indirect | 	golang.org/x/net v0.39.0 // indirect | ||||||
| 	golang.org/x/oauth2 v0.29.0 // indirect | 	golang.org/x/oauth2 v0.29.0 // indirect | ||||||
| 	golang.org/x/sync v0.13.0 // indirect | 	golang.org/x/sync v0.13.0 // indirect | ||||||
| 	golang.org/x/sys v0.32.0 // indirect | 	golang.org/x/sys v0.32.0 // indirect | ||||||
| 	golang.org/x/text v0.24.0 // indirect | 	golang.org/x/text v0.24.0 // indirect | ||||||
| 	golang.org/x/time v0.11.0 // indirect | 	golang.org/x/time v0.11.0 // indirect | ||||||
| 	google.golang.org/genproto/googleapis/api v0.0.0-20250404141209-ee84b53bf3d0 // indirect | 	google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a // indirect | ||||||
| 	google.golang.org/genproto/googleapis/rpc v0.0.0-20250404141209-ee84b53bf3d0 // indirect | 	google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a // indirect | ||||||
| 	google.golang.org/grpc v1.71.1 // indirect | 	google.golang.org/grpc v1.71.1 // indirect | ||||||
| 	google.golang.org/protobuf v1.36.6 // indirect | 	google.golang.org/protobuf v1.36.6 // indirect | ||||||
| ) | ) | ||||||
|   | |||||||
							
								
								
									
										10
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								go.sum
									
									
									
									
									
								
							| @@ -37,10 +37,10 @@ github.com/liushuangls/go-anthropic/v2 v2.15.0 h1:zpplg7BRV/9FlMmeMPI0eDwhViB0l9 | |||||||
| github.com/liushuangls/go-anthropic/v2 v2.15.0/go.mod h1:kq2yW3JVy1/rph8u5KzX7F3q95CEpCT2RXp/2nfCmb4= | github.com/liushuangls/go-anthropic/v2 v2.15.0/go.mod h1:kq2yW3JVy1/rph8u5KzX7F3q95CEpCT2RXp/2nfCmb4= | ||||||
| github.com/openai/openai-go v0.1.0-beta.6 h1:JquYDpprfrGnlKvQQg+apy9dQ8R9mIrm+wNvAPp6jCQ= | github.com/openai/openai-go v0.1.0-beta.6 h1:JquYDpprfrGnlKvQQg+apy9dQ8R9mIrm+wNvAPp6jCQ= | ||||||
| github.com/openai/openai-go v0.1.0-beta.6/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= | github.com/openai/openai-go v0.1.0-beta.6/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= | ||||||
|  | github.com/openai/openai-go v0.1.0-beta.9 h1:ABpubc5yU/3ejee2GgRrbFta81SG/d7bQbB8mIdP0Xo= | ||||||
|  | github.com/openai/openai-go v0.1.0-beta.9/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= | ||||||
| github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | ||||||
| github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | ||||||
| github.com/sashabaranov/go-openai v1.38.1 h1:TtZabbFQZa1nEni/IhVtDF/WQjVqDgd+cWR5OeddzF8= |  | ||||||
| github.com/sashabaranov/go-openai v1.38.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= |  | ||||||
| github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= | ||||||
| github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= | ||||||
| github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= | github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= | ||||||
| @@ -73,6 +73,8 @@ golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= | |||||||
| golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= | golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= | ||||||
| golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= | golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= | ||||||
| golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= | golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= | ||||||
|  | golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= | ||||||
|  | golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= | ||||||
| golang.org/x/oauth2 v0.29.0 h1:WdYw2tdTK1S8olAzWHdgeqfy+Mtm9XNhv/xJsY65d98= | golang.org/x/oauth2 v0.29.0 h1:WdYw2tdTK1S8olAzWHdgeqfy+Mtm9XNhv/xJsY65d98= | ||||||
| golang.org/x/oauth2 v0.29.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= | golang.org/x/oauth2 v0.29.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= | ||||||
| golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= | golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= | ||||||
| @@ -87,8 +89,12 @@ google.golang.org/api v0.228.0 h1:X2DJ/uoWGnY5obVjewbp8icSL5U4FzuCfy9OjbLSnLs= | |||||||
| google.golang.org/api v0.228.0/go.mod h1:wNvRS1Pbe8r4+IfBIniV8fwCpGwTrYa+kMUDiC5z5a4= | google.golang.org/api v0.228.0/go.mod h1:wNvRS1Pbe8r4+IfBIniV8fwCpGwTrYa+kMUDiC5z5a4= | ||||||
| google.golang.org/genproto/googleapis/api v0.0.0-20250404141209-ee84b53bf3d0 h1:Qbb5RVn5xzI4naMJSpJ7lhvmos6UwZkbekd5Uz7rt9E= | google.golang.org/genproto/googleapis/api v0.0.0-20250404141209-ee84b53bf3d0 h1:Qbb5RVn5xzI4naMJSpJ7lhvmos6UwZkbekd5Uz7rt9E= | ||||||
| google.golang.org/genproto/googleapis/api v0.0.0-20250404141209-ee84b53bf3d0/go.mod h1:6T35kB3IPpdw7Wul09by0G/JuOuIFkXV6OOvt8IZeT8= | google.golang.org/genproto/googleapis/api v0.0.0-20250404141209-ee84b53bf3d0/go.mod h1:6T35kB3IPpdw7Wul09by0G/JuOuIFkXV6OOvt8IZeT8= | ||||||
|  | google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a h1:OQ7sHVzkx6L57dQpzUS4ckfWJ51KDH74XHTDe23xWAs= | ||||||
|  | google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a/go.mod h1:2R6XrVC8Oc08GlNh8ujEpc7HkLiEZ16QeY7FxIs20ac= | ||||||
| google.golang.org/genproto/googleapis/rpc v0.0.0-20250404141209-ee84b53bf3d0 h1:0K7wTWyzxZ7J+L47+LbFogJW1nn/gnnMCN0vGXNYtTI= | google.golang.org/genproto/googleapis/rpc v0.0.0-20250404141209-ee84b53bf3d0 h1:0K7wTWyzxZ7J+L47+LbFogJW1nn/gnnMCN0vGXNYtTI= | ||||||
| google.golang.org/genproto/googleapis/rpc v0.0.0-20250404141209-ee84b53bf3d0/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= | google.golang.org/genproto/googleapis/rpc v0.0.0-20250404141209-ee84b53bf3d0/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= | ||||||
|  | google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a h1:GIqLhp/cYUkuGuiT+vJk8vhOP86L4+SP5j8yXgeVpvI= | ||||||
|  | google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= | ||||||
| google.golang.org/grpc v1.71.1 h1:ffsFWr7ygTUscGPI0KKK6TLrGz0476KUvvsbqWK0rPI= | google.golang.org/grpc v1.71.1 h1:ffsFWr7ygTUscGPI0KKK6TLrGz0476KUvvsbqWK0rPI= | ||||||
| google.golang.org/grpc v1.71.1/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= | google.golang.org/grpc v1.71.1/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= | ||||||
| google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= | google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= | ||||||
|   | |||||||
| @@ -23,9 +23,7 @@ func (g google) ModelVersion(modelVersion string) (ChatCompletion, error) { | |||||||
| func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (*genai.GenerativeModel, *genai.ChatSession, []genai.Part) { | func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (*genai.GenerativeModel, *genai.ChatSession, []genai.Part) { | ||||||
| 	res := *model | 	res := *model | ||||||
|  |  | ||||||
| 	if in.Toolbox != nil { | 	for _, tool := range in.Toolbox.functions { | ||||||
| 		for _, tool := range in.Toolbox.funcs { |  | ||||||
|  |  | ||||||
| 		res.Tools = append(res.Tools, &genai.Tool{ | 		res.Tools = append(res.Tools, &genai.Tool{ | ||||||
| 			FunctionDeclarations: []*genai.FunctionDeclaration{ | 			FunctionDeclarations: []*genai.FunctionDeclaration{ | ||||||
| 				{ | 				{ | ||||||
| @@ -37,12 +35,11 @@ func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) ( | |||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 		if !in.Toolbox.dontRequireTool { | 	if !in.Toolbox.RequiresTool() { | ||||||
| 		res.ToolConfig = &genai.ToolConfig{FunctionCallingConfig: &genai.FunctionCallingConfig{ | 		res.ToolConfig = &genai.ToolConfig{FunctionCallingConfig: &genai.FunctionCallingConfig{ | ||||||
| 			Mode: genai.FunctionCallingAny, | 			Mode: genai.FunctionCallingAny, | ||||||
| 		}} | 		}} | ||||||
| 	} | 	} | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	cs := res.StartChat() | 	cs := res.StartChat() | ||||||
|  |  | ||||||
|   | |||||||
| @@ -31,8 +31,7 @@ func (o openaiImpl) newRequestToOpenAIRequest(request Request) openai.ChatComple | |||||||
| 		res.Messages = append(res.Messages, msg.toChatCompletionMessages(o.model)...) | 		res.Messages = append(res.Messages, msg.toChatCompletionMessages(o.model)...) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if request.Toolbox != nil { | 	for _, tool := range request.Toolbox.functions { | ||||||
| 		for _, tool := range request.Toolbox.funcs { |  | ||||||
| 		res.Tools = append(res.Tools, openai.ChatCompletionToolParam{ | 		res.Tools = append(res.Tools, openai.ChatCompletionToolParam{ | ||||||
| 			Type: "function", | 			Type: "function", | ||||||
| 			Function: shared.FunctionDefinitionParam{ | 			Function: shared.FunctionDefinitionParam{ | ||||||
| @@ -44,12 +43,11 @@ func (o openaiImpl) newRequestToOpenAIRequest(request Request) openai.ChatComple | |||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 		if !request.Toolbox.dontRequireTool { | 	if request.Toolbox.RequiresTool() { | ||||||
| 		res.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{ | 		res.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{ | ||||||
| 			OfAuto: openai.String("required"), | 			OfAuto: openai.String("required"), | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if request.Temperature != nil { | 	if request.Temperature != nil { | ||||||
| 		res.Temperature = openai.Float(*request.Temperature) | 		res.Temperature = openai.Float(*request.Temperature) | ||||||
|   | |||||||
| @@ -15,7 +15,7 @@ type Input interface { | |||||||
| type Request struct { | type Request struct { | ||||||
| 	Conversation []Input | 	Conversation []Input | ||||||
| 	Messages     []Message | 	Messages     []Message | ||||||
| 	Toolbox      *ToolBox | 	Toolbox      ToolBox | ||||||
| 	Temperature  *float64 | 	Temperature  *float64 | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -25,27 +25,27 @@ func getFromType(t reflect.Type, b basic) Type { | |||||||
|  |  | ||||||
| 	switch t.Kind() { | 	switch t.Kind() { | ||||||
| 	case reflect.String: | 	case reflect.String: | ||||||
| 		b.DataType = String | 		b.DataType = TypeString | ||||||
| 		b.typeName = "string" | 		b.typeName = "string" | ||||||
| 		return b | 		return b | ||||||
|  |  | ||||||
| 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||||||
| 		b.DataType = Integer | 		b.DataType = TypeInteger | ||||||
| 		b.typeName = "integer" | 		b.typeName = "integer" | ||||||
| 		return b | 		return b | ||||||
|  |  | ||||||
| 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | ||||||
| 		b.DataType = Integer | 		b.DataType = TypeInteger | ||||||
| 		b.typeName = "integer" | 		b.typeName = "integer" | ||||||
| 		return b | 		return b | ||||||
|  |  | ||||||
| 	case reflect.Float32, reflect.Float64: | 	case reflect.Float32, reflect.Float64: | ||||||
| 		b.DataType = Number | 		b.DataType = TypeNumber | ||||||
| 		b.typeName = "number" | 		b.typeName = "number" | ||||||
| 		return b | 		return b | ||||||
|  |  | ||||||
| 	case reflect.Bool: | 	case reflect.Bool: | ||||||
| 		b.DataType = Boolean | 		b.DataType = TypeBoolean | ||||||
| 		b.typeName = "boolean" | 		b.typeName = "boolean" | ||||||
| 		return b | 		return b | ||||||
|  |  | ||||||
| @@ -92,7 +92,7 @@ func getField(f reflect.StructField, index int) Type { | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		b.DataType = String | 		b.DataType = TypeString | ||||||
| 		b.typeName = "string" | 		b.typeName = "string" | ||||||
| 		return enum{ | 		return enum{ | ||||||
| 			basic:  b, | 			basic:  b, | ||||||
| @@ -104,15 +104,26 @@ func getField(f reflect.StructField, index int) Type { | |||||||
| 	return getFromType(t, b) | 	return getFromType(t, b) | ||||||
| } | } | ||||||
|  |  | ||||||
| func getObject(t reflect.Type) object { | func getObject(t reflect.Type) Object { | ||||||
| 	fields := make(map[string]Type, t.NumField()) | 	fields := make(map[string]Type, t.NumField()) | ||||||
| 	for i := 0; i < t.NumField(); i++ { | 	for i := 0; i < t.NumField(); i++ { | ||||||
| 		field := t.Field(i) | 		field := t.Field(i) | ||||||
|  |  | ||||||
|  | 		if field.Anonymous { | ||||||
|  | 			// if the field is anonymous, we need to get the fields of the anonymous struct | ||||||
|  | 			// and add them to the object | ||||||
|  | 			anon := getObject(field.Type) | ||||||
|  | 			for k, v := range anon.fields { | ||||||
|  | 				fields[k] = v | ||||||
|  | 			} | ||||||
|  | 			continue | ||||||
|  | 		} else { | ||||||
| 			fields[field.Name] = getField(field, i) | 			fields[field.Name] = getField(field, i) | ||||||
| 		} | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	return object{ | 	return Object{ | ||||||
| 		basic:  basic{DataType: Object, typeName: "object"}, | 		basic:  basic{DataType: TypeObject, typeName: "object"}, | ||||||
| 		fields: fields, | 		fields: fields, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -120,7 +131,7 @@ func getObject(t reflect.Type) object { | |||||||
| func getArray(t reflect.Type) array { | func getArray(t reflect.Type) array { | ||||||
| 	res := array{ | 	res := array{ | ||||||
| 		basic: basic{ | 		basic: basic{ | ||||||
| 			DataType: Array, | 			DataType: TypeArray, | ||||||
| 			typeName: "array", | 			typeName: "array", | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -15,12 +15,12 @@ var _ Type = basic{} | |||||||
| type DataType string | type DataType string | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	String  DataType = "string" | 	TypeString  DataType = "string" | ||||||
| 	Integer DataType = "integer" | 	TypeInteger DataType = "integer" | ||||||
| 	Number  DataType = "number" | 	TypeNumber  DataType = "number" | ||||||
| 	Boolean DataType = "boolean" | 	TypeBoolean DataType = "boolean" | ||||||
| 	Object  DataType = "object" | 	TypeObject  DataType = "object" | ||||||
| 	Array   DataType = "array" | 	TypeArray   DataType = "array" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type basic struct { | type basic struct { | ||||||
| @@ -49,17 +49,17 @@ func (b basic) GoogleParameters() *genai.Schema { | |||||||
| 	var t = genai.TypeUnspecified | 	var t = genai.TypeUnspecified | ||||||
|  |  | ||||||
| 	switch b.DataType { | 	switch b.DataType { | ||||||
| 	case String: | 	case TypeString: | ||||||
| 		t = genai.TypeString | 		t = genai.TypeString | ||||||
| 	case Integer: | 	case TypeInteger: | ||||||
| 		t = genai.TypeInteger | 		t = genai.TypeInteger | ||||||
| 	case Number: | 	case TypeNumber: | ||||||
| 		t = genai.TypeNumber | 		t = genai.TypeNumber | ||||||
| 	case Boolean: | 	case TypeBoolean: | ||||||
| 		t = genai.TypeBoolean | 		t = genai.TypeBoolean | ||||||
| 	case Object: | 	case TypeObject: | ||||||
| 		t = genai.TypeObject | 		t = genai.TypeObject | ||||||
| 	case Array: | 	case TypeArray: | ||||||
| 		t = genai.TypeArray | 		t = genai.TypeArray | ||||||
| 	default: | 	default: | ||||||
| 		t = genai.TypeUnspecified | 		t = genai.TypeUnspecified | ||||||
| @@ -82,12 +82,12 @@ func (b basic) FromAny(val any) (reflect.Value, error) { | |||||||
| 	v := reflect.ValueOf(val) | 	v := reflect.ValueOf(val) | ||||||
|  |  | ||||||
| 	switch b.DataType { | 	switch b.DataType { | ||||||
| 	case String: | 	case TypeString: | ||||||
| 		var val = v.String() | 		var val = v.String() | ||||||
|  |  | ||||||
| 		return reflect.ValueOf(val), nil | 		return reflect.ValueOf(val), nil | ||||||
|  |  | ||||||
| 	case Integer: | 	case TypeInteger: | ||||||
| 		if v.Kind() == reflect.Float64 { | 		if v.Kind() == reflect.Float64 { | ||||||
| 			return v.Convert(reflect.TypeOf(int(0))), nil | 			return v.Convert(reflect.TypeOf(int(0))), nil | ||||||
| 		} else if v.Kind() != reflect.Int { | 		} else if v.Kind() != reflect.Int { | ||||||
| @@ -96,7 +96,7 @@ func (b basic) FromAny(val any) (reflect.Value, error) { | |||||||
| 			return v, nil | 			return v, nil | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 	case Number: | 	case TypeNumber: | ||||||
| 		if v.Kind() == reflect.Float64 { | 		if v.Kind() == reflect.Float64 { | ||||||
| 			return v.Convert(reflect.TypeOf(float64(0))), nil | 			return v.Convert(reflect.TypeOf(float64(0))), nil | ||||||
| 		} else if v.Kind() != reflect.Float64 { | 		} else if v.Kind() != reflect.Float64 { | ||||||
| @@ -105,7 +105,7 @@ func (b basic) FromAny(val any) (reflect.Value, error) { | |||||||
| 			return v, nil | 			return v, nil | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 	case Boolean: | 	case TypeBoolean: | ||||||
| 		if v.Kind() == reflect.Bool { | 		if v.Kind() == reflect.Bool { | ||||||
| 			return v, nil | 			return v, nil | ||||||
| 		} else if v.Kind() == reflect.String { | 		} else if v.Kind() == reflect.String { | ||||||
|   | |||||||
| @@ -8,15 +8,44 @@ import ( | |||||||
| 	"github.com/openai/openai-go" | 	"github.com/openai/openai-go" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type object struct { | const ( | ||||||
|  | 	// SyntheticFieldPrefix is any prefix that is added to any synthetic fields that are added to the object, to prevent | ||||||
|  | 	// collisions with the fields in the struct. | ||||||
|  | 	SyntheticFieldPrefix = "__" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type Object struct { | ||||||
| 	basic | 	basic | ||||||
|  |  | ||||||
| 	ref reflect.Type | 	ref reflect.Type | ||||||
|  |  | ||||||
| 	fields map[string]Type | 	fields map[string]Type | ||||||
|  |  | ||||||
|  | 	// syntheticFields are fields that are not in the struct but are generated by a system. | ||||||
|  | 	synetheticFields map[string]Type | ||||||
| } | } | ||||||
|  |  | ||||||
| func (o object) OpenAIParameters() openai.FunctionParameters { | func (o Object) WithSyntheticField(name string, description string) Object { | ||||||
|  | 	if o.synetheticFields == nil { | ||||||
|  | 		o.synetheticFields = map[string]Type{} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	o.synetheticFields[name] = basic{ | ||||||
|  | 		DataType:    TypeString, | ||||||
|  | 		typeName:    "string", | ||||||
|  | 		index:       -1, | ||||||
|  | 		required:    false, | ||||||
|  | 		description: description, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return o | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (o Object) SyntheticFields() map[string]Type { | ||||||
|  | 	return o.synetheticFields | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (o Object) OpenAIParameters() openai.FunctionParameters { | ||||||
| 	var properties = map[string]openai.FunctionParameters{} | 	var properties = map[string]openai.FunctionParameters{} | ||||||
| 	var required []string | 	var required []string | ||||||
| 	for k, v := range o.fields { | 	for k, v := range o.fields { | ||||||
| @@ -26,6 +55,13 @@ func (o object) OpenAIParameters() openai.FunctionParameters { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	for k, v := range o.synetheticFields { | ||||||
|  | 		properties[SyntheticFieldPrefix+k] = v.OpenAIParameters() | ||||||
|  | 		if v.Required() { | ||||||
|  | 			required = append(required, SyntheticFieldPrefix+k) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	var res = openai.FunctionParameters{ | 	var res = openai.FunctionParameters{ | ||||||
| 		"type":        "object", | 		"type":        "object", | ||||||
| 		"description": o.Description(), | 		"description": o.Description(), | ||||||
| @@ -39,7 +75,7 @@ func (o object) OpenAIParameters() openai.FunctionParameters { | |||||||
| 	return res | 	return res | ||||||
| } | } | ||||||
|  |  | ||||||
| func (o object) GoogleParameters() *genai.Schema { | func (o Object) GoogleParameters() *genai.Schema { | ||||||
| 	var properties = map[string]*genai.Schema{} | 	var properties = map[string]*genai.Schema{} | ||||||
| 	var required []string | 	var required []string | ||||||
| 	for k, v := range o.fields { | 	for k, v := range o.fields { | ||||||
| @@ -62,7 +98,8 @@ func (o object) GoogleParameters() *genai.Schema { | |||||||
| 	return res | 	return res | ||||||
| } | } | ||||||
|  |  | ||||||
| func (o object) FromAny(val any) (reflect.Value, error) { | // FromAny converts the value from any to the correct type, returning the value, and an error if any | ||||||
|  | func (o Object) FromAny(val any) (reflect.Value, error) { | ||||||
| 	// if the value is nil, we can't do anything | 	// if the value is nil, we can't do anything | ||||||
| 	if val == nil { | 	if val == nil { | ||||||
| 		return reflect.Value{}, nil | 		return reflect.Value{}, nil | ||||||
| @@ -99,7 +136,7 @@ func (o object) FromAny(val any) (reflect.Value, error) { | |||||||
| 	return obj, nil | 	return obj, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (o object) SetValueOnField(obj reflect.Value, val reflect.Value) { | func (o Object) SetValueOnField(obj reflect.Value, val reflect.Value) { | ||||||
| 	// if this basic type is not required that means it's a pointer type so we need to set the value to the address of the value | 	// if this basic type is not required that means it's a pointer type so we need to set the value to the address of the value | ||||||
| 	if !o.required { | 	if !o.required { | ||||||
| 		val = val.Addr() | 		val = val.Addr() | ||||||
|   | |||||||
							
								
								
									
										136
									
								
								toolbox.go
									
									
									
									
									
								
							
							
						
						
									
										136
									
								
								toolbox.go
									
									
									
									
									
								
							| @@ -4,79 +4,82 @@ import ( | |||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  |  | ||||||
| 	"github.com/sashabaranov/go-openai" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // ToolBox is a collection of tools that OpenAI can use to execute functions. | // ToolBox is a collection of tools that OpenAI can use to execute functions. | ||||||
| // It is a wrapper around a collection of functions, and provides a way to automatically call the correct function with | // It is a wrapper around a collection of functions, and provides a way to automatically call the correct function with | ||||||
| // the correct parameters. | // the correct parameters. | ||||||
| type ToolBox struct { | type ToolBox struct { | ||||||
| 	funcs           []Function | 	functions       map[string]Function | ||||||
| 	names           map[string]Function |  | ||||||
| 	dontRequireTool bool | 	dontRequireTool bool | ||||||
| } | } | ||||||
|  |  | ||||||
| func NewToolBox(fns ...*Function) *ToolBox { | func NewToolBox(fns ...Function) ToolBox { | ||||||
| 	res := ToolBox{ | 	res := ToolBox{ | ||||||
| 		funcs: []Function{}, | 		functions: map[string]Function{}, | ||||||
| 		names: map[string]Function{}, |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	for _, f := range fns { | 	for _, f := range fns { | ||||||
| 		o := *f | 		res.functions[f.Name] = f | ||||||
| 		res.names[o.Name] = o |  | ||||||
| 		res.funcs = append(res.funcs, o) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return &res |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (t *ToolBox) WithFunction(f Function) *ToolBox { |  | ||||||
| 	t2 := *t |  | ||||||
| 	t2.names[f.Name] = f |  | ||||||
| 	t2.funcs = append(t2.funcs, f) |  | ||||||
|  |  | ||||||
| 	return &t2 |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (t *ToolBox) WithFunctionRemoved(name string) *ToolBox { |  | ||||||
| 	t2 := *t |  | ||||||
|  |  | ||||||
| 	delete(t2.names, name) |  | ||||||
|  |  | ||||||
| 	for i, f := range t2.funcs { |  | ||||||
| 		if f.Name == name { |  | ||||||
| 			t2.funcs = append(t2.funcs[:i], t2.funcs[i+1:]...) |  | ||||||
| 			break |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return &t2 |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (t *ToolBox) WithRequireTool(val bool) *ToolBox { |  | ||||||
| 	t2 := *t |  | ||||||
| 	t2.dontRequireTool = !val |  | ||||||
| 	return &t2 |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // ToOpenAI will convert the current ToolBox to a slice of openai.Tool, which can be used to send to the OpenAI API. |  | ||||||
| func (t *ToolBox) toOpenAI() []openai.Tool { |  | ||||||
| 	var res []openai.Tool |  | ||||||
|  |  | ||||||
| 	for _, f := range t.funcs { |  | ||||||
| 		res = append(res, openai.Tool{ |  | ||||||
| 			Type:     "function", |  | ||||||
| 			Function: f.toOpenAIFunction(), |  | ||||||
| 		}) |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return res | 	return res | ||||||
| } | } | ||||||
|  |  | ||||||
| func (t *ToolBox) ToToolChoice() any { | func (t ToolBox) Functions() []Function { | ||||||
| 	if len(t.funcs) == 0 { | 	var res []Function | ||||||
|  |  | ||||||
|  | 	for _, f := range t.functions { | ||||||
|  | 		res = append(res, f) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return res | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t ToolBox) WithFunction(f Function) ToolBox { | ||||||
|  | 	t.functions[f.Name] = f | ||||||
|  |  | ||||||
|  | 	return t | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t ToolBox) WithFunctions(fns ...Function) ToolBox { | ||||||
|  | 	for _, f := range fns { | ||||||
|  | 		t.functions[f.Name] = f | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return t | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t ToolBox) WithSyntheticFieldsAddedToAllFunctions(fieldsAndDescriptions map[string]string) ToolBox { | ||||||
|  | 	for k, v := range t.functions { | ||||||
|  | 		t.functions[k] = v.WithSyntheticFields(fieldsAndDescriptions) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return t | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t ToolBox) ForEachFunction(fn func(f Function)) { | ||||||
|  | 	for _, f := range t.functions { | ||||||
|  | 		fn(f) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t ToolBox) WithFunctionRemoved(name string) ToolBox { | ||||||
|  | 	delete(t.functions, name) | ||||||
|  | 	return t | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t ToolBox) WithRequireTool(val bool) ToolBox { | ||||||
|  | 	t.dontRequireTool = !val | ||||||
|  | 	return t | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t ToolBox) RequiresTool() bool { | ||||||
|  | 	return !t.dontRequireTool && len(t.functions) > 0 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t ToolBox) ToToolChoice() any { | ||||||
|  | 	if len(t.functions) == 0 { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -87,8 +90,8 @@ var ( | |||||||
| 	ErrFunctionNotFound = errors.New("function not found") | 	ErrFunctionNotFound = errors.New("function not found") | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func (t *ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) { | func (t ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) { | ||||||
| 	f, ok := t.names[functionName] | 	f, ok := t.functions[functionName] | ||||||
|  |  | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName)) | 		return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName)) | ||||||
| @@ -97,14 +100,29 @@ func (t *ToolBox) executeFunction(ctx *Context, functionName string, params stri | |||||||
| 	return f.Execute(ctx, params) | 	return f.Execute(ctx, params) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (t *ToolBox) Execute(ctx *Context, toolCall ToolCall) (any, error) { | func (t ToolBox) Execute(ctx *Context, toolCall ToolCall) (any, error) { | ||||||
| 	return t.executeFunction(ctx.WithToolCall(&toolCall), toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments) | 	return t.executeFunction(ctx.WithToolCall(&toolCall), toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (t ToolBox) GetSyntheticParametersFromFunctionContext(ctx context.Context) map[string]string { | ||||||
|  | 	val := ctx.Value("syntheticParameters") | ||||||
|  |  | ||||||
|  | 	if val == nil { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	syntheticParameters, ok := val.(map[string]string) | ||||||
|  | 	if !ok { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return syntheticParameters | ||||||
|  | } | ||||||
|  |  | ||||||
| // ExecuteCallbacks will execute all the tool calls in the given list, and call the given callbacks when a new function is created, and when a function is finished. | // ExecuteCallbacks will execute all the tool calls in the given list, and call the given callbacks when a new function is created, and when a function is finished. | ||||||
| // OnNewFunction is called when a new function is created | // OnNewFunction is called when a new function is created | ||||||
| // OnFunctionFinished is called when a function is finished | // OnFunctionFinished is called when a function is finished | ||||||
| func (t *ToolBox) ExecuteCallbacks(ctx *Context, toolCalls []ToolCall, OnNewFunction func(ctx context.Context, funcName string, parameter string) (any, error), OnFunctionFinished func(ctx context.Context, funcName string, parameter string, result any, err error, newFunctionResult any) error) ([]ToolCallResponse, error) { | func (t ToolBox) ExecuteCallbacks(ctx *Context, toolCalls []ToolCall, OnNewFunction func(ctx context.Context, funcName string, parameter string) (any, error), OnFunctionFinished func(ctx context.Context, funcName string, parameter string, result any, err error, newFunctionResult any) error) ([]ToolCallResponse, error) { | ||||||
| 	var res []ToolCallResponse | 	var res []ToolCallResponse | ||||||
|  |  | ||||||
| 	for _, call := range toolCalls { | 	for _, call := range toolCalls { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user