From 37939088edc06fd6526c021367234c9a630c381c Mon Sep 17 00:00:00 2001 From: Steve Dudenhoeffer Date: Fri, 8 Nov 2024 20:53:12 -0500 Subject: [PATCH] initial commit of untested function stuff --- error.go | 21 ++++++++ function.go | 86 +++++++++++++++++++++++++++++-- functions.go | 35 +++++++++++++ go.mod | 7 +-- go.sum | 13 ++++- google.go | 2 +- llm.go | 5 +- openai.go | 16 +++--- schema/GetType.go | 125 ++++++++++++++++++++++++++++++++++++++++++++++ schema/array.go | 65 ++++++++++++++++++++++++ schema/basic.go | 105 ++++++++++++++++++++++++++++++++++++++ schema/enum.go | 47 +++++++++++++++++ schema/object.go | 78 +++++++++++++++++++++++++++++ schema/type.go | 18 +++++++ toolbox.go | 90 +++++++++++++++++++++++++++++++++ 15 files changed, 693 insertions(+), 20 deletions(-) create mode 100644 error.go create mode 100644 functions.go create mode 100644 schema/GetType.go create mode 100644 schema/array.go create mode 100644 schema/basic.go create mode 100644 schema/enum.go create mode 100644 schema/object.go create mode 100644 schema/type.go create mode 100644 toolbox.go diff --git a/error.go b/error.go new file mode 100644 index 0000000..67a4dec --- /dev/null +++ b/error.go @@ -0,0 +1,21 @@ +package go_llm + +import "fmt" + +// Error is essentially just an error, but it is used to differentiate between a normal error and a fatal error. +type Error struct { + error + + Source error + Parameter error +} + +func newError(parent error, err error) Error { + e := fmt.Errorf("%w: %w", parent, err) + return Error{ + error: e, + + Source: parent, + Parameter: err, + } +} diff --git a/function.go b/function.go index 5def8f5..173d4d7 100644 --- a/function.go +++ b/function.go @@ -1,10 +1,88 @@ package go_llm +import ( + "context" + "encoding/json" + "fmt" + "gitea.stevedudenhoeffer.com/steve/go-llm/schema" + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/jsonschema" + "reflect" + "time" +) + type Function struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - Strict bool `json:"strict,omitempty"` - Parameters any `json:"parameters"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Strict bool `json:"strict,omitempty"` + Parameters schema.Type `json:"parameters"` + + Forced bool `json:"forced,omitempty"` + + // Timeout is the maximum time to wait for the function to complete + Timeout time.Duration `json:"-"` + + // fn is the function to call, only set if this is constructed with NewFunction + fn reflect.Value + + paramType reflect.Type + + // definition is a cache of the openaiImpl jsonschema definition + definition *jsonschema.Definition +} + +func (f *Function) Execute(ctx context.Context, input string) (string, error) { + if !f.fn.IsValid() { + return "", fmt.Errorf("function %s is not implemented", f.Name) + } + + // first, we need to parse the input into the struct + p := reflect.New(f.paramType).Elem() + //m := map[string]any{} + err := json.Unmarshal([]byte(input), p.Addr().Interface()) + if err != nil { + return "", fmt.Errorf("failed to unmarshal input: %w", err) + } + + // now we can call the function + exec := func(ctx context.Context) (string, error) { + out := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p}) + + if len(out) != 2 { + return "", fmt.Errorf("function %s must return two values, got %d", f.Name, len(out)) + } + + if out[1].IsNil() { + return out[0].String(), nil + } + + return "", out[1].Interface().(error) + } + + var cancel context.CancelFunc + if f.Timeout > 0 { + ctx, cancel = context.WithTimeout(ctx, f.Timeout) + defer cancel() + } + + return exec(ctx) +} + +func (f *Function) toOpenAIFunction() *openai.FunctionDefinition { + return &openai.FunctionDefinition{ + Name: f.Name, + Description: f.Description, + Strict: f.Strict, + Parameters: f.Parameters, + } +} +func (f *Function) toOpenAIDefinition() jsonschema.Definition { + if f.definition == nil { + def := f.Parameters.Definition() + f.definition = &def + } + + return *f.definition } type FunctionCall struct { diff --git a/functions.go b/functions.go new file mode 100644 index 0000000..6b93bb6 --- /dev/null +++ b/functions.go @@ -0,0 +1,35 @@ +package go_llm + +import ( + "context" + "gitea.stevedudenhoeffer.com/steve/go-llm/schema" + "reflect" +) + +// Parse takes a function pointer and returns a function object. +// fn must be a pointer to a function that takes a context.Context as its first argument, and then a struct that contains +// the parameters for the function. The struct must contain only the types: string, int, float64, bool, and pointers to +// those types. +// The struct parameters can have the following tags: +// - Description: a string that describes the parameter, passed to openaiImpl to tell it what the parameter is for + +func NewFunction[T any](name string, description string, fn func(context.Context, T) (string, error)) *Function { + var o T + + res := Function{ + Name: name, + Description: description, + Parameters: schema.GetType(o), + fn: reflect.ValueOf(fn), + paramType: reflect.TypeOf(o), + } + + if res.fn.Kind() != reflect.Func { + panic("fn must be a function") + } + if res.paramType.Kind() != reflect.Struct { + panic("function parameter must be a struct") + } + + return &res +} diff --git a/go.mod b/go.mod index ff535eb..e4486d7 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,11 @@ module gitea.stevedudenhoeffer.com/steve/go-llm go 1.23.1 require ( + github.com/google/generative-ai-go v0.18.0 github.com/liushuangls/go-anthropic/v2 v2.8.0 github.com/sashabaranov/go-openai v1.31.0 + golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f + google.golang.org/api v0.186.0 ) require ( @@ -19,7 +22,6 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect - github.com/google/generative-ai-go v0.18.0 // indirect github.com/google/s2a-go v0.1.7 // indirect github.com/google/uuid v1.6.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect @@ -33,11 +35,10 @@ require ( golang.org/x/crypto v0.24.0 // indirect golang.org/x/net v0.26.0 // indirect golang.org/x/oauth2 v0.21.0 // indirect - golang.org/x/sync v0.7.0 // indirect + golang.org/x/sync v0.9.0 // indirect golang.org/x/sys v0.21.0 // indirect golang.org/x/text v0.16.0 // indirect golang.org/x/time v0.5.0 // indirect - google.golang.org/api v0.186.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240617180043-68d350f18fd4 // indirect google.golang.org/grpc v1.64.1 // indirect diff --git a/go.sum b/go.sum index 7b77f3a..7262523 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,7 @@ github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= @@ -52,6 +53,8 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -63,6 +66,7 @@ github.com/googleapis/gax-go/v2 v2.12.5 h1:8gw9KZK8TiVKB6q3zHY3SBzLnrGp6HQjyfYBY github.com/googleapis/gax-go/v2 v2.12.5/go.mod h1:BUDKcWo+RaKq5SC9vVYL0wLADa3VcfswbOMMRmB9H3E= github.com/liushuangls/go-anthropic/v2 v2.8.0 h1:0zH2jDNycbrlszxnLrG+Gx8vVT0yJAPWU4s3ZTkWzgI= github.com/liushuangls/go-anthropic/v2 v2.8.0/go.mod h1:8BKv/fkeTaL5R9R9bGkaknYBueyw2WxY20o7bImbOek= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/sashabaranov/go-openai v1.31.0 h1:rGe77x7zUeCjtS2IS7NCY6Tp4bQviXNMhkQM6hz/UC4= @@ -73,6 +77,8 @@ github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpE github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0 h1:A3SayB3rNyt+1S6qpI9mHPkeHTZbD7XILEqWnYZb2l0= @@ -90,6 +96,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f h1:XdNn9LlyWAhLVp6P/i8QYBW+hlyhrhei9uErw2B5GJo= +golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f/go.mod h1:D5SMRVC3C2/4+F/DB1wZsLRnSNimn2Sp/NPsCrsv8ak= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= @@ -107,8 +115,8 @@ golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbht golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= +golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -158,6 +166,7 @@ google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6h google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/google.go b/google.go index 5535023..778f389 100644 --- a/google.go +++ b/google.go @@ -31,7 +31,7 @@ func (g google) requestToGoogleRequest(in Request, model *genai.GenerativeModel) } for _, tool := range in.Toolbox { - panic("google toolbox is todo" + tool.Name) + panic("google ToolBox is todo" + tool.Name) /* t := genai.Tool{} diff --git a/llm.go b/llm.go index 87d435b..cb4b99a 100644 --- a/llm.go +++ b/llm.go @@ -27,7 +27,7 @@ type Message struct { type Request struct { Messages []Message - Toolbox []Function + Toolbox *ToolBox Temperature *float32 } @@ -50,6 +50,7 @@ type Response struct { type ChatCompletion interface { ChatComplete(ctx context.Context, req Request) (Response, error) + SplitLongString(ctx context.Context, input string) ([]string, error) } type LLM interface { @@ -57,7 +58,7 @@ type LLM interface { } func OpenAI(key string) LLM { - return openai{key: key} + return openaiImpl{key: key} } func Anthropic(key string) LLM { diff --git a/openai.go b/openai.go index 6f1909d..21ad71b 100644 --- a/openai.go +++ b/openai.go @@ -7,14 +7,14 @@ import ( "strings" ) -type openai struct { +type openaiImpl struct { key string model string } -var _ LLM = openai{} +var _ LLM = openaiImpl{} -func (o openai) requestToOpenAIRequest(request Request) oai.ChatCompletionRequest { +func (o openaiImpl) requestToOpenAIRequest(request Request) oai.ChatCompletionRequest { res := oai.ChatCompletionRequest{ Model: o.model, } @@ -90,7 +90,7 @@ func (o openai) requestToOpenAIRequest(request Request) oai.ChatCompletionReques return res } -func (o openai) responseToLLMResponse(response oai.ChatCompletionResponse) Response { +func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) Response { res := Response{} for _, choice := range response.Choices { @@ -118,7 +118,7 @@ func (o openai) responseToLLMResponse(response oai.ChatCompletionResponse) Respo return res } -func (o openai) ChatComplete(ctx context.Context, request Request) (Response, error) { +func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) { cl := oai.NewClient(o.key) req := o.requestToOpenAIRequest(request) @@ -128,14 +128,14 @@ func (o openai) ChatComplete(ctx context.Context, request Request) (Response, er fmt.Println("resp:", fmt.Sprintf("%#v", resp)) if err != nil { - return Response{}, fmt.Errorf("unhandled openai error: %w", err) + return Response{}, fmt.Errorf("unhandled openaiImpl error: %w", err) } return o.responseToLLMResponse(resp), nil } -func (o openai) ModelVersion(modelVersion string) (ChatCompletion, error) { - return openai{ +func (o openaiImpl) ModelVersion(modelVersion string) (ChatCompletion, error) { + return openaiImpl{ key: o.key, model: modelVersion, }, nil diff --git a/schema/GetType.go b/schema/GetType.go new file mode 100644 index 0000000..f2f7e16 --- /dev/null +++ b/schema/GetType.go @@ -0,0 +1,125 @@ +package schema + +import ( + "reflect" + "strings" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +// GetType will, given an interface{} that is a struct (NOT a pointer to a struct), return the Type of the struct that +// can be used to generate a json schema and build an object from a parsed json object. +func GetType(a any) Type { + t := reflect.TypeOf(a) + + if t.Kind() != reflect.Struct { + panic("GetType expects a struct") + } + + return getObject(t) +} + +func getFromType(t reflect.Type, b basic) Type { + if t.Kind() == reflect.Ptr { + t = t.Elem() + b.required = false + } + + switch t.Kind() { + case reflect.String: + b.DataType = jsonschema.String + return b + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + b.DataType = jsonschema.Integer + return b + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + b.DataType = jsonschema.Integer + return b + + case reflect.Float32, reflect.Float64: + b.DataType = jsonschema.Number + return b + + case reflect.Bool: + b.DataType = jsonschema.Boolean + return b + + case reflect.Struct: + o := getObject(t) + + o.basic.required = b.required + o.basic.index = b.index + o.basic.description = b.description + + return o + + case reflect.Slice: + return getArray(t) + + default: + panic("unhandled default case for " + t.Kind().String() + " in getFromType") + } +} + +func getField(f reflect.StructField, index int) Type { + b := basic{ + index: index, + required: true, + description: "", + } + + t := f.Type + + // if the tag "description" is set, use that as the description + if desc, ok := f.Tag.Lookup("description"); ok { + b.description = desc + } + + // now if the tag "enum" is set, we need to create an enum type + if v, ok := f.Tag.Lookup("enum"); ok { + vals := strings.Split(v, ",") + + for i := 0; i < len(vals); i++ { + vals[i] = strings.TrimSpace(vals[i]) + + if vals[i] == "" { + vals = append(vals[:i], vals[i+1:]...) + } + } + + return enum{ + basic: b, + values: vals, + } + + } + + return getFromType(t, b) +} + +func getObject(t reflect.Type) object { + fields := make(map[string]Type, t.NumField()) + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + fields[field.Name] = getField(field, i) + } + + return object{ + basic: basic{DataType: jsonschema.Object}, + fields: fields, + } +} + +func getArray(t reflect.Type) array { + res := array{ + basic: basic{ + DataType: jsonschema.Array, + }, + } + + res.items = getFromType(t.Elem(), basic{}) + + return res +} diff --git a/schema/array.go b/schema/array.go new file mode 100644 index 0000000..6773da8 --- /dev/null +++ b/schema/array.go @@ -0,0 +1,65 @@ +package schema + +import ( + "errors" + "reflect" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +type array struct { + basic + + // items is the schema of the items in the array + items Type +} + +func (a array) SchemaType() jsonschema.DataType { + return jsonschema.Array +} + +func (a array) Definition() jsonschema.Definition { + def := a.basic.Definition() + def.Type = jsonschema.Array + i := a.items.Definition() + def.Items = &i + def.AdditionalProperties = false + return def +} + +func (a array) FromAny(val any) (reflect.Value, error) { + v := reflect.ValueOf(val) + + // first realize we may have a pointer to a slice if this type is not required + if !a.required && v.Kind() == reflect.Ptr { + v = v.Elem() + } + + if v.Kind() != reflect.Slice { + return reflect.Value{}, errors.New("expected slice, got " + v.Kind().String()) + } + + // if the slice is nil, we can just return it + if v.IsNil() { + return v, nil + } + + // if the slice is not nil, we need to convert each item + items := make([]reflect.Value, v.Len()) + for i := 0; i < v.Len(); i++ { + item, err := a.items.FromAny(v.Index(i).Interface()) + if err != nil { + return reflect.Value{}, err + } + items[i] = item + } + + return reflect.ValueOf(items), nil +} + +func (a array) SetValue(obj reflect.Value, val reflect.Value) { + if !a.required { + val = val.Addr() + } + obj.Field(a.index).Set(val) +} diff --git a/schema/basic.go b/schema/basic.go new file mode 100644 index 0000000..0175c8f --- /dev/null +++ b/schema/basic.go @@ -0,0 +1,105 @@ +package schema + +import ( + "errors" + "reflect" + "strconv" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +// just enforcing that basic implements Type +var _ Type = basic{} + +type basic struct { + jsonschema.DataType + + // index is the position of the parameter in the StructField of the function's parameter struct + index int + + // required is a flag that indicates whether the parameter is required in the function's parameter struct. + // this is inferred by if the parameter is a pointer type or not. + required bool + + // description is a llm-readable description of the parameter passed to openai + description string +} + +func (b basic) SchemaType() jsonschema.DataType { + return b.DataType +} + +func (b basic) Definition() jsonschema.Definition { + return jsonschema.Definition{ + Type: b.DataType, + Description: b.description, + } +} + +func (b basic) Required() bool { + return b.required +} + +func (b basic) Description() string { + return b.description +} + +func (b basic) FromAny(val any) (reflect.Value, error) { + v := reflect.ValueOf(val) + + switch b.DataType { + case jsonschema.String: + var val = v.String() + + return reflect.ValueOf(val), nil + + case jsonschema.Integer: + if v.Kind() == reflect.Float64 { + return v.Convert(reflect.TypeOf(int(0))), nil + } else if v.Kind() != reflect.Int { + return reflect.Value{}, errors.New("expected int, got " + v.Kind().String()) + } else { + return v, nil + } + + case jsonschema.Number: + if v.Kind() == reflect.Float64 { + return v.Convert(reflect.TypeOf(float64(0))), nil + } else if v.Kind() != reflect.Float64 { + return reflect.Value{}, errors.New("expected float64, got " + v.Kind().String()) + } else { + return v, nil + } + + case jsonschema.Boolean: + if v.Kind() == reflect.Bool { + return v, nil + } else if v.Kind() == reflect.String { + b, err := strconv.ParseBool(v.String()) + if err != nil { + return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String()) + } + return reflect.ValueOf(b), nil + } else { + return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String()) + } + + default: + return reflect.Value{}, errors.New("unknown type") + } +} + +func (b basic) SetValueOnField(obj reflect.Value, val reflect.Value) { + // if this basic type is not required that means it's a pointer type + // so we need to create a new value of the type of the pointer + if !b.required { + vv := reflect.New(obj.Field(b.index).Type().Elem()) + + // and then set the value of the pointer to the new value + vv.Elem().Set(val) + + obj.Field(b.index).Set(vv) + return + } + obj.Field(b.index).Set(val) +} diff --git a/schema/enum.go b/schema/enum.go new file mode 100644 index 0000000..272866f --- /dev/null +++ b/schema/enum.go @@ -0,0 +1,47 @@ +package schema + +import ( + "errors" + "reflect" + + "golang.org/x/exp/slices" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +type enum struct { + basic + + values []string +} + +func (e enum) SchemaType() jsonschema.DataType { + return jsonschema.String +} + +func (e enum) Definition() jsonschema.Definition { + def := e.basic.Definition() + def.Enum = e.values + return def +} + +func (e enum) FromAny(val any) (reflect.Value, error) { + v := reflect.ValueOf(val) + if v.Kind() != reflect.String { + return reflect.Value{}, errors.New("expected string, got " + v.Kind().String()) + } + + s := v.String() + if !slices.Contains(e.values, s) { + return reflect.Value{}, errors.New("value " + s + " not in enum") + } + + return v, nil +} + +func (e enum) SetValueOnField(obj reflect.Value, val reflect.Value) { + if !e.required { + val = val.Addr() + } + obj.Field(e.index).Set(val) +} diff --git a/schema/object.go b/schema/object.go new file mode 100644 index 0000000..b0a0c1b --- /dev/null +++ b/schema/object.go @@ -0,0 +1,78 @@ +package schema + +import ( + "errors" + "reflect" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +type object struct { + basic + + ref reflect.Type + + fields map[string]Type +} + +func (o object) SchemaType() jsonschema.DataType { + return jsonschema.Object +} + +func (o object) Definition() jsonschema.Definition { + def := o.basic.Definition() + def.Type = jsonschema.Object + def.Properties = make(map[string]jsonschema.Definition) + for k, v := range o.fields { + def.Properties[k] = v.Definition() + } + + def.AdditionalProperties = false + return def +} + +func (o object) FromAny(val any) (reflect.Value, error) { + // if the value is nil, we can't do anything + if val == nil { + return reflect.Value{}, nil + } + + // now make a new object of the type we're trying to parse + obj := reflect.New(o.ref).Elem() + + // now we need to iterate over the fields and set the values + for k, v := range o.fields { + // get the field by name + field := obj.FieldByName(k) + if !field.IsValid() { + return reflect.Value{}, errors.New("field " + k + " not found") + } + + // get the value from the map + val2, ok := val.(map[string]interface{})[k] + if !ok { + return reflect.Value{}, errors.New("field " + k + " not found in map") + } + + // now we need to convert the value to the correct type + val3, err := v.FromAny(val2) + if err != nil { + return reflect.Value{}, err + } + + // now we need to set the value on the field + v.SetValueOnField(field, val3) + + } + + return obj, nil +} + +func (o object) SetValueOnField(obj reflect.Value, val reflect.Value) { + // if this basic type is not required that means it's a pointer type so we need to set the value to the address of the value + if !o.required { + val = val.Addr() + } + + obj.Field(o.index).Set(val) +} diff --git a/schema/type.go b/schema/type.go new file mode 100644 index 0000000..7f84216 --- /dev/null +++ b/schema/type.go @@ -0,0 +1,18 @@ +package schema + +import ( + "reflect" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +type Type interface { + SchemaType() jsonschema.DataType + Definition() jsonschema.Definition + + Required() bool + Description() string + + FromAny(any) (reflect.Value, error) + SetValueOnField(obj reflect.Value, val reflect.Value) +} diff --git a/toolbox.go b/toolbox.go new file mode 100644 index 0000000..b95fd2d --- /dev/null +++ b/toolbox.go @@ -0,0 +1,90 @@ +package go_llm + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "github.com/sashabaranov/go-openai" + "log/slog" +) + +// ToolBox is a collection of tools that OpenAI can use to execute functions. +// It is a wrapper around a collection of functions, and provides a way to automatically call the correct function with +// the correct parameters. +type ToolBox struct { + funcs []Function + names map[string]Function +} + +func NewToolBox(fns ...*Function) *ToolBox { + res := ToolBox{ + funcs: []Function{}, + names: map[string]Function{}, + } + + for _, f := range fns { + o := *f + res.names[o.Name] = o + res.funcs = append(res.funcs, o) + } + + return &res +} + +func (t *ToolBox) WithFunction(f Function) *ToolBox { + t2 := *t + t2.names[f.Name] = f + t2.funcs = append(t2.funcs, f) + + return &t2 +} + +// ToOpenAI will convert the current ToolBox to a slice of openai.Tool, which can be used to send to the OpenAI API. +func (t *ToolBox) toOpenAI() []openai.Tool { + var res []openai.Tool + + for _, f := range t.funcs { + res = append(res, openai.Tool{ + Type: "function", + Function: f.toOpenAIFunction(), + }) + } + + return res +} + +func (t *ToolBox) ToToolChoice() any { + if len(t.funcs) == 0 { + return nil + } + + return "required" +} + +var ( + ErrFunctionNotFound = errors.New("function not found") +) + +func (t *ToolBox) ExecuteFunction(ctx context.Context, functionName string, params string) (string, error) { + f, ok := t.names[functionName] + + slog.Info("functionName", functionName) + + if !ok { + return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName)) + } + + return f.Execute(ctx, params) +} + +func (t *ToolBox) Execute(ctx context.Context, toolCall ToolCall) (string, error) { + slog.Info("toolCall", toolCall) + + b, err := json.Marshal(toolCall.FunctionCall.Arguments) + + if err != nil { + return "", fmt.Errorf("failed to marshal arguments: %w", err) + } + return t.ExecuteFunction(ctx, toolCall.ID, string(b)) +}