Refactor toolbox and function handling to support synthetic fields and improve type definitions
This commit is contained in:
parent
2ae583e9f3
commit
3093b988f8
@ -134,15 +134,16 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
|
||||
}
|
||||
}
|
||||
|
||||
if req.Toolbox != nil {
|
||||
for _, tool := range req.Toolbox.funcs {
|
||||
/*
|
||||
for _, tool := range req.Toolbox.functions {
|
||||
res.Tools = append(res.Tools, anth.ToolDefinition{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
InputSchema: tool.Parameters,
|
||||
InputSchema: tool.Parameters.OpenAIParameters(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
*/
|
||||
|
||||
res.Messages = msgs
|
||||
|
||||
|
40
context.go
40
context.go
@ -7,9 +7,10 @@ import (
|
||||
|
||||
type Context struct {
|
||||
context.Context
|
||||
request Request
|
||||
response *ResponseChoice
|
||||
toolcall *ToolCall
|
||||
request Request
|
||||
response *ResponseChoice
|
||||
toolcall *ToolCall
|
||||
syntheticFields map[string]string
|
||||
}
|
||||
|
||||
func (c *Context) ToNewRequest(toolResults ...ToolCallResponse) Request {
|
||||
@ -55,20 +56,32 @@ func (c *Context) ToolCall() *ToolCall {
|
||||
return c.toolcall
|
||||
}
|
||||
|
||||
func (c *Context) SyntheticFields() map[string]string {
|
||||
if c.syntheticFields == nil {
|
||||
c.syntheticFields = map[string]string{}
|
||||
}
|
||||
|
||||
return c.syntheticFields
|
||||
}
|
||||
|
||||
func (c *Context) WithContext(ctx context.Context) *Context {
|
||||
return &Context{Context: ctx, request: c.request, response: c.response, toolcall: c.toolcall}
|
||||
return &Context{Context: ctx, request: c.request, response: c.response, toolcall: c.toolcall, syntheticFields: c.syntheticFields}
|
||||
}
|
||||
|
||||
func (c *Context) WithRequest(request Request) *Context {
|
||||
return &Context{Context: c.Context, request: request, response: c.response, toolcall: c.toolcall}
|
||||
return &Context{Context: c.Context, request: request, response: c.response, toolcall: c.toolcall, syntheticFields: c.syntheticFields}
|
||||
}
|
||||
|
||||
func (c *Context) WithResponse(response *ResponseChoice) *Context {
|
||||
return &Context{Context: c.Context, request: c.request, response: response, toolcall: c.toolcall}
|
||||
return &Context{Context: c.Context, request: c.request, response: response, toolcall: c.toolcall, syntheticFields: c.syntheticFields}
|
||||
}
|
||||
|
||||
func (c *Context) WithToolCall(toolcall *ToolCall) *Context {
|
||||
return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: toolcall}
|
||||
return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: toolcall, syntheticFields: c.syntheticFields}
|
||||
}
|
||||
|
||||
func (c *Context) WithSyntheticFields(syntheticFields map[string]string) *Context {
|
||||
return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: c.toolcall, syntheticFields: syntheticFields}
|
||||
}
|
||||
|
||||
func (c *Context) Deadline() (deadline time.Time, ok bool) {
|
||||
@ -84,8 +97,19 @@ func (c *Context) Err() error {
|
||||
}
|
||||
|
||||
func (c *Context) Value(key any) any {
|
||||
if key == "request" {
|
||||
switch key {
|
||||
case "request":
|
||||
return c.request
|
||||
|
||||
case "response":
|
||||
return c.response
|
||||
|
||||
case "toolcall":
|
||||
return c.toolcall
|
||||
|
||||
case "syntheticFields":
|
||||
return c.syntheticFields
|
||||
|
||||
}
|
||||
return c.Context.Value(key)
|
||||
}
|
||||
|
65
function.go
65
function.go
@ -4,11 +4,10 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
|
||||
)
|
||||
|
||||
@ -29,16 +28,63 @@ type Function struct {
|
||||
paramType reflect.Type
|
||||
}
|
||||
|
||||
func (f *Function) Execute(ctx *Context, input string) (any, error) {
|
||||
func (f Function) WithSyntheticField(name string, description string) Function {
|
||||
if obj, o := f.Parameters.(schema.Object); o {
|
||||
f.Parameters = obj.WithSyntheticField(name, description)
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
func (f Function) WithSyntheticFields(fieldsAndDescriptions map[string]string) Function {
|
||||
if obj, o := f.Parameters.(schema.Object); o {
|
||||
for k, v := range fieldsAndDescriptions {
|
||||
obj = obj.WithSyntheticField(k, v)
|
||||
}
|
||||
f.Parameters = obj
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
func (f Function) Execute(ctx *Context, input string) (any, error) {
|
||||
if !f.fn.IsValid() {
|
||||
return "", fmt.Errorf("function %s is not implemented", f.Name)
|
||||
}
|
||||
|
||||
slog.Info("Function.Execute", "name", f.Name, "input", input, "f", f.paramType)
|
||||
// first, we need to parse the input into the struct
|
||||
p := reflect.New(f.paramType)
|
||||
fmt.Println("Function.Execute", f.Name, "input:", input)
|
||||
//m := map[string]any{}
|
||||
err := json.Unmarshal([]byte(input), p.Interface())
|
||||
|
||||
var vals map[string]any
|
||||
err := json.Unmarshal([]byte(input), &vals)
|
||||
|
||||
var syntheticFields map[string]string
|
||||
|
||||
// first eat up any synthetic fields
|
||||
if obj, o := f.Parameters.(schema.Object); o {
|
||||
for k := range obj.SyntheticFields() {
|
||||
key := schema.SyntheticFieldPrefix + k
|
||||
if val, ok := vals[key]; ok {
|
||||
if syntheticFields == nil {
|
||||
syntheticFields = map[string]string{}
|
||||
}
|
||||
|
||||
syntheticFields[k] = fmt.Sprint(val)
|
||||
delete(vals, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// now for any remaining fields, re-marshal them into json and then unmarshal into the struct
|
||||
b, err := json.Marshal(vals)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal input: %w (input: %s)", err, input)
|
||||
}
|
||||
|
||||
// now we can unmarshal the input into the struct
|
||||
err = json.Unmarshal(b, p.Interface())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to unmarshal input: %w (input: %s)", err, input)
|
||||
}
|
||||
@ -67,15 +113,6 @@ func (f *Function) Execute(ctx *Context, input string) (any, error) {
|
||||
return exec(ctx)
|
||||
}
|
||||
|
||||
func (f *Function) toOpenAIFunction() *openai.FunctionDefinition {
|
||||
return &openai.FunctionDefinition{
|
||||
Name: f.Name,
|
||||
Description: f.Description,
|
||||
Strict: f.Strict,
|
||||
Parameters: f.Parameters,
|
||||
}
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
|
@ -13,7 +13,7 @@ import (
|
||||
// The struct parameters can have the following tags:
|
||||
// - Description: a string that describes the parameter, passed to openaiImpl to tell it what the parameter is for
|
||||
|
||||
func NewFunction[T any](name string, description string, fn func(*Context, T) (any, error)) *Function {
|
||||
func NewFunction[T any](name string, description string, fn func(*Context, T) (any, error)) Function {
|
||||
var o T
|
||||
|
||||
res := Function{
|
||||
@ -31,5 +31,5 @@ func NewFunction[T any](name string, description string, fn func(*Context, T) (a
|
||||
panic("function parameter must be a struct")
|
||||
}
|
||||
|
||||
return &res
|
||||
return res
|
||||
}
|
||||
|
9
go.mod
9
go.mod
@ -5,8 +5,7 @@ go 1.23.1
|
||||
require (
|
||||
github.com/google/generative-ai-go v0.19.0
|
||||
github.com/liushuangls/go-anthropic/v2 v2.15.0
|
||||
github.com/openai/openai-go v0.1.0-beta.6
|
||||
github.com/sashabaranov/go-openai v1.38.1
|
||||
github.com/openai/openai-go v0.1.0-beta.9
|
||||
google.golang.org/api v0.228.0
|
||||
)
|
||||
|
||||
@ -35,14 +34,14 @@ require (
|
||||
go.opentelemetry.io/otel/metric v1.35.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.35.0 // indirect
|
||||
golang.org/x/crypto v0.37.0 // indirect
|
||||
golang.org/x/net v0.38.0 // indirect
|
||||
golang.org/x/net v0.39.0 // indirect
|
||||
golang.org/x/oauth2 v0.29.0 // indirect
|
||||
golang.org/x/sync v0.13.0 // indirect
|
||||
golang.org/x/sys v0.32.0 // indirect
|
||||
golang.org/x/text v0.24.0 // indirect
|
||||
golang.org/x/time v0.11.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250404141209-ee84b53bf3d0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250404141209-ee84b53bf3d0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a // indirect
|
||||
google.golang.org/grpc v1.71.1 // indirect
|
||||
google.golang.org/protobuf v1.36.6 // indirect
|
||||
)
|
||||
|
10
go.sum
10
go.sum
@ -37,10 +37,10 @@ github.com/liushuangls/go-anthropic/v2 v2.15.0 h1:zpplg7BRV/9FlMmeMPI0eDwhViB0l9
|
||||
github.com/liushuangls/go-anthropic/v2 v2.15.0/go.mod h1:kq2yW3JVy1/rph8u5KzX7F3q95CEpCT2RXp/2nfCmb4=
|
||||
github.com/openai/openai-go v0.1.0-beta.6 h1:JquYDpprfrGnlKvQQg+apy9dQ8R9mIrm+wNvAPp6jCQ=
|
||||
github.com/openai/openai-go v0.1.0-beta.6/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
|
||||
github.com/openai/openai-go v0.1.0-beta.9 h1:ABpubc5yU/3ejee2GgRrbFta81SG/d7bQbB8mIdP0Xo=
|
||||
github.com/openai/openai-go v0.1.0-beta.9/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/sashabaranov/go-openai v1.38.1 h1:TtZabbFQZa1nEni/IhVtDF/WQjVqDgd+cWR5OeddzF8=
|
||||
github.com/sashabaranov/go-openai v1.38.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
@ -73,6 +73,8 @@ golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
||||
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
||||
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
||||
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
|
||||
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
|
||||
golang.org/x/oauth2 v0.29.0 h1:WdYw2tdTK1S8olAzWHdgeqfy+Mtm9XNhv/xJsY65d98=
|
||||
golang.org/x/oauth2 v0.29.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
|
||||
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
|
||||
@ -87,8 +89,12 @@ google.golang.org/api v0.228.0 h1:X2DJ/uoWGnY5obVjewbp8icSL5U4FzuCfy9OjbLSnLs=
|
||||
google.golang.org/api v0.228.0/go.mod h1:wNvRS1Pbe8r4+IfBIniV8fwCpGwTrYa+kMUDiC5z5a4=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250404141209-ee84b53bf3d0 h1:Qbb5RVn5xzI4naMJSpJ7lhvmos6UwZkbekd5Uz7rt9E=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250404141209-ee84b53bf3d0/go.mod h1:6T35kB3IPpdw7Wul09by0G/JuOuIFkXV6OOvt8IZeT8=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a h1:OQ7sHVzkx6L57dQpzUS4ckfWJ51KDH74XHTDe23xWAs=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a/go.mod h1:2R6XrVC8Oc08GlNh8ujEpc7HkLiEZ16QeY7FxIs20ac=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250404141209-ee84b53bf3d0 h1:0K7wTWyzxZ7J+L47+LbFogJW1nn/gnnMCN0vGXNYtTI=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250404141209-ee84b53bf3d0/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a h1:GIqLhp/cYUkuGuiT+vJk8vhOP86L4+SP5j8yXgeVpvI=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
|
||||
google.golang.org/grpc v1.71.1 h1:ffsFWr7ygTUscGPI0KKK6TLrGz0476KUvvsbqWK0rPI=
|
||||
google.golang.org/grpc v1.71.1/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec=
|
||||
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
|
||||
|
31
google.go
31
google.go
@ -23,25 +23,22 @@ func (g google) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
||||
func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (*genai.GenerativeModel, *genai.ChatSession, []genai.Part) {
|
||||
res := *model
|
||||
|
||||
if in.Toolbox != nil {
|
||||
for _, tool := range in.Toolbox.funcs {
|
||||
|
||||
res.Tools = append(res.Tools, &genai.Tool{
|
||||
FunctionDeclarations: []*genai.FunctionDeclaration{
|
||||
{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
Parameters: tool.Parameters.GoogleParameters(),
|
||||
},
|
||||
for _, tool := range in.Toolbox.functions {
|
||||
res.Tools = append(res.Tools, &genai.Tool{
|
||||
FunctionDeclarations: []*genai.FunctionDeclaration{
|
||||
{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
Parameters: tool.Parameters.GoogleParameters(),
|
||||
},
|
||||
})
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if !in.Toolbox.dontRequireTool {
|
||||
res.ToolConfig = &genai.ToolConfig{FunctionCallingConfig: &genai.FunctionCallingConfig{
|
||||
Mode: genai.FunctionCallingAny,
|
||||
}}
|
||||
}
|
||||
if !in.Toolbox.RequiresTool() {
|
||||
res.ToolConfig = &genai.ToolConfig{FunctionCallingConfig: &genai.FunctionCallingConfig{
|
||||
Mode: genai.FunctionCallingAny,
|
||||
}}
|
||||
}
|
||||
|
||||
cs := res.StartChat()
|
||||
|
30
openai.go
30
openai.go
@ -31,23 +31,21 @@ func (o openaiImpl) newRequestToOpenAIRequest(request Request) openai.ChatComple
|
||||
res.Messages = append(res.Messages, msg.toChatCompletionMessages(o.model)...)
|
||||
}
|
||||
|
||||
if request.Toolbox != nil {
|
||||
for _, tool := range request.Toolbox.funcs {
|
||||
res.Tools = append(res.Tools, openai.ChatCompletionToolParam{
|
||||
Type: "function",
|
||||
Function: shared.FunctionDefinitionParam{
|
||||
Name: tool.Name,
|
||||
Description: openai.String(tool.Description),
|
||||
Strict: openai.Bool(tool.Strict),
|
||||
Parameters: tool.Parameters.OpenAIParameters(),
|
||||
},
|
||||
})
|
||||
}
|
||||
for _, tool := range request.Toolbox.functions {
|
||||
res.Tools = append(res.Tools, openai.ChatCompletionToolParam{
|
||||
Type: "function",
|
||||
Function: shared.FunctionDefinitionParam{
|
||||
Name: tool.Name,
|
||||
Description: openai.String(tool.Description),
|
||||
Strict: openai.Bool(tool.Strict),
|
||||
Parameters: tool.Parameters.OpenAIParameters(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if !request.Toolbox.dontRequireTool {
|
||||
res.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{
|
||||
OfAuto: openai.String("required"),
|
||||
}
|
||||
if request.Toolbox.RequiresTool() {
|
||||
res.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{
|
||||
OfAuto: openai.String("required"),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -15,7 +15,7 @@ type Input interface {
|
||||
type Request struct {
|
||||
Conversation []Input
|
||||
Messages []Message
|
||||
Toolbox *ToolBox
|
||||
Toolbox ToolBox
|
||||
Temperature *float64
|
||||
}
|
||||
|
||||
|
@ -25,27 +25,27 @@ func getFromType(t reflect.Type, b basic) Type {
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.String:
|
||||
b.DataType = String
|
||||
b.DataType = TypeString
|
||||
b.typeName = "string"
|
||||
return b
|
||||
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
b.DataType = Integer
|
||||
b.DataType = TypeInteger
|
||||
b.typeName = "integer"
|
||||
return b
|
||||
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
b.DataType = Integer
|
||||
b.DataType = TypeInteger
|
||||
b.typeName = "integer"
|
||||
return b
|
||||
|
||||
case reflect.Float32, reflect.Float64:
|
||||
b.DataType = Number
|
||||
b.DataType = TypeNumber
|
||||
b.typeName = "number"
|
||||
return b
|
||||
|
||||
case reflect.Bool:
|
||||
b.DataType = Boolean
|
||||
b.DataType = TypeBoolean
|
||||
b.typeName = "boolean"
|
||||
return b
|
||||
|
||||
@ -92,7 +92,7 @@ func getField(f reflect.StructField, index int) Type {
|
||||
}
|
||||
}
|
||||
|
||||
b.DataType = String
|
||||
b.DataType = TypeString
|
||||
b.typeName = "string"
|
||||
return enum{
|
||||
basic: b,
|
||||
@ -104,15 +104,26 @@ func getField(f reflect.StructField, index int) Type {
|
||||
return getFromType(t, b)
|
||||
}
|
||||
|
||||
func getObject(t reflect.Type) object {
|
||||
func getObject(t reflect.Type) Object {
|
||||
fields := make(map[string]Type, t.NumField())
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
fields[field.Name] = getField(field, i)
|
||||
|
||||
if field.Anonymous {
|
||||
// if the field is anonymous, we need to get the fields of the anonymous struct
|
||||
// and add them to the object
|
||||
anon := getObject(field.Type)
|
||||
for k, v := range anon.fields {
|
||||
fields[k] = v
|
||||
}
|
||||
continue
|
||||
} else {
|
||||
fields[field.Name] = getField(field, i)
|
||||
}
|
||||
}
|
||||
|
||||
return object{
|
||||
basic: basic{DataType: Object, typeName: "object"},
|
||||
return Object{
|
||||
basic: basic{DataType: TypeObject, typeName: "object"},
|
||||
fields: fields,
|
||||
}
|
||||
}
|
||||
@ -120,7 +131,7 @@ func getObject(t reflect.Type) object {
|
||||
func getArray(t reflect.Type) array {
|
||||
res := array{
|
||||
basic: basic{
|
||||
DataType: Array,
|
||||
DataType: TypeArray,
|
||||
typeName: "array",
|
||||
},
|
||||
}
|
||||
|
@ -15,12 +15,12 @@ var _ Type = basic{}
|
||||
type DataType string
|
||||
|
||||
const (
|
||||
String DataType = "string"
|
||||
Integer DataType = "integer"
|
||||
Number DataType = "number"
|
||||
Boolean DataType = "boolean"
|
||||
Object DataType = "object"
|
||||
Array DataType = "array"
|
||||
TypeString DataType = "string"
|
||||
TypeInteger DataType = "integer"
|
||||
TypeNumber DataType = "number"
|
||||
TypeBoolean DataType = "boolean"
|
||||
TypeObject DataType = "object"
|
||||
TypeArray DataType = "array"
|
||||
)
|
||||
|
||||
type basic struct {
|
||||
@ -49,17 +49,17 @@ func (b basic) GoogleParameters() *genai.Schema {
|
||||
var t = genai.TypeUnspecified
|
||||
|
||||
switch b.DataType {
|
||||
case String:
|
||||
case TypeString:
|
||||
t = genai.TypeString
|
||||
case Integer:
|
||||
case TypeInteger:
|
||||
t = genai.TypeInteger
|
||||
case Number:
|
||||
case TypeNumber:
|
||||
t = genai.TypeNumber
|
||||
case Boolean:
|
||||
case TypeBoolean:
|
||||
t = genai.TypeBoolean
|
||||
case Object:
|
||||
case TypeObject:
|
||||
t = genai.TypeObject
|
||||
case Array:
|
||||
case TypeArray:
|
||||
t = genai.TypeArray
|
||||
default:
|
||||
t = genai.TypeUnspecified
|
||||
@ -82,12 +82,12 @@ func (b basic) FromAny(val any) (reflect.Value, error) {
|
||||
v := reflect.ValueOf(val)
|
||||
|
||||
switch b.DataType {
|
||||
case String:
|
||||
case TypeString:
|
||||
var val = v.String()
|
||||
|
||||
return reflect.ValueOf(val), nil
|
||||
|
||||
case Integer:
|
||||
case TypeInteger:
|
||||
if v.Kind() == reflect.Float64 {
|
||||
return v.Convert(reflect.TypeOf(int(0))), nil
|
||||
} else if v.Kind() != reflect.Int {
|
||||
@ -96,7 +96,7 @@ func (b basic) FromAny(val any) (reflect.Value, error) {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
case Number:
|
||||
case TypeNumber:
|
||||
if v.Kind() == reflect.Float64 {
|
||||
return v.Convert(reflect.TypeOf(float64(0))), nil
|
||||
} else if v.Kind() != reflect.Float64 {
|
||||
@ -105,7 +105,7 @@ func (b basic) FromAny(val any) (reflect.Value, error) {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
case Boolean:
|
||||
case TypeBoolean:
|
||||
if v.Kind() == reflect.Bool {
|
||||
return v, nil
|
||||
} else if v.Kind() == reflect.String {
|
||||
|
@ -8,15 +8,44 @@ import (
|
||||
"github.com/openai/openai-go"
|
||||
)
|
||||
|
||||
type object struct {
|
||||
const (
|
||||
// SyntheticFieldPrefix is any prefix that is added to any synthetic fields that are added to the object, to prevent
|
||||
// collisions with the fields in the struct.
|
||||
SyntheticFieldPrefix = "__"
|
||||
)
|
||||
|
||||
type Object struct {
|
||||
basic
|
||||
|
||||
ref reflect.Type
|
||||
|
||||
fields map[string]Type
|
||||
|
||||
// syntheticFields are fields that are not in the struct but are generated by a system.
|
||||
synetheticFields map[string]Type
|
||||
}
|
||||
|
||||
func (o object) OpenAIParameters() openai.FunctionParameters {
|
||||
func (o Object) WithSyntheticField(name string, description string) Object {
|
||||
if o.synetheticFields == nil {
|
||||
o.synetheticFields = map[string]Type{}
|
||||
}
|
||||
|
||||
o.synetheticFields[name] = basic{
|
||||
DataType: TypeString,
|
||||
typeName: "string",
|
||||
index: -1,
|
||||
required: false,
|
||||
description: description,
|
||||
}
|
||||
|
||||
return o
|
||||
}
|
||||
|
||||
func (o Object) SyntheticFields() map[string]Type {
|
||||
return o.synetheticFields
|
||||
}
|
||||
|
||||
func (o Object) OpenAIParameters() openai.FunctionParameters {
|
||||
var properties = map[string]openai.FunctionParameters{}
|
||||
var required []string
|
||||
for k, v := range o.fields {
|
||||
@ -26,6 +55,13 @@ func (o object) OpenAIParameters() openai.FunctionParameters {
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range o.synetheticFields {
|
||||
properties[SyntheticFieldPrefix+k] = v.OpenAIParameters()
|
||||
if v.Required() {
|
||||
required = append(required, SyntheticFieldPrefix+k)
|
||||
}
|
||||
}
|
||||
|
||||
var res = openai.FunctionParameters{
|
||||
"type": "object",
|
||||
"description": o.Description(),
|
||||
@ -39,7 +75,7 @@ func (o object) OpenAIParameters() openai.FunctionParameters {
|
||||
return res
|
||||
}
|
||||
|
||||
func (o object) GoogleParameters() *genai.Schema {
|
||||
func (o Object) GoogleParameters() *genai.Schema {
|
||||
var properties = map[string]*genai.Schema{}
|
||||
var required []string
|
||||
for k, v := range o.fields {
|
||||
@ -62,7 +98,8 @@ func (o object) GoogleParameters() *genai.Schema {
|
||||
return res
|
||||
}
|
||||
|
||||
func (o object) FromAny(val any) (reflect.Value, error) {
|
||||
// FromAny converts the value from any to the correct type, returning the value, and an error if any
|
||||
func (o Object) FromAny(val any) (reflect.Value, error) {
|
||||
// if the value is nil, we can't do anything
|
||||
if val == nil {
|
||||
return reflect.Value{}, nil
|
||||
@ -99,7 +136,7 @@ func (o object) FromAny(val any) (reflect.Value, error) {
|
||||
return obj, nil
|
||||
}
|
||||
|
||||
func (o object) SetValueOnField(obj reflect.Value, val reflect.Value) {
|
||||
func (o Object) SetValueOnField(obj reflect.Value, val reflect.Value) {
|
||||
// if this basic type is not required that means it's a pointer type so we need to set the value to the address of the value
|
||||
if !o.required {
|
||||
val = val.Addr()
|
||||
|
136
toolbox.go
136
toolbox.go
@ -4,79 +4,82 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
// ToolBox is a collection of tools that OpenAI can use to execute functions.
|
||||
// It is a wrapper around a collection of functions, and provides a way to automatically call the correct function with
|
||||
// the correct parameters.
|
||||
type ToolBox struct {
|
||||
funcs []Function
|
||||
names map[string]Function
|
||||
functions map[string]Function
|
||||
dontRequireTool bool
|
||||
}
|
||||
|
||||
func NewToolBox(fns ...*Function) *ToolBox {
|
||||
func NewToolBox(fns ...Function) ToolBox {
|
||||
res := ToolBox{
|
||||
funcs: []Function{},
|
||||
names: map[string]Function{},
|
||||
functions: map[string]Function{},
|
||||
}
|
||||
|
||||
for _, f := range fns {
|
||||
o := *f
|
||||
res.names[o.Name] = o
|
||||
res.funcs = append(res.funcs, o)
|
||||
}
|
||||
|
||||
return &res
|
||||
}
|
||||
|
||||
func (t *ToolBox) WithFunction(f Function) *ToolBox {
|
||||
t2 := *t
|
||||
t2.names[f.Name] = f
|
||||
t2.funcs = append(t2.funcs, f)
|
||||
|
||||
return &t2
|
||||
}
|
||||
|
||||
func (t *ToolBox) WithFunctionRemoved(name string) *ToolBox {
|
||||
t2 := *t
|
||||
|
||||
delete(t2.names, name)
|
||||
|
||||
for i, f := range t2.funcs {
|
||||
if f.Name == name {
|
||||
t2.funcs = append(t2.funcs[:i], t2.funcs[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return &t2
|
||||
}
|
||||
|
||||
func (t *ToolBox) WithRequireTool(val bool) *ToolBox {
|
||||
t2 := *t
|
||||
t2.dontRequireTool = !val
|
||||
return &t2
|
||||
}
|
||||
|
||||
// ToOpenAI will convert the current ToolBox to a slice of openai.Tool, which can be used to send to the OpenAI API.
|
||||
func (t *ToolBox) toOpenAI() []openai.Tool {
|
||||
var res []openai.Tool
|
||||
|
||||
for _, f := range t.funcs {
|
||||
res = append(res, openai.Tool{
|
||||
Type: "function",
|
||||
Function: f.toOpenAIFunction(),
|
||||
})
|
||||
res.functions[f.Name] = f
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (t *ToolBox) ToToolChoice() any {
|
||||
if len(t.funcs) == 0 {
|
||||
func (t ToolBox) Functions() []Function {
|
||||
var res []Function
|
||||
|
||||
for _, f := range t.functions {
|
||||
res = append(res, f)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (t ToolBox) WithFunction(f Function) ToolBox {
|
||||
t.functions[f.Name] = f
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (t ToolBox) WithFunctions(fns ...Function) ToolBox {
|
||||
for _, f := range fns {
|
||||
t.functions[f.Name] = f
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (t ToolBox) WithSyntheticFieldsAddedToAllFunctions(fieldsAndDescriptions map[string]string) ToolBox {
|
||||
for k, v := range t.functions {
|
||||
t.functions[k] = v.WithSyntheticFields(fieldsAndDescriptions)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (t ToolBox) ForEachFunction(fn func(f Function)) {
|
||||
for _, f := range t.functions {
|
||||
fn(f)
|
||||
}
|
||||
}
|
||||
|
||||
func (t ToolBox) WithFunctionRemoved(name string) ToolBox {
|
||||
delete(t.functions, name)
|
||||
return t
|
||||
}
|
||||
|
||||
func (t ToolBox) WithRequireTool(val bool) ToolBox {
|
||||
t.dontRequireTool = !val
|
||||
return t
|
||||
}
|
||||
|
||||
func (t ToolBox) RequiresTool() bool {
|
||||
return !t.dontRequireTool && len(t.functions) > 0
|
||||
}
|
||||
|
||||
func (t ToolBox) ToToolChoice() any {
|
||||
if len(t.functions) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -87,8 +90,8 @@ var (
|
||||
ErrFunctionNotFound = errors.New("function not found")
|
||||
)
|
||||
|
||||
func (t *ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) {
|
||||
f, ok := t.names[functionName]
|
||||
func (t ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) {
|
||||
f, ok := t.functions[functionName]
|
||||
|
||||
if !ok {
|
||||
return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName))
|
||||
@ -97,14 +100,29 @@ func (t *ToolBox) executeFunction(ctx *Context, functionName string, params stri
|
||||
return f.Execute(ctx, params)
|
||||
}
|
||||
|
||||
func (t *ToolBox) Execute(ctx *Context, toolCall ToolCall) (any, error) {
|
||||
func (t ToolBox) Execute(ctx *Context, toolCall ToolCall) (any, error) {
|
||||
return t.executeFunction(ctx.WithToolCall(&toolCall), toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments)
|
||||
}
|
||||
|
||||
func (t ToolBox) GetSyntheticParametersFromFunctionContext(ctx context.Context) map[string]string {
|
||||
val := ctx.Value("syntheticParameters")
|
||||
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
syntheticParameters, ok := val.(map[string]string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return syntheticParameters
|
||||
}
|
||||
|
||||
// ExecuteCallbacks will execute all the tool calls in the given list, and call the given callbacks when a new function is created, and when a function is finished.
|
||||
// OnNewFunction is called when a new function is created
|
||||
// OnFunctionFinished is called when a function is finished
|
||||
func (t *ToolBox) ExecuteCallbacks(ctx *Context, toolCalls []ToolCall, OnNewFunction func(ctx context.Context, funcName string, parameter string) (any, error), OnFunctionFinished func(ctx context.Context, funcName string, parameter string, result any, err error, newFunctionResult any) error) ([]ToolCallResponse, error) {
|
||||
func (t ToolBox) ExecuteCallbacks(ctx *Context, toolCalls []ToolCall, OnNewFunction func(ctx context.Context, funcName string, parameter string) (any, error), OnFunctionFinished func(ctx context.Context, funcName string, parameter string, result any, err error, newFunctionResult any) error) ([]ToolCallResponse, error) {
|
||||
var res []ToolCallResponse
|
||||
|
||||
for _, call := range toolCalls {
|
||||
|
Loading…
x
Reference in New Issue
Block a user