initial commit of untested function stuff
This commit is contained in:
parent
f603010dee
commit
37939088ed
21
error.go
Normal file
21
error.go
Normal file
@ -0,0 +1,21 @@
|
||||
package go_llm
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Error is essentially just an error, but it is used to differentiate between a normal error and a fatal error.
|
||||
type Error struct {
|
||||
error
|
||||
|
||||
Source error
|
||||
Parameter error
|
||||
}
|
||||
|
||||
func newError(parent error, err error) Error {
|
||||
e := fmt.Errorf("%w: %w", parent, err)
|
||||
return Error{
|
||||
error: e,
|
||||
|
||||
Source: parent,
|
||||
Parameter: err,
|
||||
}
|
||||
}
|
80
function.go
80
function.go
@ -1,10 +1,88 @@
|
||||
package go_llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Function struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Strict bool `json:"strict,omitempty"`
|
||||
Parameters any `json:"parameters"`
|
||||
Parameters schema.Type `json:"parameters"`
|
||||
|
||||
Forced bool `json:"forced,omitempty"`
|
||||
|
||||
// Timeout is the maximum time to wait for the function to complete
|
||||
Timeout time.Duration `json:"-"`
|
||||
|
||||
// fn is the function to call, only set if this is constructed with NewFunction
|
||||
fn reflect.Value
|
||||
|
||||
paramType reflect.Type
|
||||
|
||||
// definition is a cache of the openaiImpl jsonschema definition
|
||||
definition *jsonschema.Definition
|
||||
}
|
||||
|
||||
func (f *Function) Execute(ctx context.Context, input string) (string, error) {
|
||||
if !f.fn.IsValid() {
|
||||
return "", fmt.Errorf("function %s is not implemented", f.Name)
|
||||
}
|
||||
|
||||
// first, we need to parse the input into the struct
|
||||
p := reflect.New(f.paramType).Elem()
|
||||
//m := map[string]any{}
|
||||
err := json.Unmarshal([]byte(input), p.Addr().Interface())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to unmarshal input: %w", err)
|
||||
}
|
||||
|
||||
// now we can call the function
|
||||
exec := func(ctx context.Context) (string, error) {
|
||||
out := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p})
|
||||
|
||||
if len(out) != 2 {
|
||||
return "", fmt.Errorf("function %s must return two values, got %d", f.Name, len(out))
|
||||
}
|
||||
|
||||
if out[1].IsNil() {
|
||||
return out[0].String(), nil
|
||||
}
|
||||
|
||||
return "", out[1].Interface().(error)
|
||||
}
|
||||
|
||||
var cancel context.CancelFunc
|
||||
if f.Timeout > 0 {
|
||||
ctx, cancel = context.WithTimeout(ctx, f.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
return exec(ctx)
|
||||
}
|
||||
|
||||
func (f *Function) toOpenAIFunction() *openai.FunctionDefinition {
|
||||
return &openai.FunctionDefinition{
|
||||
Name: f.Name,
|
||||
Description: f.Description,
|
||||
Strict: f.Strict,
|
||||
Parameters: f.Parameters,
|
||||
}
|
||||
}
|
||||
func (f *Function) toOpenAIDefinition() jsonschema.Definition {
|
||||
if f.definition == nil {
|
||||
def := f.Parameters.Definition()
|
||||
f.definition = &def
|
||||
}
|
||||
|
||||
return *f.definition
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
|
35
functions.go
Normal file
35
functions.go
Normal file
@ -0,0 +1,35 @@
|
||||
package go_llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Parse takes a function pointer and returns a function object.
|
||||
// fn must be a pointer to a function that takes a context.Context as its first argument, and then a struct that contains
|
||||
// the parameters for the function. The struct must contain only the types: string, int, float64, bool, and pointers to
|
||||
// those types.
|
||||
// The struct parameters can have the following tags:
|
||||
// - Description: a string that describes the parameter, passed to openaiImpl to tell it what the parameter is for
|
||||
|
||||
func NewFunction[T any](name string, description string, fn func(context.Context, T) (string, error)) *Function {
|
||||
var o T
|
||||
|
||||
res := Function{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Parameters: schema.GetType(o),
|
||||
fn: reflect.ValueOf(fn),
|
||||
paramType: reflect.TypeOf(o),
|
||||
}
|
||||
|
||||
if res.fn.Kind() != reflect.Func {
|
||||
panic("fn must be a function")
|
||||
}
|
||||
if res.paramType.Kind() != reflect.Struct {
|
||||
panic("function parameter must be a struct")
|
||||
}
|
||||
|
||||
return &res
|
||||
}
|
7
go.mod
7
go.mod
@ -3,8 +3,11 @@ module gitea.stevedudenhoeffer.com/steve/go-llm
|
||||
go 1.23.1
|
||||
|
||||
require (
|
||||
github.com/google/generative-ai-go v0.18.0
|
||||
github.com/liushuangls/go-anthropic/v2 v2.8.0
|
||||
github.com/sashabaranov/go-openai v1.31.0
|
||||
golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f
|
||||
google.golang.org/api v0.186.0
|
||||
)
|
||||
|
||||
require (
|
||||
@ -19,7 +22,6 @@ require (
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||
github.com/golang/protobuf v1.5.4 // indirect
|
||||
github.com/google/generative-ai-go v0.18.0 // indirect
|
||||
github.com/google/s2a-go v0.1.7 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
||||
@ -33,11 +35,10 @@ require (
|
||||
golang.org/x/crypto v0.24.0 // indirect
|
||||
golang.org/x/net v0.26.0 // indirect
|
||||
golang.org/x/oauth2 v0.21.0 // indirect
|
||||
golang.org/x/sync v0.7.0 // indirect
|
||||
golang.org/x/sync v0.9.0 // indirect
|
||||
golang.org/x/sys v0.21.0 // indirect
|
||||
golang.org/x/text v0.16.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
google.golang.org/api v0.186.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240617180043-68d350f18fd4 // indirect
|
||||
google.golang.org/grpc v1.64.1 // indirect
|
||||
|
13
go.sum
13
go.sum
@ -16,6 +16,7 @@ github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA
|
||||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
@ -52,6 +53,8 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
|
||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o=
|
||||
github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
@ -63,6 +66,7 @@ github.com/googleapis/gax-go/v2 v2.12.5 h1:8gw9KZK8TiVKB6q3zHY3SBzLnrGp6HQjyfYBY
|
||||
github.com/googleapis/gax-go/v2 v2.12.5/go.mod h1:BUDKcWo+RaKq5SC9vVYL0wLADa3VcfswbOMMRmB9H3E=
|
||||
github.com/liushuangls/go-anthropic/v2 v2.8.0 h1:0zH2jDNycbrlszxnLrG+Gx8vVT0yJAPWU4s3ZTkWzgI=
|
||||
github.com/liushuangls/go-anthropic/v2 v2.8.0/go.mod h1:8BKv/fkeTaL5R9R9bGkaknYBueyw2WxY20o7bImbOek=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/sashabaranov/go-openai v1.31.0 h1:rGe77x7zUeCjtS2IS7NCY6Tp4bQviXNMhkQM6hz/UC4=
|
||||
@ -73,6 +77,8 @@ github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpE
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
|
||||
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0 h1:A3SayB3rNyt+1S6qpI9mHPkeHTZbD7XILEqWnYZb2l0=
|
||||
@ -90,6 +96,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
|
||||
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
|
||||
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f h1:XdNn9LlyWAhLVp6P/i8QYBW+hlyhrhei9uErw2B5GJo=
|
||||
golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f/go.mod h1:D5SMRVC3C2/4+F/DB1wZsLRnSNimn2Sp/NPsCrsv8ak=
|
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
|
||||
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||
@ -107,8 +115,8 @@ golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbht
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ=
|
||||
golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@ -158,6 +166,7 @@ google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6h
|
||||
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
|
@ -31,7 +31,7 @@ func (g google) requestToGoogleRequest(in Request, model *genai.GenerativeModel)
|
||||
}
|
||||
|
||||
for _, tool := range in.Toolbox {
|
||||
panic("google toolbox is todo" + tool.Name)
|
||||
panic("google ToolBox is todo" + tool.Name)
|
||||
|
||||
/*
|
||||
t := genai.Tool{}
|
||||
|
5
llm.go
5
llm.go
@ -27,7 +27,7 @@ type Message struct {
|
||||
|
||||
type Request struct {
|
||||
Messages []Message
|
||||
Toolbox []Function
|
||||
Toolbox *ToolBox
|
||||
Temperature *float32
|
||||
}
|
||||
|
||||
@ -50,6 +50,7 @@ type Response struct {
|
||||
|
||||
type ChatCompletion interface {
|
||||
ChatComplete(ctx context.Context, req Request) (Response, error)
|
||||
SplitLongString(ctx context.Context, input string) ([]string, error)
|
||||
}
|
||||
|
||||
type LLM interface {
|
||||
@ -57,7 +58,7 @@ type LLM interface {
|
||||
}
|
||||
|
||||
func OpenAI(key string) LLM {
|
||||
return openai{key: key}
|
||||
return openaiImpl{key: key}
|
||||
}
|
||||
|
||||
func Anthropic(key string) LLM {
|
||||
|
16
openai.go
16
openai.go
@ -7,14 +7,14 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type openai struct {
|
||||
type openaiImpl struct {
|
||||
key string
|
||||
model string
|
||||
}
|
||||
|
||||
var _ LLM = openai{}
|
||||
var _ LLM = openaiImpl{}
|
||||
|
||||
func (o openai) requestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
|
||||
func (o openaiImpl) requestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
|
||||
res := oai.ChatCompletionRequest{
|
||||
Model: o.model,
|
||||
}
|
||||
@ -90,7 +90,7 @@ func (o openai) requestToOpenAIRequest(request Request) oai.ChatCompletionReques
|
||||
return res
|
||||
}
|
||||
|
||||
func (o openai) responseToLLMResponse(response oai.ChatCompletionResponse) Response {
|
||||
func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) Response {
|
||||
res := Response{}
|
||||
|
||||
for _, choice := range response.Choices {
|
||||
@ -118,7 +118,7 @@ func (o openai) responseToLLMResponse(response oai.ChatCompletionResponse) Respo
|
||||
return res
|
||||
}
|
||||
|
||||
func (o openai) ChatComplete(ctx context.Context, request Request) (Response, error) {
|
||||
func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) {
|
||||
cl := oai.NewClient(o.key)
|
||||
|
||||
req := o.requestToOpenAIRequest(request)
|
||||
@ -128,14 +128,14 @@ func (o openai) ChatComplete(ctx context.Context, request Request) (Response, er
|
||||
fmt.Println("resp:", fmt.Sprintf("%#v", resp))
|
||||
|
||||
if err != nil {
|
||||
return Response{}, fmt.Errorf("unhandled openai error: %w", err)
|
||||
return Response{}, fmt.Errorf("unhandled openaiImpl error: %w", err)
|
||||
}
|
||||
|
||||
return o.responseToLLMResponse(resp), nil
|
||||
}
|
||||
|
||||
func (o openai) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
||||
return openai{
|
||||
func (o openaiImpl) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
||||
return openaiImpl{
|
||||
key: o.key,
|
||||
model: modelVersion,
|
||||
}, nil
|
||||
|
125
schema/GetType.go
Normal file
125
schema/GetType.go
Normal file
@ -0,0 +1,125 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
// GetType will, given an interface{} that is a struct (NOT a pointer to a struct), return the Type of the struct that
|
||||
// can be used to generate a json schema and build an object from a parsed json object.
|
||||
func GetType(a any) Type {
|
||||
t := reflect.TypeOf(a)
|
||||
|
||||
if t.Kind() != reflect.Struct {
|
||||
panic("GetType expects a struct")
|
||||
}
|
||||
|
||||
return getObject(t)
|
||||
}
|
||||
|
||||
func getFromType(t reflect.Type, b basic) Type {
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
b.required = false
|
||||
}
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.String:
|
||||
b.DataType = jsonschema.String
|
||||
return b
|
||||
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
b.DataType = jsonschema.Integer
|
||||
return b
|
||||
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
b.DataType = jsonschema.Integer
|
||||
return b
|
||||
|
||||
case reflect.Float32, reflect.Float64:
|
||||
b.DataType = jsonschema.Number
|
||||
return b
|
||||
|
||||
case reflect.Bool:
|
||||
b.DataType = jsonschema.Boolean
|
||||
return b
|
||||
|
||||
case reflect.Struct:
|
||||
o := getObject(t)
|
||||
|
||||
o.basic.required = b.required
|
||||
o.basic.index = b.index
|
||||
o.basic.description = b.description
|
||||
|
||||
return o
|
||||
|
||||
case reflect.Slice:
|
||||
return getArray(t)
|
||||
|
||||
default:
|
||||
panic("unhandled default case for " + t.Kind().String() + " in getFromType")
|
||||
}
|
||||
}
|
||||
|
||||
func getField(f reflect.StructField, index int) Type {
|
||||
b := basic{
|
||||
index: index,
|
||||
required: true,
|
||||
description: "",
|
||||
}
|
||||
|
||||
t := f.Type
|
||||
|
||||
// if the tag "description" is set, use that as the description
|
||||
if desc, ok := f.Tag.Lookup("description"); ok {
|
||||
b.description = desc
|
||||
}
|
||||
|
||||
// now if the tag "enum" is set, we need to create an enum type
|
||||
if v, ok := f.Tag.Lookup("enum"); ok {
|
||||
vals := strings.Split(v, ",")
|
||||
|
||||
for i := 0; i < len(vals); i++ {
|
||||
vals[i] = strings.TrimSpace(vals[i])
|
||||
|
||||
if vals[i] == "" {
|
||||
vals = append(vals[:i], vals[i+1:]...)
|
||||
}
|
||||
}
|
||||
|
||||
return enum{
|
||||
basic: b,
|
||||
values: vals,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return getFromType(t, b)
|
||||
}
|
||||
|
||||
func getObject(t reflect.Type) object {
|
||||
fields := make(map[string]Type, t.NumField())
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
fields[field.Name] = getField(field, i)
|
||||
}
|
||||
|
||||
return object{
|
||||
basic: basic{DataType: jsonschema.Object},
|
||||
fields: fields,
|
||||
}
|
||||
}
|
||||
|
||||
func getArray(t reflect.Type) array {
|
||||
res := array{
|
||||
basic: basic{
|
||||
DataType: jsonschema.Array,
|
||||
},
|
||||
}
|
||||
|
||||
res.items = getFromType(t.Elem(), basic{})
|
||||
|
||||
return res
|
||||
}
|
65
schema/array.go
Normal file
65
schema/array.go
Normal file
@ -0,0 +1,65 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
type array struct {
|
||||
basic
|
||||
|
||||
// items is the schema of the items in the array
|
||||
items Type
|
||||
}
|
||||
|
||||
func (a array) SchemaType() jsonschema.DataType {
|
||||
return jsonschema.Array
|
||||
}
|
||||
|
||||
func (a array) Definition() jsonschema.Definition {
|
||||
def := a.basic.Definition()
|
||||
def.Type = jsonschema.Array
|
||||
i := a.items.Definition()
|
||||
def.Items = &i
|
||||
def.AdditionalProperties = false
|
||||
return def
|
||||
}
|
||||
|
||||
func (a array) FromAny(val any) (reflect.Value, error) {
|
||||
v := reflect.ValueOf(val)
|
||||
|
||||
// first realize we may have a pointer to a slice if this type is not required
|
||||
if !a.required && v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
if v.Kind() != reflect.Slice {
|
||||
return reflect.Value{}, errors.New("expected slice, got " + v.Kind().String())
|
||||
}
|
||||
|
||||
// if the slice is nil, we can just return it
|
||||
if v.IsNil() {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// if the slice is not nil, we need to convert each item
|
||||
items := make([]reflect.Value, v.Len())
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
item, err := a.items.FromAny(v.Index(i).Interface())
|
||||
if err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
items[i] = item
|
||||
}
|
||||
|
||||
return reflect.ValueOf(items), nil
|
||||
}
|
||||
|
||||
func (a array) SetValue(obj reflect.Value, val reflect.Value) {
|
||||
if !a.required {
|
||||
val = val.Addr()
|
||||
}
|
||||
obj.Field(a.index).Set(val)
|
||||
}
|
105
schema/basic.go
Normal file
105
schema/basic.go
Normal file
@ -0,0 +1,105 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
// just enforcing that basic implements Type
|
||||
var _ Type = basic{}
|
||||
|
||||
type basic struct {
|
||||
jsonschema.DataType
|
||||
|
||||
// index is the position of the parameter in the StructField of the function's parameter struct
|
||||
index int
|
||||
|
||||
// required is a flag that indicates whether the parameter is required in the function's parameter struct.
|
||||
// this is inferred by if the parameter is a pointer type or not.
|
||||
required bool
|
||||
|
||||
// description is a llm-readable description of the parameter passed to openai
|
||||
description string
|
||||
}
|
||||
|
||||
func (b basic) SchemaType() jsonschema.DataType {
|
||||
return b.DataType
|
||||
}
|
||||
|
||||
func (b basic) Definition() jsonschema.Definition {
|
||||
return jsonschema.Definition{
|
||||
Type: b.DataType,
|
||||
Description: b.description,
|
||||
}
|
||||
}
|
||||
|
||||
func (b basic) Required() bool {
|
||||
return b.required
|
||||
}
|
||||
|
||||
func (b basic) Description() string {
|
||||
return b.description
|
||||
}
|
||||
|
||||
func (b basic) FromAny(val any) (reflect.Value, error) {
|
||||
v := reflect.ValueOf(val)
|
||||
|
||||
switch b.DataType {
|
||||
case jsonschema.String:
|
||||
var val = v.String()
|
||||
|
||||
return reflect.ValueOf(val), nil
|
||||
|
||||
case jsonschema.Integer:
|
||||
if v.Kind() == reflect.Float64 {
|
||||
return v.Convert(reflect.TypeOf(int(0))), nil
|
||||
} else if v.Kind() != reflect.Int {
|
||||
return reflect.Value{}, errors.New("expected int, got " + v.Kind().String())
|
||||
} else {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
case jsonschema.Number:
|
||||
if v.Kind() == reflect.Float64 {
|
||||
return v.Convert(reflect.TypeOf(float64(0))), nil
|
||||
} else if v.Kind() != reflect.Float64 {
|
||||
return reflect.Value{}, errors.New("expected float64, got " + v.Kind().String())
|
||||
} else {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
case jsonschema.Boolean:
|
||||
if v.Kind() == reflect.Bool {
|
||||
return v, nil
|
||||
} else if v.Kind() == reflect.String {
|
||||
b, err := strconv.ParseBool(v.String())
|
||||
if err != nil {
|
||||
return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String())
|
||||
}
|
||||
return reflect.ValueOf(b), nil
|
||||
} else {
|
||||
return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String())
|
||||
}
|
||||
|
||||
default:
|
||||
return reflect.Value{}, errors.New("unknown type")
|
||||
}
|
||||
}
|
||||
|
||||
func (b basic) SetValueOnField(obj reflect.Value, val reflect.Value) {
|
||||
// if this basic type is not required that means it's a pointer type
|
||||
// so we need to create a new value of the type of the pointer
|
||||
if !b.required {
|
||||
vv := reflect.New(obj.Field(b.index).Type().Elem())
|
||||
|
||||
// and then set the value of the pointer to the new value
|
||||
vv.Elem().Set(val)
|
||||
|
||||
obj.Field(b.index).Set(vv)
|
||||
return
|
||||
}
|
||||
obj.Field(b.index).Set(val)
|
||||
}
|
47
schema/enum.go
Normal file
47
schema/enum.go
Normal file
@ -0,0 +1,47 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
type enum struct {
|
||||
basic
|
||||
|
||||
values []string
|
||||
}
|
||||
|
||||
func (e enum) SchemaType() jsonschema.DataType {
|
||||
return jsonschema.String
|
||||
}
|
||||
|
||||
func (e enum) Definition() jsonschema.Definition {
|
||||
def := e.basic.Definition()
|
||||
def.Enum = e.values
|
||||
return def
|
||||
}
|
||||
|
||||
func (e enum) FromAny(val any) (reflect.Value, error) {
|
||||
v := reflect.ValueOf(val)
|
||||
if v.Kind() != reflect.String {
|
||||
return reflect.Value{}, errors.New("expected string, got " + v.Kind().String())
|
||||
}
|
||||
|
||||
s := v.String()
|
||||
if !slices.Contains(e.values, s) {
|
||||
return reflect.Value{}, errors.New("value " + s + " not in enum")
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (e enum) SetValueOnField(obj reflect.Value, val reflect.Value) {
|
||||
if !e.required {
|
||||
val = val.Addr()
|
||||
}
|
||||
obj.Field(e.index).Set(val)
|
||||
}
|
78
schema/object.go
Normal file
78
schema/object.go
Normal file
@ -0,0 +1,78 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
type object struct {
|
||||
basic
|
||||
|
||||
ref reflect.Type
|
||||
|
||||
fields map[string]Type
|
||||
}
|
||||
|
||||
func (o object) SchemaType() jsonschema.DataType {
|
||||
return jsonschema.Object
|
||||
}
|
||||
|
||||
func (o object) Definition() jsonschema.Definition {
|
||||
def := o.basic.Definition()
|
||||
def.Type = jsonschema.Object
|
||||
def.Properties = make(map[string]jsonschema.Definition)
|
||||
for k, v := range o.fields {
|
||||
def.Properties[k] = v.Definition()
|
||||
}
|
||||
|
||||
def.AdditionalProperties = false
|
||||
return def
|
||||
}
|
||||
|
||||
func (o object) FromAny(val any) (reflect.Value, error) {
|
||||
// if the value is nil, we can't do anything
|
||||
if val == nil {
|
||||
return reflect.Value{}, nil
|
||||
}
|
||||
|
||||
// now make a new object of the type we're trying to parse
|
||||
obj := reflect.New(o.ref).Elem()
|
||||
|
||||
// now we need to iterate over the fields and set the values
|
||||
for k, v := range o.fields {
|
||||
// get the field by name
|
||||
field := obj.FieldByName(k)
|
||||
if !field.IsValid() {
|
||||
return reflect.Value{}, errors.New("field " + k + " not found")
|
||||
}
|
||||
|
||||
// get the value from the map
|
||||
val2, ok := val.(map[string]interface{})[k]
|
||||
if !ok {
|
||||
return reflect.Value{}, errors.New("field " + k + " not found in map")
|
||||
}
|
||||
|
||||
// now we need to convert the value to the correct type
|
||||
val3, err := v.FromAny(val2)
|
||||
if err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
|
||||
// now we need to set the value on the field
|
||||
v.SetValueOnField(field, val3)
|
||||
|
||||
}
|
||||
|
||||
return obj, nil
|
||||
}
|
||||
|
||||
func (o object) SetValueOnField(obj reflect.Value, val reflect.Value) {
|
||||
// if this basic type is not required that means it's a pointer type so we need to set the value to the address of the value
|
||||
if !o.required {
|
||||
val = val.Addr()
|
||||
}
|
||||
|
||||
obj.Field(o.index).Set(val)
|
||||
}
|
18
schema/type.go
Normal file
18
schema/type.go
Normal file
@ -0,0 +1,18 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
type Type interface {
|
||||
SchemaType() jsonschema.DataType
|
||||
Definition() jsonschema.Definition
|
||||
|
||||
Required() bool
|
||||
Description() string
|
||||
|
||||
FromAny(any) (reflect.Value, error)
|
||||
SetValueOnField(obj reflect.Value, val reflect.Value)
|
||||
}
|
90
toolbox.go
Normal file
90
toolbox.go
Normal file
@ -0,0 +1,90 @@
|
||||
package go_llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// ToolBox is a collection of tools that OpenAI can use to execute functions.
|
||||
// It is a wrapper around a collection of functions, and provides a way to automatically call the correct function with
|
||||
// the correct parameters.
|
||||
type ToolBox struct {
|
||||
funcs []Function
|
||||
names map[string]Function
|
||||
}
|
||||
|
||||
func NewToolBox(fns ...*Function) *ToolBox {
|
||||
res := ToolBox{
|
||||
funcs: []Function{},
|
||||
names: map[string]Function{},
|
||||
}
|
||||
|
||||
for _, f := range fns {
|
||||
o := *f
|
||||
res.names[o.Name] = o
|
||||
res.funcs = append(res.funcs, o)
|
||||
}
|
||||
|
||||
return &res
|
||||
}
|
||||
|
||||
func (t *ToolBox) WithFunction(f Function) *ToolBox {
|
||||
t2 := *t
|
||||
t2.names[f.Name] = f
|
||||
t2.funcs = append(t2.funcs, f)
|
||||
|
||||
return &t2
|
||||
}
|
||||
|
||||
// ToOpenAI will convert the current ToolBox to a slice of openai.Tool, which can be used to send to the OpenAI API.
|
||||
func (t *ToolBox) toOpenAI() []openai.Tool {
|
||||
var res []openai.Tool
|
||||
|
||||
for _, f := range t.funcs {
|
||||
res = append(res, openai.Tool{
|
||||
Type: "function",
|
||||
Function: f.toOpenAIFunction(),
|
||||
})
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (t *ToolBox) ToToolChoice() any {
|
||||
if len(t.funcs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return "required"
|
||||
}
|
||||
|
||||
var (
|
||||
ErrFunctionNotFound = errors.New("function not found")
|
||||
)
|
||||
|
||||
func (t *ToolBox) ExecuteFunction(ctx context.Context, functionName string, params string) (string, error) {
|
||||
f, ok := t.names[functionName]
|
||||
|
||||
slog.Info("functionName", functionName)
|
||||
|
||||
if !ok {
|
||||
return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName))
|
||||
}
|
||||
|
||||
return f.Execute(ctx, params)
|
||||
}
|
||||
|
||||
func (t *ToolBox) Execute(ctx context.Context, toolCall ToolCall) (string, error) {
|
||||
slog.Info("toolCall", toolCall)
|
||||
|
||||
b, err := json.Marshal(toolCall.FunctionCall.Arguments)
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal arguments: %w", err)
|
||||
}
|
||||
return t.ExecuteFunction(ctx, toolCall.ID, string(b))
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user