Compare commits

...

7 Commits

Author SHA1 Message Date
0d70ec46de Fix role setting for assistant-sent images in Anthropic API
Anthropic API does not support assistants sending images directly, so the role is adjusted to "user" for such messages. This ensures compatibility and prevents errors when processing image messages.
2025-01-09 01:18:11 -05:00
6e2b5a33c0 fix anthropic 2024-12-29 19:45:28 -05:00
dfb768d966 make toolbox optional 2024-12-28 20:39:57 -05:00
46a526fd5a Merge branch 'main' of ssh://nuc.dudenhoeffer.casa:222/steve/go-llm 2024-12-28 19:49:26 -05:00
0993a8e865 Fix unmarshalling issues and adjust logging for debugging
Modify `FunctionCall` struct to handle arguments as strings. Add debugging logs to facilitate error tracing and improve JSON unmarshalling in various functions.
2024-11-11 00:23:01 -05:00
cd4ad59a38 sync of changes 2024-11-09 19:50:14 -05:00
37939088ed initial commit of untested function stuff 2024-11-08 20:53:12 -05:00
16 changed files with 740 additions and 46 deletions

View File

@@ -2,9 +2,11 @@ package go_llm
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"log/slog"
"net/http"
anth "github.com/liushuangls/go-anthropic/v2"
@@ -24,6 +26,13 @@ func (a anthropic) ModelVersion(modelVersion string) (ChatCompletion, error) {
return a, nil
}
func deferClose(c io.Closer) {
err := c.Close()
if err != nil {
slog.Error("error closing", "error", err)
}
}
func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
res := anth.MessagesRequest{
Model: anth.Model(a.model),
@@ -62,6 +71,11 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
}
for _, img := range msg.Images {
// anthropic doesn't allow the assistant to send images, so we need to say it's from the user
if m.Role == anth.RoleAssistant {
m.Role = anth.RoleUser
}
if img.Base64 != "" {
m.Content = append(m.Content, anth.NewImageMessageContent(
anth.NewMessageContentSource(
@@ -84,7 +98,7 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
continue
}
defer resp.Body.Close()
defer deferClose(resp.Body)
img.ContentType = resp.Header.Get("Content-Type")
@@ -120,12 +134,14 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
}
}
for _, tool := range req.Toolbox {
res.Tools = append(res.Tools, anth.ToolDefinition{
Name: tool.Name,
Description: tool.Description,
InputSchema: tool.Parameters,
})
if req.Toolbox != nil {
for _, tool := range req.Toolbox.funcs {
res.Tools = append(res.Tools, anth.ToolDefinition{
Name: tool.Name,
Description: tool.Description,
InputSchema: tool.Parameters,
})
}
}
res.Messages = msgs
@@ -153,13 +169,18 @@ func (a anthropic) responseToLLMResponse(in anth.MessagesResponse) Response {
case anth.MessagesContentTypeToolUse:
if msg.MessageContentToolUse != nil {
choice.Calls = append(choice.Calls, ToolCall{
ID: msg.MessageContentToolUse.ID,
FunctionCall: FunctionCall{
Name: msg.MessageContentToolUse.Name,
Arguments: msg.MessageContentToolUse.Input,
},
})
b, e := json.Marshal(msg.MessageContentToolUse.Input)
if e != nil {
log.Println("failed to marshal input", e)
} else {
choice.Calls = append(choice.Calls, ToolCall{
ID: msg.MessageContentToolUse.ID,
FunctionCall: FunctionCall{
Name: msg.MessageContentToolUse.Name,
Arguments: string(b),
},
})
}
}
}

21
error.go Normal file
View File

@@ -0,0 +1,21 @@
package go_llm
import "fmt"
// Error is essentially just an error, but it is used to differentiate between a normal error and a fatal error.
type Error struct {
error
Source error
Parameter error
}
func newError(parent error, err error) Error {
e := fmt.Errorf("%w: %w", parent, err)
return Error{
error: e,
Source: parent,
Parameter: err,
}
}

View File

@@ -1,13 +1,92 @@
package go_llm
import (
"context"
"encoding/json"
"fmt"
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/jsonschema"
"reflect"
"time"
)
type Function struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Strict bool `json:"strict,omitempty"`
Parameters any `json:"parameters"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
Strict bool `json:"strict,omitempty"`
Parameters schema.Type `json:"parameters"`
Forced bool `json:"forced,omitempty"`
// Timeout is the maximum time to wait for the function to complete
Timeout time.Duration `json:"-"`
// fn is the function to call, only set if this is constructed with NewFunction
fn reflect.Value
paramType reflect.Type
// definition is a cache of the openaiImpl jsonschema definition
definition *jsonschema.Definition
}
func (f *Function) Execute(ctx context.Context, input string) (string, error) {
if !f.fn.IsValid() {
return "", fmt.Errorf("function %s is not implemented", f.Name)
}
// first, we need to parse the input into the struct
p := reflect.New(f.paramType)
fmt.Println("Function.Execute", f.Name, "input:", input)
//m := map[string]any{}
err := json.Unmarshal([]byte(input), p.Interface())
if err != nil {
return "", fmt.Errorf("failed to unmarshal input: %w (input: %s)", err, input)
}
// now we can call the function
exec := func(ctx context.Context) (string, error) {
out := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()})
if len(out) != 2 {
return "", fmt.Errorf("function %s must return two values, got %d", f.Name, len(out))
}
if out[1].IsNil() {
return out[0].String(), nil
}
return "", out[1].Interface().(error)
}
var cancel context.CancelFunc
if f.Timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, f.Timeout)
defer cancel()
}
return exec(ctx)
}
func (f *Function) toOpenAIFunction() *openai.FunctionDefinition {
return &openai.FunctionDefinition{
Name: f.Name,
Description: f.Description,
Strict: f.Strict,
Parameters: f.Parameters,
}
}
func (f *Function) toOpenAIDefinition() jsonschema.Definition {
if f.definition == nil {
def := f.Parameters.Definition()
f.definition = &def
}
return *f.definition
}
type FunctionCall struct {
Name string `json:"name,omitempty"`
Arguments any `json:"arguments,omitempty"`
Arguments string `json:"arguments,omitempty"`
}

35
functions.go Normal file
View File

@@ -0,0 +1,35 @@
package go_llm
import (
"context"
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
"reflect"
)
// Parse takes a function pointer and returns a function object.
// fn must be a pointer to a function that takes a context.Context as its first argument, and then a struct that contains
// the parameters for the function. The struct must contain only the types: string, int, float64, bool, and pointers to
// those types.
// The struct parameters can have the following tags:
// - Description: a string that describes the parameter, passed to openaiImpl to tell it what the parameter is for
func NewFunction[T any](name string, description string, fn func(context.Context, T) (string, error)) *Function {
var o T
res := Function{
Name: name,
Description: description,
Parameters: schema.GetType(o),
fn: reflect.ValueOf(fn),
paramType: reflect.TypeOf(o),
}
if res.fn.Kind() != reflect.Func {
panic("fn must be a function")
}
if res.paramType.Kind() != reflect.Struct {
panic("function parameter must be a struct")
}
return &res
}

3
go.mod
View File

@@ -5,7 +5,8 @@ go 1.23.1
require (
github.com/google/generative-ai-go v0.19.0
github.com/liushuangls/go-anthropic/v2 v2.13.0
github.com/sashabaranov/go-openai v1.36.0
github.com/sashabaranov/go-openai v1.36.1
golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67
google.golang.org/api v0.214.0
)

4
go.sum
View File

@@ -39,6 +39,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sashabaranov/go-openai v1.36.0 h1:fcSrn8uGuorzPWCBp8L0aCR95Zjb/Dd+ZSML0YZy9EI=
github.com/sashabaranov/go-openai v1.36.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.36.1 h1:EVfRXwIlW2rUzpx6vR+aeIKCK/xylSrVYAx1TMTSX3g=
github.com/sashabaranov/go-openai v1.36.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
@@ -59,6 +61,8 @@ go.opentelemetry.io/otel/trace v1.33.0 h1:cCJuF7LRjUFso9LPnEAHJDB2pqzp+hbO8eu1qq
go.opentelemetry.io/otel/trace v1.33.0/go.mod h1:uIcdVUZMpTAmz0tI1z04GoVSezK37CbGV4fr1f2nBck=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67 h1:1UoZQm6f0P/ZO0w1Ri+f+ifG/gXhegadRdwBIXEFWDo=
golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE=

View File

@@ -2,6 +2,7 @@ package go_llm
import (
"context"
"encoding/json"
"fmt"
"github.com/google/generative-ai-go/genai"
"google.golang.org/api/option"
@@ -30,8 +31,8 @@ func (g google) requestToGoogleRequest(in Request, model *genai.GenerativeModel)
res = append(res, genai.Text(c.Text))
}
for _, tool := range in.Toolbox {
panic("google toolbox is todo" + tool.Name)
for _, tool := range in.Toolbox.funcs {
panic("google ToolBox is todo" + tool.Name)
/*
t := genai.Tool{}
@@ -63,11 +64,17 @@ func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Respon
choice := ResponseChoice{}
choice.Content = v.Name
b, e := json.Marshal(v.Args)
if e != nil {
return Response{}, fmt.Errorf("error marshalling args: %w", e)
}
call := ToolCall{
ID: v.Name,
FunctionCall: FunctionCall{
Name: v.Name,
Arguments: v.Args,
Arguments: string(b),
},
}

4
llm.go
View File

@@ -27,7 +27,7 @@ type Message struct {
type Request struct {
Messages []Message
Toolbox []Function
Toolbox *ToolBox
Temperature *float32
}
@@ -57,7 +57,7 @@ type LLM interface {
}
func OpenAI(key string) LLM {
return openai{key: key}
return openaiImpl{key: key}
}
func Anthropic(key string) LLM {

View File

@@ -3,18 +3,19 @@ package go_llm
import (
"context"
"fmt"
oai "github.com/sashabaranov/go-openai"
"strings"
oai "github.com/sashabaranov/go-openai"
)
type openai struct {
type openaiImpl struct {
key string
model string
}
var _ LLM = openai{}
var _ LLM = openaiImpl{}
func (o openai) requestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
func (o openaiImpl) requestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
res := oai.ChatCompletionRequest{
Model: o.model,
}
@@ -57,16 +58,20 @@ func (o openai) requestToOpenAIRequest(request Request) oai.ChatCompletionReques
res.Messages = append(res.Messages, m)
}
for _, tool := range request.Toolbox {
res.Tools = append(res.Tools, oai.Tool{
Type: "function",
Function: &oai.FunctionDefinition{
Name: tool.Name,
Description: tool.Description,
Strict: tool.Strict,
Parameters: tool.Parameters,
},
})
if request.Toolbox != nil {
for _, tool := range request.Toolbox.funcs {
res.Tools = append(res.Tools, oai.Tool{
Type: "function",
Function: &oai.FunctionDefinition{
Name: tool.Name,
Description: tool.Description,
Strict: tool.Strict,
Parameters: tool.Parameters.Definition(),
},
})
fmt.Println("tool:", tool.Name, tool.Description, tool.Strict, tool.Parameters.Definition())
}
}
if request.Temperature != nil {
@@ -90,12 +95,13 @@ func (o openai) requestToOpenAIRequest(request Request) oai.ChatCompletionReques
return res
}
func (o openai) responseToLLMResponse(response oai.ChatCompletionResponse) Response {
func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) Response {
res := Response{}
for _, choice := range response.Choices {
var tools []ToolCall
var toolCalls []ToolCall
for _, call := range choice.Message.ToolCalls {
fmt.Println("responseToLLMResponse: call:", call.Function.Arguments)
toolCall := ToolCall{
ID: call.ID,
FunctionCall: FunctionCall{
@@ -104,7 +110,9 @@ func (o openai) responseToLLMResponse(response oai.ChatCompletionResponse) Respo
},
}
tools = append(tools, toolCall)
fmt.Println("toolCall.FunctionCall.Arguments:", toolCall.FunctionCall.Arguments)
toolCalls = append(toolCalls, toolCall)
}
res.Choices = append(res.Choices, ResponseChoice{
@@ -112,13 +120,14 @@ func (o openai) responseToLLMResponse(response oai.ChatCompletionResponse) Respo
Role: Role(choice.Message.Role),
Name: choice.Message.Name,
Refusal: choice.Message.Refusal,
Calls: toolCalls,
})
}
return res
}
func (o openai) ChatComplete(ctx context.Context, request Request) (Response, error) {
func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) {
cl := oai.NewClient(o.key)
req := o.requestToOpenAIRequest(request)
@@ -128,14 +137,14 @@ func (o openai) ChatComplete(ctx context.Context, request Request) (Response, er
fmt.Println("resp:", fmt.Sprintf("%#v", resp))
if err != nil {
return Response{}, fmt.Errorf("unhandled openai error: %w", err)
return Response{}, fmt.Errorf("unhandled openaiImpl error: %w", err)
}
return o.responseToLLMResponse(resp), nil
}
func (o openai) ModelVersion(modelVersion string) (ChatCompletion, error) {
return openai{
func (o openaiImpl) ModelVersion(modelVersion string) (ChatCompletion, error) {
return openaiImpl{
key: o.key,
model: modelVersion,
}, nil

125
schema/GetType.go Normal file
View File

@@ -0,0 +1,125 @@
package schema
import (
"reflect"
"strings"
"github.com/sashabaranov/go-openai/jsonschema"
)
// GetType will, given an interface{} that is a struct (NOT a pointer to a struct), return the Type of the struct that
// can be used to generate a json schema and build an object from a parsed json object.
func GetType(a any) Type {
t := reflect.TypeOf(a)
if t.Kind() != reflect.Struct {
panic("GetType expects a struct")
}
return getObject(t)
}
func getFromType(t reflect.Type, b basic) Type {
if t.Kind() == reflect.Ptr {
t = t.Elem()
b.required = false
}
switch t.Kind() {
case reflect.String:
b.DataType = jsonschema.String
return b
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
b.DataType = jsonschema.Integer
return b
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
b.DataType = jsonschema.Integer
return b
case reflect.Float32, reflect.Float64:
b.DataType = jsonschema.Number
return b
case reflect.Bool:
b.DataType = jsonschema.Boolean
return b
case reflect.Struct:
o := getObject(t)
o.basic.required = b.required
o.basic.index = b.index
o.basic.description = b.description
return o
case reflect.Slice:
return getArray(t)
default:
panic("unhandled default case for " + t.Kind().String() + " in getFromType")
}
}
func getField(f reflect.StructField, index int) Type {
b := basic{
index: index,
required: true,
description: "",
}
t := f.Type
// if the tag "description" is set, use that as the description
if desc, ok := f.Tag.Lookup("description"); ok {
b.description = desc
}
// now if the tag "enum" is set, we need to create an enum type
if v, ok := f.Tag.Lookup("enum"); ok {
vals := strings.Split(v, ",")
for i := 0; i < len(vals); i++ {
vals[i] = strings.TrimSpace(vals[i])
if vals[i] == "" {
vals = append(vals[:i], vals[i+1:]...)
}
}
return enum{
basic: b,
values: vals,
}
}
return getFromType(t, b)
}
func getObject(t reflect.Type) object {
fields := make(map[string]Type, t.NumField())
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fields[field.Name] = getField(field, i)
}
return object{
basic: basic{DataType: jsonschema.Object},
fields: fields,
}
}
func getArray(t reflect.Type) array {
res := array{
basic: basic{
DataType: jsonschema.Array,
},
}
res.items = getFromType(t.Elem(), basic{})
return res
}

65
schema/array.go Normal file
View File

@@ -0,0 +1,65 @@
package schema
import (
"errors"
"reflect"
"github.com/sashabaranov/go-openai/jsonschema"
)
type array struct {
basic
// items is the schema of the items in the array
items Type
}
func (a array) SchemaType() jsonschema.DataType {
return jsonschema.Array
}
func (a array) Definition() jsonschema.Definition {
def := a.basic.Definition()
def.Type = jsonschema.Array
i := a.items.Definition()
def.Items = &i
def.AdditionalProperties = false
return def
}
func (a array) FromAny(val any) (reflect.Value, error) {
v := reflect.ValueOf(val)
// first realize we may have a pointer to a slice if this type is not required
if !a.required && v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Slice {
return reflect.Value{}, errors.New("expected slice, got " + v.Kind().String())
}
// if the slice is nil, we can just return it
if v.IsNil() {
return v, nil
}
// if the slice is not nil, we need to convert each item
items := make([]reflect.Value, v.Len())
for i := 0; i < v.Len(); i++ {
item, err := a.items.FromAny(v.Index(i).Interface())
if err != nil {
return reflect.Value{}, err
}
items[i] = item
}
return reflect.ValueOf(items), nil
}
func (a array) SetValue(obj reflect.Value, val reflect.Value) {
if !a.required {
val = val.Addr()
}
obj.Field(a.index).Set(val)
}

105
schema/basic.go Normal file
View File

@@ -0,0 +1,105 @@
package schema
import (
"errors"
"reflect"
"strconv"
"github.com/sashabaranov/go-openai/jsonschema"
)
// just enforcing that basic implements Type
var _ Type = basic{}
type basic struct {
jsonschema.DataType
// index is the position of the parameter in the StructField of the function's parameter struct
index int
// required is a flag that indicates whether the parameter is required in the function's parameter struct.
// this is inferred by if the parameter is a pointer type or not.
required bool
// description is a llm-readable description of the parameter passed to openai
description string
}
func (b basic) SchemaType() jsonschema.DataType {
return b.DataType
}
func (b basic) Definition() jsonschema.Definition {
return jsonschema.Definition{
Type: b.DataType,
Description: b.description,
}
}
func (b basic) Required() bool {
return b.required
}
func (b basic) Description() string {
return b.description
}
func (b basic) FromAny(val any) (reflect.Value, error) {
v := reflect.ValueOf(val)
switch b.DataType {
case jsonschema.String:
var val = v.String()
return reflect.ValueOf(val), nil
case jsonschema.Integer:
if v.Kind() == reflect.Float64 {
return v.Convert(reflect.TypeOf(int(0))), nil
} else if v.Kind() != reflect.Int {
return reflect.Value{}, errors.New("expected int, got " + v.Kind().String())
} else {
return v, nil
}
case jsonschema.Number:
if v.Kind() == reflect.Float64 {
return v.Convert(reflect.TypeOf(float64(0))), nil
} else if v.Kind() != reflect.Float64 {
return reflect.Value{}, errors.New("expected float64, got " + v.Kind().String())
} else {
return v, nil
}
case jsonschema.Boolean:
if v.Kind() == reflect.Bool {
return v, nil
} else if v.Kind() == reflect.String {
b, err := strconv.ParseBool(v.String())
if err != nil {
return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String())
}
return reflect.ValueOf(b), nil
} else {
return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String())
}
default:
return reflect.Value{}, errors.New("unknown type")
}
}
func (b basic) SetValueOnField(obj reflect.Value, val reflect.Value) {
// if this basic type is not required that means it's a pointer type
// so we need to create a new value of the type of the pointer
if !b.required {
vv := reflect.New(obj.Field(b.index).Type().Elem())
// and then set the value of the pointer to the new value
vv.Elem().Set(val)
obj.Field(b.index).Set(vv)
return
}
obj.Field(b.index).Set(val)
}

47
schema/enum.go Normal file
View File

@@ -0,0 +1,47 @@
package schema
import (
"errors"
"reflect"
"golang.org/x/exp/slices"
"github.com/sashabaranov/go-openai/jsonschema"
)
type enum struct {
basic
values []string
}
func (e enum) SchemaType() jsonschema.DataType {
return jsonschema.String
}
func (e enum) Definition() jsonschema.Definition {
def := e.basic.Definition()
def.Enum = e.values
return def
}
func (e enum) FromAny(val any) (reflect.Value, error) {
v := reflect.ValueOf(val)
if v.Kind() != reflect.String {
return reflect.Value{}, errors.New("expected string, got " + v.Kind().String())
}
s := v.String()
if !slices.Contains(e.values, s) {
return reflect.Value{}, errors.New("value " + s + " not in enum")
}
return v, nil
}
func (e enum) SetValueOnField(obj reflect.Value, val reflect.Value) {
if !e.required {
val = val.Addr()
}
obj.Field(e.index).Set(val)
}

78
schema/object.go Normal file
View File

@@ -0,0 +1,78 @@
package schema
import (
"errors"
"reflect"
"github.com/sashabaranov/go-openai/jsonschema"
)
type object struct {
basic
ref reflect.Type
fields map[string]Type
}
func (o object) SchemaType() jsonschema.DataType {
return jsonschema.Object
}
func (o object) Definition() jsonschema.Definition {
def := o.basic.Definition()
def.Type = jsonschema.Object
def.Properties = make(map[string]jsonschema.Definition)
for k, v := range o.fields {
def.Properties[k] = v.Definition()
}
def.AdditionalProperties = false
return def
}
func (o object) FromAny(val any) (reflect.Value, error) {
// if the value is nil, we can't do anything
if val == nil {
return reflect.Value{}, nil
}
// now make a new object of the type we're trying to parse
obj := reflect.New(o.ref).Elem()
// now we need to iterate over the fields and set the values
for k, v := range o.fields {
// get the field by name
field := obj.FieldByName(k)
if !field.IsValid() {
return reflect.Value{}, errors.New("field " + k + " not found")
}
// get the value from the map
val2, ok := val.(map[string]interface{})[k]
if !ok {
return reflect.Value{}, errors.New("field " + k + " not found in map")
}
// now we need to convert the value to the correct type
val3, err := v.FromAny(val2)
if err != nil {
return reflect.Value{}, err
}
// now we need to set the value on the field
v.SetValueOnField(field, val3)
}
return obj, nil
}
func (o object) SetValueOnField(obj reflect.Value, val reflect.Value) {
// if this basic type is not required that means it's a pointer type so we need to set the value to the address of the value
if !o.required {
val = val.Addr()
}
obj.Field(o.index).Set(val)
}

18
schema/type.go Normal file
View File

@@ -0,0 +1,18 @@
package schema
import (
"reflect"
"github.com/sashabaranov/go-openai/jsonschema"
)
type Type interface {
SchemaType() jsonschema.DataType
Definition() jsonschema.Definition
Required() bool
Description() string
FromAny(any) (reflect.Value, error)
SetValueOnField(obj reflect.Value, val reflect.Value)
}

79
toolbox.go Normal file
View File

@@ -0,0 +1,79 @@
package go_llm
import (
"context"
"errors"
"fmt"
"github.com/sashabaranov/go-openai"
)
// ToolBox is a collection of tools that OpenAI can use to execute functions.
// It is a wrapper around a collection of functions, and provides a way to automatically call the correct function with
// the correct parameters.
type ToolBox struct {
funcs []Function
names map[string]Function
}
func NewToolBox(fns ...*Function) *ToolBox {
res := ToolBox{
funcs: []Function{},
names: map[string]Function{},
}
for _, f := range fns {
o := *f
res.names[o.Name] = o
res.funcs = append(res.funcs, o)
}
return &res
}
func (t *ToolBox) WithFunction(f Function) *ToolBox {
t2 := *t
t2.names[f.Name] = f
t2.funcs = append(t2.funcs, f)
return &t2
}
// ToOpenAI will convert the current ToolBox to a slice of openai.Tool, which can be used to send to the OpenAI API.
func (t *ToolBox) toOpenAI() []openai.Tool {
var res []openai.Tool
for _, f := range t.funcs {
res = append(res, openai.Tool{
Type: "function",
Function: f.toOpenAIFunction(),
})
}
return res
}
func (t *ToolBox) ToToolChoice() any {
if len(t.funcs) == 0 {
return nil
}
return "required"
}
var (
ErrFunctionNotFound = errors.New("function not found")
)
func (t *ToolBox) ExecuteFunction(ctx context.Context, functionName string, params string) (string, error) {
f, ok := t.names[functionName]
if !ok {
return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName))
}
return f.Execute(ctx, params)
}
func (t *ToolBox) Execute(ctx context.Context, toolCall ToolCall) (string, error) {
return t.ExecuteFunction(ctx, toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments)
}