Refactor toolbox and function handling to support synthetic fields and improve type definitions

This commit is contained in:
Steve Dudenhoeffer 2025-04-12 02:20:40 -04:00
parent 2ae583e9f3
commit 3093b988f8
13 changed files with 288 additions and 160 deletions

View File

@ -134,15 +134,16 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
}
}
if req.Toolbox != nil {
for _, tool := range req.Toolbox.funcs {
/*
for _, tool := range req.Toolbox.functions {
res.Tools = append(res.Tools, anth.ToolDefinition{
Name: tool.Name,
Description: tool.Description,
InputSchema: tool.Parameters,
InputSchema: tool.Parameters.OpenAIParameters(),
})
}
}
*/
res.Messages = msgs

View File

@ -7,9 +7,10 @@ import (
type Context struct {
context.Context
request Request
response *ResponseChoice
toolcall *ToolCall
request Request
response *ResponseChoice
toolcall *ToolCall
syntheticFields map[string]string
}
func (c *Context) ToNewRequest(toolResults ...ToolCallResponse) Request {
@ -55,20 +56,32 @@ func (c *Context) ToolCall() *ToolCall {
return c.toolcall
}
func (c *Context) SyntheticFields() map[string]string {
if c.syntheticFields == nil {
c.syntheticFields = map[string]string{}
}
return c.syntheticFields
}
func (c *Context) WithContext(ctx context.Context) *Context {
return &Context{Context: ctx, request: c.request, response: c.response, toolcall: c.toolcall}
return &Context{Context: ctx, request: c.request, response: c.response, toolcall: c.toolcall, syntheticFields: c.syntheticFields}
}
func (c *Context) WithRequest(request Request) *Context {
return &Context{Context: c.Context, request: request, response: c.response, toolcall: c.toolcall}
return &Context{Context: c.Context, request: request, response: c.response, toolcall: c.toolcall, syntheticFields: c.syntheticFields}
}
func (c *Context) WithResponse(response *ResponseChoice) *Context {
return &Context{Context: c.Context, request: c.request, response: response, toolcall: c.toolcall}
return &Context{Context: c.Context, request: c.request, response: response, toolcall: c.toolcall, syntheticFields: c.syntheticFields}
}
func (c *Context) WithToolCall(toolcall *ToolCall) *Context {
return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: toolcall}
return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: toolcall, syntheticFields: c.syntheticFields}
}
func (c *Context) WithSyntheticFields(syntheticFields map[string]string) *Context {
return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: c.toolcall, syntheticFields: syntheticFields}
}
func (c *Context) Deadline() (deadline time.Time, ok bool) {
@ -84,8 +97,19 @@ func (c *Context) Err() error {
}
func (c *Context) Value(key any) any {
if key == "request" {
switch key {
case "request":
return c.request
case "response":
return c.response
case "toolcall":
return c.toolcall
case "syntheticFields":
return c.syntheticFields
}
return c.Context.Value(key)
}

View File

@ -4,11 +4,10 @@ import (
"context"
"encoding/json"
"fmt"
"log/slog"
"reflect"
"time"
"github.com/sashabaranov/go-openai"
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
)
@ -29,16 +28,63 @@ type Function struct {
paramType reflect.Type
}
func (f *Function) Execute(ctx *Context, input string) (any, error) {
func (f Function) WithSyntheticField(name string, description string) Function {
if obj, o := f.Parameters.(schema.Object); o {
f.Parameters = obj.WithSyntheticField(name, description)
}
return f
}
func (f Function) WithSyntheticFields(fieldsAndDescriptions map[string]string) Function {
if obj, o := f.Parameters.(schema.Object); o {
for k, v := range fieldsAndDescriptions {
obj = obj.WithSyntheticField(k, v)
}
f.Parameters = obj
}
return f
}
func (f Function) Execute(ctx *Context, input string) (any, error) {
if !f.fn.IsValid() {
return "", fmt.Errorf("function %s is not implemented", f.Name)
}
slog.Info("Function.Execute", "name", f.Name, "input", input, "f", f.paramType)
// first, we need to parse the input into the struct
p := reflect.New(f.paramType)
fmt.Println("Function.Execute", f.Name, "input:", input)
//m := map[string]any{}
err := json.Unmarshal([]byte(input), p.Interface())
var vals map[string]any
err := json.Unmarshal([]byte(input), &vals)
var syntheticFields map[string]string
// first eat up any synthetic fields
if obj, o := f.Parameters.(schema.Object); o {
for k := range obj.SyntheticFields() {
key := schema.SyntheticFieldPrefix + k
if val, ok := vals[key]; ok {
if syntheticFields == nil {
syntheticFields = map[string]string{}
}
syntheticFields[k] = fmt.Sprint(val)
delete(vals, key)
}
}
}
// now for any remaining fields, re-marshal them into json and then unmarshal into the struct
b, err := json.Marshal(vals)
if err != nil {
return "", fmt.Errorf("failed to marshal input: %w (input: %s)", err, input)
}
// now we can unmarshal the input into the struct
err = json.Unmarshal(b, p.Interface())
if err != nil {
return "", fmt.Errorf("failed to unmarshal input: %w (input: %s)", err, input)
}
@ -67,15 +113,6 @@ func (f *Function) Execute(ctx *Context, input string) (any, error) {
return exec(ctx)
}
func (f *Function) toOpenAIFunction() *openai.FunctionDefinition {
return &openai.FunctionDefinition{
Name: f.Name,
Description: f.Description,
Strict: f.Strict,
Parameters: f.Parameters,
}
}
type FunctionCall struct {
Name string `json:"name,omitempty"`
Arguments string `json:"arguments,omitempty"`

View File

@ -13,7 +13,7 @@ import (
// The struct parameters can have the following tags:
// - Description: a string that describes the parameter, passed to openaiImpl to tell it what the parameter is for
func NewFunction[T any](name string, description string, fn func(*Context, T) (any, error)) *Function {
func NewFunction[T any](name string, description string, fn func(*Context, T) (any, error)) Function {
var o T
res := Function{
@ -31,5 +31,5 @@ func NewFunction[T any](name string, description string, fn func(*Context, T) (a
panic("function parameter must be a struct")
}
return &res
return res
}

9
go.mod
View File

@ -5,8 +5,7 @@ go 1.23.1
require (
github.com/google/generative-ai-go v0.19.0
github.com/liushuangls/go-anthropic/v2 v2.15.0
github.com/openai/openai-go v0.1.0-beta.6
github.com/sashabaranov/go-openai v1.38.1
github.com/openai/openai-go v0.1.0-beta.9
google.golang.org/api v0.228.0
)
@ -35,14 +34,14 @@ require (
go.opentelemetry.io/otel/metric v1.35.0 // indirect
go.opentelemetry.io/otel/trace v1.35.0 // indirect
golang.org/x/crypto v0.37.0 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/net v0.39.0 // indirect
golang.org/x/oauth2 v0.29.0 // indirect
golang.org/x/sync v0.13.0 // indirect
golang.org/x/sys v0.32.0 // indirect
golang.org/x/text v0.24.0 // indirect
golang.org/x/time v0.11.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250404141209-ee84b53bf3d0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250404141209-ee84b53bf3d0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a // indirect
google.golang.org/grpc v1.71.1 // indirect
google.golang.org/protobuf v1.36.6 // indirect
)

10
go.sum
View File

@ -37,10 +37,10 @@ github.com/liushuangls/go-anthropic/v2 v2.15.0 h1:zpplg7BRV/9FlMmeMPI0eDwhViB0l9
github.com/liushuangls/go-anthropic/v2 v2.15.0/go.mod h1:kq2yW3JVy1/rph8u5KzX7F3q95CEpCT2RXp/2nfCmb4=
github.com/openai/openai-go v0.1.0-beta.6 h1:JquYDpprfrGnlKvQQg+apy9dQ8R9mIrm+wNvAPp6jCQ=
github.com/openai/openai-go v0.1.0-beta.6/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
github.com/openai/openai-go v0.1.0-beta.9 h1:ABpubc5yU/3ejee2GgRrbFta81SG/d7bQbB8mIdP0Xo=
github.com/openai/openai-go v0.1.0-beta.9/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sashabaranov/go-openai v1.38.1 h1:TtZabbFQZa1nEni/IhVtDF/WQjVqDgd+cWR5OeddzF8=
github.com/sashabaranov/go-openai v1.38.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
@ -73,6 +73,8 @@ golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
golang.org/x/oauth2 v0.29.0 h1:WdYw2tdTK1S8olAzWHdgeqfy+Mtm9XNhv/xJsY65d98=
golang.org/x/oauth2 v0.29.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
@ -87,8 +89,12 @@ google.golang.org/api v0.228.0 h1:X2DJ/uoWGnY5obVjewbp8icSL5U4FzuCfy9OjbLSnLs=
google.golang.org/api v0.228.0/go.mod h1:wNvRS1Pbe8r4+IfBIniV8fwCpGwTrYa+kMUDiC5z5a4=
google.golang.org/genproto/googleapis/api v0.0.0-20250404141209-ee84b53bf3d0 h1:Qbb5RVn5xzI4naMJSpJ7lhvmos6UwZkbekd5Uz7rt9E=
google.golang.org/genproto/googleapis/api v0.0.0-20250404141209-ee84b53bf3d0/go.mod h1:6T35kB3IPpdw7Wul09by0G/JuOuIFkXV6OOvt8IZeT8=
google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a h1:OQ7sHVzkx6L57dQpzUS4ckfWJ51KDH74XHTDe23xWAs=
google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a/go.mod h1:2R6XrVC8Oc08GlNh8ujEpc7HkLiEZ16QeY7FxIs20ac=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250404141209-ee84b53bf3d0 h1:0K7wTWyzxZ7J+L47+LbFogJW1nn/gnnMCN0vGXNYtTI=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250404141209-ee84b53bf3d0/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a h1:GIqLhp/cYUkuGuiT+vJk8vhOP86L4+SP5j8yXgeVpvI=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
google.golang.org/grpc v1.71.1 h1:ffsFWr7ygTUscGPI0KKK6TLrGz0476KUvvsbqWK0rPI=
google.golang.org/grpc v1.71.1/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=

View File

@ -23,25 +23,22 @@ func (g google) ModelVersion(modelVersion string) (ChatCompletion, error) {
func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (*genai.GenerativeModel, *genai.ChatSession, []genai.Part) {
res := *model
if in.Toolbox != nil {
for _, tool := range in.Toolbox.funcs {
res.Tools = append(res.Tools, &genai.Tool{
FunctionDeclarations: []*genai.FunctionDeclaration{
{
Name: tool.Name,
Description: tool.Description,
Parameters: tool.Parameters.GoogleParameters(),
},
for _, tool := range in.Toolbox.functions {
res.Tools = append(res.Tools, &genai.Tool{
FunctionDeclarations: []*genai.FunctionDeclaration{
{
Name: tool.Name,
Description: tool.Description,
Parameters: tool.Parameters.GoogleParameters(),
},
})
}
},
})
}
if !in.Toolbox.dontRequireTool {
res.ToolConfig = &genai.ToolConfig{FunctionCallingConfig: &genai.FunctionCallingConfig{
Mode: genai.FunctionCallingAny,
}}
}
if !in.Toolbox.RequiresTool() {
res.ToolConfig = &genai.ToolConfig{FunctionCallingConfig: &genai.FunctionCallingConfig{
Mode: genai.FunctionCallingAny,
}}
}
cs := res.StartChat()

View File

@ -31,23 +31,21 @@ func (o openaiImpl) newRequestToOpenAIRequest(request Request) openai.ChatComple
res.Messages = append(res.Messages, msg.toChatCompletionMessages(o.model)...)
}
if request.Toolbox != nil {
for _, tool := range request.Toolbox.funcs {
res.Tools = append(res.Tools, openai.ChatCompletionToolParam{
Type: "function",
Function: shared.FunctionDefinitionParam{
Name: tool.Name,
Description: openai.String(tool.Description),
Strict: openai.Bool(tool.Strict),
Parameters: tool.Parameters.OpenAIParameters(),
},
})
}
for _, tool := range request.Toolbox.functions {
res.Tools = append(res.Tools, openai.ChatCompletionToolParam{
Type: "function",
Function: shared.FunctionDefinitionParam{
Name: tool.Name,
Description: openai.String(tool.Description),
Strict: openai.Bool(tool.Strict),
Parameters: tool.Parameters.OpenAIParameters(),
},
})
}
if !request.Toolbox.dontRequireTool {
res.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{
OfAuto: openai.String("required"),
}
if request.Toolbox.RequiresTool() {
res.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{
OfAuto: openai.String("required"),
}
}

View File

@ -15,7 +15,7 @@ type Input interface {
type Request struct {
Conversation []Input
Messages []Message
Toolbox *ToolBox
Toolbox ToolBox
Temperature *float64
}

View File

@ -25,27 +25,27 @@ func getFromType(t reflect.Type, b basic) Type {
switch t.Kind() {
case reflect.String:
b.DataType = String
b.DataType = TypeString
b.typeName = "string"
return b
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
b.DataType = Integer
b.DataType = TypeInteger
b.typeName = "integer"
return b
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
b.DataType = Integer
b.DataType = TypeInteger
b.typeName = "integer"
return b
case reflect.Float32, reflect.Float64:
b.DataType = Number
b.DataType = TypeNumber
b.typeName = "number"
return b
case reflect.Bool:
b.DataType = Boolean
b.DataType = TypeBoolean
b.typeName = "boolean"
return b
@ -92,7 +92,7 @@ func getField(f reflect.StructField, index int) Type {
}
}
b.DataType = String
b.DataType = TypeString
b.typeName = "string"
return enum{
basic: b,
@ -104,15 +104,26 @@ func getField(f reflect.StructField, index int) Type {
return getFromType(t, b)
}
func getObject(t reflect.Type) object {
func getObject(t reflect.Type) Object {
fields := make(map[string]Type, t.NumField())
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fields[field.Name] = getField(field, i)
if field.Anonymous {
// if the field is anonymous, we need to get the fields of the anonymous struct
// and add them to the object
anon := getObject(field.Type)
for k, v := range anon.fields {
fields[k] = v
}
continue
} else {
fields[field.Name] = getField(field, i)
}
}
return object{
basic: basic{DataType: Object, typeName: "object"},
return Object{
basic: basic{DataType: TypeObject, typeName: "object"},
fields: fields,
}
}
@ -120,7 +131,7 @@ func getObject(t reflect.Type) object {
func getArray(t reflect.Type) array {
res := array{
basic: basic{
DataType: Array,
DataType: TypeArray,
typeName: "array",
},
}

View File

@ -15,12 +15,12 @@ var _ Type = basic{}
type DataType string
const (
String DataType = "string"
Integer DataType = "integer"
Number DataType = "number"
Boolean DataType = "boolean"
Object DataType = "object"
Array DataType = "array"
TypeString DataType = "string"
TypeInteger DataType = "integer"
TypeNumber DataType = "number"
TypeBoolean DataType = "boolean"
TypeObject DataType = "object"
TypeArray DataType = "array"
)
type basic struct {
@ -49,17 +49,17 @@ func (b basic) GoogleParameters() *genai.Schema {
var t = genai.TypeUnspecified
switch b.DataType {
case String:
case TypeString:
t = genai.TypeString
case Integer:
case TypeInteger:
t = genai.TypeInteger
case Number:
case TypeNumber:
t = genai.TypeNumber
case Boolean:
case TypeBoolean:
t = genai.TypeBoolean
case Object:
case TypeObject:
t = genai.TypeObject
case Array:
case TypeArray:
t = genai.TypeArray
default:
t = genai.TypeUnspecified
@ -82,12 +82,12 @@ func (b basic) FromAny(val any) (reflect.Value, error) {
v := reflect.ValueOf(val)
switch b.DataType {
case String:
case TypeString:
var val = v.String()
return reflect.ValueOf(val), nil
case Integer:
case TypeInteger:
if v.Kind() == reflect.Float64 {
return v.Convert(reflect.TypeOf(int(0))), nil
} else if v.Kind() != reflect.Int {
@ -96,7 +96,7 @@ func (b basic) FromAny(val any) (reflect.Value, error) {
return v, nil
}
case Number:
case TypeNumber:
if v.Kind() == reflect.Float64 {
return v.Convert(reflect.TypeOf(float64(0))), nil
} else if v.Kind() != reflect.Float64 {
@ -105,7 +105,7 @@ func (b basic) FromAny(val any) (reflect.Value, error) {
return v, nil
}
case Boolean:
case TypeBoolean:
if v.Kind() == reflect.Bool {
return v, nil
} else if v.Kind() == reflect.String {

View File

@ -8,15 +8,44 @@ import (
"github.com/openai/openai-go"
)
type object struct {
const (
// SyntheticFieldPrefix is any prefix that is added to any synthetic fields that are added to the object, to prevent
// collisions with the fields in the struct.
SyntheticFieldPrefix = "__"
)
type Object struct {
basic
ref reflect.Type
fields map[string]Type
// syntheticFields are fields that are not in the struct but are generated by a system.
synetheticFields map[string]Type
}
func (o object) OpenAIParameters() openai.FunctionParameters {
func (o Object) WithSyntheticField(name string, description string) Object {
if o.synetheticFields == nil {
o.synetheticFields = map[string]Type{}
}
o.synetheticFields[name] = basic{
DataType: TypeString,
typeName: "string",
index: -1,
required: false,
description: description,
}
return o
}
func (o Object) SyntheticFields() map[string]Type {
return o.synetheticFields
}
func (o Object) OpenAIParameters() openai.FunctionParameters {
var properties = map[string]openai.FunctionParameters{}
var required []string
for k, v := range o.fields {
@ -26,6 +55,13 @@ func (o object) OpenAIParameters() openai.FunctionParameters {
}
}
for k, v := range o.synetheticFields {
properties[SyntheticFieldPrefix+k] = v.OpenAIParameters()
if v.Required() {
required = append(required, SyntheticFieldPrefix+k)
}
}
var res = openai.FunctionParameters{
"type": "object",
"description": o.Description(),
@ -39,7 +75,7 @@ func (o object) OpenAIParameters() openai.FunctionParameters {
return res
}
func (o object) GoogleParameters() *genai.Schema {
func (o Object) GoogleParameters() *genai.Schema {
var properties = map[string]*genai.Schema{}
var required []string
for k, v := range o.fields {
@ -62,7 +98,8 @@ func (o object) GoogleParameters() *genai.Schema {
return res
}
func (o object) FromAny(val any) (reflect.Value, error) {
// FromAny converts the value from any to the correct type, returning the value, and an error if any
func (o Object) FromAny(val any) (reflect.Value, error) {
// if the value is nil, we can't do anything
if val == nil {
return reflect.Value{}, nil
@ -99,7 +136,7 @@ func (o object) FromAny(val any) (reflect.Value, error) {
return obj, nil
}
func (o object) SetValueOnField(obj reflect.Value, val reflect.Value) {
func (o Object) SetValueOnField(obj reflect.Value, val reflect.Value) {
// if this basic type is not required that means it's a pointer type so we need to set the value to the address of the value
if !o.required {
val = val.Addr()

View File

@ -4,79 +4,82 @@ import (
"context"
"errors"
"fmt"
"github.com/sashabaranov/go-openai"
)
// ToolBox is a collection of tools that OpenAI can use to execute functions.
// It is a wrapper around a collection of functions, and provides a way to automatically call the correct function with
// the correct parameters.
type ToolBox struct {
funcs []Function
names map[string]Function
functions map[string]Function
dontRequireTool bool
}
func NewToolBox(fns ...*Function) *ToolBox {
func NewToolBox(fns ...Function) ToolBox {
res := ToolBox{
funcs: []Function{},
names: map[string]Function{},
functions: map[string]Function{},
}
for _, f := range fns {
o := *f
res.names[o.Name] = o
res.funcs = append(res.funcs, o)
}
return &res
}
func (t *ToolBox) WithFunction(f Function) *ToolBox {
t2 := *t
t2.names[f.Name] = f
t2.funcs = append(t2.funcs, f)
return &t2
}
func (t *ToolBox) WithFunctionRemoved(name string) *ToolBox {
t2 := *t
delete(t2.names, name)
for i, f := range t2.funcs {
if f.Name == name {
t2.funcs = append(t2.funcs[:i], t2.funcs[i+1:]...)
break
}
}
return &t2
}
func (t *ToolBox) WithRequireTool(val bool) *ToolBox {
t2 := *t
t2.dontRequireTool = !val
return &t2
}
// ToOpenAI will convert the current ToolBox to a slice of openai.Tool, which can be used to send to the OpenAI API.
func (t *ToolBox) toOpenAI() []openai.Tool {
var res []openai.Tool
for _, f := range t.funcs {
res = append(res, openai.Tool{
Type: "function",
Function: f.toOpenAIFunction(),
})
res.functions[f.Name] = f
}
return res
}
func (t *ToolBox) ToToolChoice() any {
if len(t.funcs) == 0 {
func (t ToolBox) Functions() []Function {
var res []Function
for _, f := range t.functions {
res = append(res, f)
}
return res
}
func (t ToolBox) WithFunction(f Function) ToolBox {
t.functions[f.Name] = f
return t
}
func (t ToolBox) WithFunctions(fns ...Function) ToolBox {
for _, f := range fns {
t.functions[f.Name] = f
}
return t
}
func (t ToolBox) WithSyntheticFieldsAddedToAllFunctions(fieldsAndDescriptions map[string]string) ToolBox {
for k, v := range t.functions {
t.functions[k] = v.WithSyntheticFields(fieldsAndDescriptions)
}
return t
}
func (t ToolBox) ForEachFunction(fn func(f Function)) {
for _, f := range t.functions {
fn(f)
}
}
func (t ToolBox) WithFunctionRemoved(name string) ToolBox {
delete(t.functions, name)
return t
}
func (t ToolBox) WithRequireTool(val bool) ToolBox {
t.dontRequireTool = !val
return t
}
func (t ToolBox) RequiresTool() bool {
return !t.dontRequireTool && len(t.functions) > 0
}
func (t ToolBox) ToToolChoice() any {
if len(t.functions) == 0 {
return nil
}
@ -87,8 +90,8 @@ var (
ErrFunctionNotFound = errors.New("function not found")
)
func (t *ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) {
f, ok := t.names[functionName]
func (t ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) {
f, ok := t.functions[functionName]
if !ok {
return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName))
@ -97,14 +100,29 @@ func (t *ToolBox) executeFunction(ctx *Context, functionName string, params stri
return f.Execute(ctx, params)
}
func (t *ToolBox) Execute(ctx *Context, toolCall ToolCall) (any, error) {
func (t ToolBox) Execute(ctx *Context, toolCall ToolCall) (any, error) {
return t.executeFunction(ctx.WithToolCall(&toolCall), toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments)
}
func (t ToolBox) GetSyntheticParametersFromFunctionContext(ctx context.Context) map[string]string {
val := ctx.Value("syntheticParameters")
if val == nil {
return nil
}
syntheticParameters, ok := val.(map[string]string)
if !ok {
return nil
}
return syntheticParameters
}
// ExecuteCallbacks will execute all the tool calls in the given list, and call the given callbacks when a new function is created, and when a function is finished.
// OnNewFunction is called when a new function is created
// OnFunctionFinished is called when a function is finished
func (t *ToolBox) ExecuteCallbacks(ctx *Context, toolCalls []ToolCall, OnNewFunction func(ctx context.Context, funcName string, parameter string) (any, error), OnFunctionFinished func(ctx context.Context, funcName string, parameter string, result any, err error, newFunctionResult any) error) ([]ToolCallResponse, error) {
func (t ToolBox) ExecuteCallbacks(ctx *Context, toolCalls []ToolCall, OnNewFunction func(ctx context.Context, funcName string, parameter string) (any, error), OnFunctionFinished func(ctx context.Context, funcName string, parameter string, result any, err error, newFunctionResult any) error) ([]ToolCallResponse, error) {
var res []ToolCallResponse
for _, call := range toolCalls {