Compare commits
No commits in common. "46a526fd5a16ea0f4d5f7e9ac2904bbcac12fca2" and "0b06fd965e229b5862977f55e6dedfa70659112d" have entirely different histories.
46a526fd5a
...
0b06fd965e
22
anthropic.go
22
anthropic.go
@ -2,7 +2,6 @@ package go_llm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
@ -121,7 +120,7 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tool := range req.Toolbox.funcs {
|
for _, tool := range req.Toolbox {
|
||||||
res.Tools = append(res.Tools, anth.ToolDefinition{
|
res.Tools = append(res.Tools, anth.ToolDefinition{
|
||||||
Name: tool.Name,
|
Name: tool.Name,
|
||||||
Description: tool.Description,
|
Description: tool.Description,
|
||||||
@ -154,18 +153,13 @@ func (a anthropic) responseToLLMResponse(in anth.MessagesResponse) Response {
|
|||||||
|
|
||||||
case anth.MessagesContentTypeToolUse:
|
case anth.MessagesContentTypeToolUse:
|
||||||
if msg.MessageContentToolUse != nil {
|
if msg.MessageContentToolUse != nil {
|
||||||
b, e := json.Marshal(msg.MessageContentToolUse.Input)
|
choice.Calls = append(choice.Calls, ToolCall{
|
||||||
if e != nil {
|
ID: msg.MessageContentToolUse.ID,
|
||||||
log.Println("failed to marshal input", e)
|
FunctionCall: FunctionCall{
|
||||||
} else {
|
Name: msg.MessageContentToolUse.Name,
|
||||||
choice.Calls = append(choice.Calls, ToolCall{
|
Arguments: msg.MessageContentToolUse.Input,
|
||||||
ID: msg.MessageContentToolUse.ID,
|
},
|
||||||
FunctionCall: FunctionCall{
|
})
|
||||||
Name: msg.MessageContentToolUse.Name,
|
|
||||||
Arguments: string(b),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
21
error.go
21
error.go
@ -1,21 +0,0 @@
|
|||||||
package go_llm
|
|
||||||
|
|
||||||
import "fmt"
|
|
||||||
|
|
||||||
// Error is essentially just an error, but it is used to differentiate between a normal error and a fatal error.
|
|
||||||
type Error struct {
|
|
||||||
error
|
|
||||||
|
|
||||||
Source error
|
|
||||||
Parameter error
|
|
||||||
}
|
|
||||||
|
|
||||||
func newError(parent error, err error) Error {
|
|
||||||
e := fmt.Errorf("%w: %w", parent, err)
|
|
||||||
return Error{
|
|
||||||
error: e,
|
|
||||||
|
|
||||||
Source: parent,
|
|
||||||
Parameter: err,
|
|
||||||
}
|
|
||||||
}
|
|
89
function.go
89
function.go
@ -1,92 +1,13 @@
|
|||||||
package go_llm
|
package go_llm
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
|
|
||||||
"github.com/sashabaranov/go-openai"
|
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
|
||||||
"reflect"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Function struct {
|
type Function struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Description string `json:"description,omitempty"`
|
Description string `json:"description,omitempty"`
|
||||||
Strict bool `json:"strict,omitempty"`
|
Strict bool `json:"strict,omitempty"`
|
||||||
Parameters schema.Type `json:"parameters"`
|
Parameters any `json:"parameters"`
|
||||||
|
|
||||||
Forced bool `json:"forced,omitempty"`
|
|
||||||
|
|
||||||
// Timeout is the maximum time to wait for the function to complete
|
|
||||||
Timeout time.Duration `json:"-"`
|
|
||||||
|
|
||||||
// fn is the function to call, only set if this is constructed with NewFunction
|
|
||||||
fn reflect.Value
|
|
||||||
|
|
||||||
paramType reflect.Type
|
|
||||||
|
|
||||||
// definition is a cache of the openaiImpl jsonschema definition
|
|
||||||
definition *jsonschema.Definition
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Function) Execute(ctx context.Context, input string) (string, error) {
|
|
||||||
if !f.fn.IsValid() {
|
|
||||||
return "", fmt.Errorf("function %s is not implemented", f.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// first, we need to parse the input into the struct
|
|
||||||
p := reflect.New(f.paramType)
|
|
||||||
fmt.Println("Function.Execute", f.Name, "input:", input)
|
|
||||||
//m := map[string]any{}
|
|
||||||
err := json.Unmarshal([]byte(input), p.Interface())
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to unmarshal input: %w (input: %s)", err, input)
|
|
||||||
}
|
|
||||||
|
|
||||||
// now we can call the function
|
|
||||||
exec := func(ctx context.Context) (string, error) {
|
|
||||||
out := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()})
|
|
||||||
|
|
||||||
if len(out) != 2 {
|
|
||||||
return "", fmt.Errorf("function %s must return two values, got %d", f.Name, len(out))
|
|
||||||
}
|
|
||||||
|
|
||||||
if out[1].IsNil() {
|
|
||||||
return out[0].String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", out[1].Interface().(error)
|
|
||||||
}
|
|
||||||
|
|
||||||
var cancel context.CancelFunc
|
|
||||||
if f.Timeout > 0 {
|
|
||||||
ctx, cancel = context.WithTimeout(ctx, f.Timeout)
|
|
||||||
defer cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
return exec(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Function) toOpenAIFunction() *openai.FunctionDefinition {
|
|
||||||
return &openai.FunctionDefinition{
|
|
||||||
Name: f.Name,
|
|
||||||
Description: f.Description,
|
|
||||||
Strict: f.Strict,
|
|
||||||
Parameters: f.Parameters,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
func (f *Function) toOpenAIDefinition() jsonschema.Definition {
|
|
||||||
if f.definition == nil {
|
|
||||||
def := f.Parameters.Definition()
|
|
||||||
f.definition = &def
|
|
||||||
}
|
|
||||||
|
|
||||||
return *f.definition
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type FunctionCall struct {
|
type FunctionCall struct {
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
Arguments string `json:"arguments,omitempty"`
|
Arguments any `json:"arguments,omitempty"`
|
||||||
}
|
}
|
||||||
|
35
functions.go
35
functions.go
@ -1,35 +0,0 @@
|
|||||||
package go_llm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
|
|
||||||
"reflect"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Parse takes a function pointer and returns a function object.
|
|
||||||
// fn must be a pointer to a function that takes a context.Context as its first argument, and then a struct that contains
|
|
||||||
// the parameters for the function. The struct must contain only the types: string, int, float64, bool, and pointers to
|
|
||||||
// those types.
|
|
||||||
// The struct parameters can have the following tags:
|
|
||||||
// - Description: a string that describes the parameter, passed to openaiImpl to tell it what the parameter is for
|
|
||||||
|
|
||||||
func NewFunction[T any](name string, description string, fn func(context.Context, T) (string, error)) *Function {
|
|
||||||
var o T
|
|
||||||
|
|
||||||
res := Function{
|
|
||||||
Name: name,
|
|
||||||
Description: description,
|
|
||||||
Parameters: schema.GetType(o),
|
|
||||||
fn: reflect.ValueOf(fn),
|
|
||||||
paramType: reflect.TypeOf(o),
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.fn.Kind() != reflect.Func {
|
|
||||||
panic("fn must be a function")
|
|
||||||
}
|
|
||||||
if res.paramType.Kind() != reflect.Struct {
|
|
||||||
panic("function parameter must be a struct")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &res
|
|
||||||
}
|
|
13
google.go
13
google.go
@ -2,7 +2,6 @@ package go_llm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/google/generative-ai-go/genai"
|
"github.com/google/generative-ai-go/genai"
|
||||||
"google.golang.org/api/option"
|
"google.golang.org/api/option"
|
||||||
@ -31,8 +30,8 @@ func (g google) requestToGoogleRequest(in Request, model *genai.GenerativeModel)
|
|||||||
res = append(res, genai.Text(c.Text))
|
res = append(res, genai.Text(c.Text))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tool := range in.Toolbox.funcs {
|
for _, tool := range in.Toolbox {
|
||||||
panic("google ToolBox is todo" + tool.Name)
|
panic("google toolbox is todo" + tool.Name)
|
||||||
|
|
||||||
/*
|
/*
|
||||||
t := genai.Tool{}
|
t := genai.Tool{}
|
||||||
@ -64,17 +63,11 @@ func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Respon
|
|||||||
choice := ResponseChoice{}
|
choice := ResponseChoice{}
|
||||||
|
|
||||||
choice.Content = v.Name
|
choice.Content = v.Name
|
||||||
b, e := json.Marshal(v.Args)
|
|
||||||
|
|
||||||
if e != nil {
|
|
||||||
return Response{}, fmt.Errorf("error marshalling args: %w", e)
|
|
||||||
}
|
|
||||||
|
|
||||||
call := ToolCall{
|
call := ToolCall{
|
||||||
ID: v.Name,
|
ID: v.Name,
|
||||||
FunctionCall: FunctionCall{
|
FunctionCall: FunctionCall{
|
||||||
Name: v.Name,
|
Name: v.Name,
|
||||||
Arguments: string(b),
|
Arguments: v.Args,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
4
llm.go
4
llm.go
@ -27,7 +27,7 @@ type Message struct {
|
|||||||
|
|
||||||
type Request struct {
|
type Request struct {
|
||||||
Messages []Message
|
Messages []Message
|
||||||
Toolbox *ToolBox
|
Toolbox []Function
|
||||||
Temperature *float32
|
Temperature *float32
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,7 +57,7 @@ type LLM interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func OpenAI(key string) LLM {
|
func OpenAI(key string) LLM {
|
||||||
return openaiImpl{key: key}
|
return openai{key: key}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Anthropic(key string) LLM {
|
func Anthropic(key string) LLM {
|
||||||
|
30
openai.go
30
openai.go
@ -7,14 +7,14 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type openaiImpl struct {
|
type openai struct {
|
||||||
key string
|
key string
|
||||||
model string
|
model string
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ LLM = openaiImpl{}
|
var _ LLM = openai{}
|
||||||
|
|
||||||
func (o openaiImpl) requestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
|
func (o openai) requestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
|
||||||
res := oai.ChatCompletionRequest{
|
res := oai.ChatCompletionRequest{
|
||||||
Model: o.model,
|
Model: o.model,
|
||||||
}
|
}
|
||||||
@ -57,18 +57,16 @@ func (o openaiImpl) requestToOpenAIRequest(request Request) oai.ChatCompletionRe
|
|||||||
res.Messages = append(res.Messages, m)
|
res.Messages = append(res.Messages, m)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tool := range request.Toolbox.funcs {
|
for _, tool := range request.Toolbox {
|
||||||
res.Tools = append(res.Tools, oai.Tool{
|
res.Tools = append(res.Tools, oai.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
Function: &oai.FunctionDefinition{
|
Function: &oai.FunctionDefinition{
|
||||||
Name: tool.Name,
|
Name: tool.Name,
|
||||||
Description: tool.Description,
|
Description: tool.Description,
|
||||||
Strict: tool.Strict,
|
Strict: tool.Strict,
|
||||||
Parameters: tool.Parameters.Definition(),
|
Parameters: tool.Parameters,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
fmt.Println("tool:", tool.Name, tool.Description, tool.Strict, tool.Parameters.Definition())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if request.Temperature != nil {
|
if request.Temperature != nil {
|
||||||
@ -92,13 +90,12 @@ func (o openaiImpl) requestToOpenAIRequest(request Request) oai.ChatCompletionRe
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) Response {
|
func (o openai) responseToLLMResponse(response oai.ChatCompletionResponse) Response {
|
||||||
res := Response{}
|
res := Response{}
|
||||||
|
|
||||||
for _, choice := range response.Choices {
|
for _, choice := range response.Choices {
|
||||||
var toolCalls []ToolCall
|
var tools []ToolCall
|
||||||
for _, call := range choice.Message.ToolCalls {
|
for _, call := range choice.Message.ToolCalls {
|
||||||
fmt.Println("responseToLLMResponse: call:", call.Function.Arguments)
|
|
||||||
toolCall := ToolCall{
|
toolCall := ToolCall{
|
||||||
ID: call.ID,
|
ID: call.ID,
|
||||||
FunctionCall: FunctionCall{
|
FunctionCall: FunctionCall{
|
||||||
@ -107,9 +104,7 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("toolCall.FunctionCall.Arguments:", toolCall.FunctionCall.Arguments)
|
tools = append(tools, toolCall)
|
||||||
|
|
||||||
toolCalls = append(toolCalls, toolCall)
|
|
||||||
|
|
||||||
}
|
}
|
||||||
res.Choices = append(res.Choices, ResponseChoice{
|
res.Choices = append(res.Choices, ResponseChoice{
|
||||||
@ -117,14 +112,13 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R
|
|||||||
Role: Role(choice.Message.Role),
|
Role: Role(choice.Message.Role),
|
||||||
Name: choice.Message.Name,
|
Name: choice.Message.Name,
|
||||||
Refusal: choice.Message.Refusal,
|
Refusal: choice.Message.Refusal,
|
||||||
Calls: toolCalls,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) {
|
func (o openai) ChatComplete(ctx context.Context, request Request) (Response, error) {
|
||||||
cl := oai.NewClient(o.key)
|
cl := oai.NewClient(o.key)
|
||||||
|
|
||||||
req := o.requestToOpenAIRequest(request)
|
req := o.requestToOpenAIRequest(request)
|
||||||
@ -134,14 +128,14 @@ func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response
|
|||||||
fmt.Println("resp:", fmt.Sprintf("%#v", resp))
|
fmt.Println("resp:", fmt.Sprintf("%#v", resp))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Response{}, fmt.Errorf("unhandled openaiImpl error: %w", err)
|
return Response{}, fmt.Errorf("unhandled openai error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return o.responseToLLMResponse(resp), nil
|
return o.responseToLLMResponse(resp), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o openaiImpl) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
func (o openai) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
||||||
return openaiImpl{
|
return openai{
|
||||||
key: o.key,
|
key: o.key,
|
||||||
model: modelVersion,
|
model: modelVersion,
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -1,125 +0,0 @@
|
|||||||
package schema
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
|
||||||
)
|
|
||||||
|
|
||||||
// GetType will, given an interface{} that is a struct (NOT a pointer to a struct), return the Type of the struct that
|
|
||||||
// can be used to generate a json schema and build an object from a parsed json object.
|
|
||||||
func GetType(a any) Type {
|
|
||||||
t := reflect.TypeOf(a)
|
|
||||||
|
|
||||||
if t.Kind() != reflect.Struct {
|
|
||||||
panic("GetType expects a struct")
|
|
||||||
}
|
|
||||||
|
|
||||||
return getObject(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getFromType(t reflect.Type, b basic) Type {
|
|
||||||
if t.Kind() == reflect.Ptr {
|
|
||||||
t = t.Elem()
|
|
||||||
b.required = false
|
|
||||||
}
|
|
||||||
|
|
||||||
switch t.Kind() {
|
|
||||||
case reflect.String:
|
|
||||||
b.DataType = jsonschema.String
|
|
||||||
return b
|
|
||||||
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
||||||
b.DataType = jsonschema.Integer
|
|
||||||
return b
|
|
||||||
|
|
||||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
||||||
b.DataType = jsonschema.Integer
|
|
||||||
return b
|
|
||||||
|
|
||||||
case reflect.Float32, reflect.Float64:
|
|
||||||
b.DataType = jsonschema.Number
|
|
||||||
return b
|
|
||||||
|
|
||||||
case reflect.Bool:
|
|
||||||
b.DataType = jsonschema.Boolean
|
|
||||||
return b
|
|
||||||
|
|
||||||
case reflect.Struct:
|
|
||||||
o := getObject(t)
|
|
||||||
|
|
||||||
o.basic.required = b.required
|
|
||||||
o.basic.index = b.index
|
|
||||||
o.basic.description = b.description
|
|
||||||
|
|
||||||
return o
|
|
||||||
|
|
||||||
case reflect.Slice:
|
|
||||||
return getArray(t)
|
|
||||||
|
|
||||||
default:
|
|
||||||
panic("unhandled default case for " + t.Kind().String() + " in getFromType")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getField(f reflect.StructField, index int) Type {
|
|
||||||
b := basic{
|
|
||||||
index: index,
|
|
||||||
required: true,
|
|
||||||
description: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
t := f.Type
|
|
||||||
|
|
||||||
// if the tag "description" is set, use that as the description
|
|
||||||
if desc, ok := f.Tag.Lookup("description"); ok {
|
|
||||||
b.description = desc
|
|
||||||
}
|
|
||||||
|
|
||||||
// now if the tag "enum" is set, we need to create an enum type
|
|
||||||
if v, ok := f.Tag.Lookup("enum"); ok {
|
|
||||||
vals := strings.Split(v, ",")
|
|
||||||
|
|
||||||
for i := 0; i < len(vals); i++ {
|
|
||||||
vals[i] = strings.TrimSpace(vals[i])
|
|
||||||
|
|
||||||
if vals[i] == "" {
|
|
||||||
vals = append(vals[:i], vals[i+1:]...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return enum{
|
|
||||||
basic: b,
|
|
||||||
values: vals,
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return getFromType(t, b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getObject(t reflect.Type) object {
|
|
||||||
fields := make(map[string]Type, t.NumField())
|
|
||||||
for i := 0; i < t.NumField(); i++ {
|
|
||||||
field := t.Field(i)
|
|
||||||
fields[field.Name] = getField(field, i)
|
|
||||||
}
|
|
||||||
|
|
||||||
return object{
|
|
||||||
basic: basic{DataType: jsonschema.Object},
|
|
||||||
fields: fields,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getArray(t reflect.Type) array {
|
|
||||||
res := array{
|
|
||||||
basic: basic{
|
|
||||||
DataType: jsonschema.Array,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
res.items = getFromType(t.Elem(), basic{})
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
@ -1,65 +0,0 @@
|
|||||||
package schema
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
|
||||||
)
|
|
||||||
|
|
||||||
type array struct {
|
|
||||||
basic
|
|
||||||
|
|
||||||
// items is the schema of the items in the array
|
|
||||||
items Type
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a array) SchemaType() jsonschema.DataType {
|
|
||||||
return jsonschema.Array
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a array) Definition() jsonschema.Definition {
|
|
||||||
def := a.basic.Definition()
|
|
||||||
def.Type = jsonschema.Array
|
|
||||||
i := a.items.Definition()
|
|
||||||
def.Items = &i
|
|
||||||
def.AdditionalProperties = false
|
|
||||||
return def
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a array) FromAny(val any) (reflect.Value, error) {
|
|
||||||
v := reflect.ValueOf(val)
|
|
||||||
|
|
||||||
// first realize we may have a pointer to a slice if this type is not required
|
|
||||||
if !a.required && v.Kind() == reflect.Ptr {
|
|
||||||
v = v.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if v.Kind() != reflect.Slice {
|
|
||||||
return reflect.Value{}, errors.New("expected slice, got " + v.Kind().String())
|
|
||||||
}
|
|
||||||
|
|
||||||
// if the slice is nil, we can just return it
|
|
||||||
if v.IsNil() {
|
|
||||||
return v, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// if the slice is not nil, we need to convert each item
|
|
||||||
items := make([]reflect.Value, v.Len())
|
|
||||||
for i := 0; i < v.Len(); i++ {
|
|
||||||
item, err := a.items.FromAny(v.Index(i).Interface())
|
|
||||||
if err != nil {
|
|
||||||
return reflect.Value{}, err
|
|
||||||
}
|
|
||||||
items[i] = item
|
|
||||||
}
|
|
||||||
|
|
||||||
return reflect.ValueOf(items), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a array) SetValue(obj reflect.Value, val reflect.Value) {
|
|
||||||
if !a.required {
|
|
||||||
val = val.Addr()
|
|
||||||
}
|
|
||||||
obj.Field(a.index).Set(val)
|
|
||||||
}
|
|
105
schema/basic.go
105
schema/basic.go
@ -1,105 +0,0 @@
|
|||||||
package schema
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"reflect"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
|
||||||
)
|
|
||||||
|
|
||||||
// just enforcing that basic implements Type
|
|
||||||
var _ Type = basic{}
|
|
||||||
|
|
||||||
type basic struct {
|
|
||||||
jsonschema.DataType
|
|
||||||
|
|
||||||
// index is the position of the parameter in the StructField of the function's parameter struct
|
|
||||||
index int
|
|
||||||
|
|
||||||
// required is a flag that indicates whether the parameter is required in the function's parameter struct.
|
|
||||||
// this is inferred by if the parameter is a pointer type or not.
|
|
||||||
required bool
|
|
||||||
|
|
||||||
// description is a llm-readable description of the parameter passed to openai
|
|
||||||
description string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b basic) SchemaType() jsonschema.DataType {
|
|
||||||
return b.DataType
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b basic) Definition() jsonschema.Definition {
|
|
||||||
return jsonschema.Definition{
|
|
||||||
Type: b.DataType,
|
|
||||||
Description: b.description,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b basic) Required() bool {
|
|
||||||
return b.required
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b basic) Description() string {
|
|
||||||
return b.description
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b basic) FromAny(val any) (reflect.Value, error) {
|
|
||||||
v := reflect.ValueOf(val)
|
|
||||||
|
|
||||||
switch b.DataType {
|
|
||||||
case jsonschema.String:
|
|
||||||
var val = v.String()
|
|
||||||
|
|
||||||
return reflect.ValueOf(val), nil
|
|
||||||
|
|
||||||
case jsonschema.Integer:
|
|
||||||
if v.Kind() == reflect.Float64 {
|
|
||||||
return v.Convert(reflect.TypeOf(int(0))), nil
|
|
||||||
} else if v.Kind() != reflect.Int {
|
|
||||||
return reflect.Value{}, errors.New("expected int, got " + v.Kind().String())
|
|
||||||
} else {
|
|
||||||
return v, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
case jsonschema.Number:
|
|
||||||
if v.Kind() == reflect.Float64 {
|
|
||||||
return v.Convert(reflect.TypeOf(float64(0))), nil
|
|
||||||
} else if v.Kind() != reflect.Float64 {
|
|
||||||
return reflect.Value{}, errors.New("expected float64, got " + v.Kind().String())
|
|
||||||
} else {
|
|
||||||
return v, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
case jsonschema.Boolean:
|
|
||||||
if v.Kind() == reflect.Bool {
|
|
||||||
return v, nil
|
|
||||||
} else if v.Kind() == reflect.String {
|
|
||||||
b, err := strconv.ParseBool(v.String())
|
|
||||||
if err != nil {
|
|
||||||
return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String())
|
|
||||||
}
|
|
||||||
return reflect.ValueOf(b), nil
|
|
||||||
} else {
|
|
||||||
return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String())
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
return reflect.Value{}, errors.New("unknown type")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b basic) SetValueOnField(obj reflect.Value, val reflect.Value) {
|
|
||||||
// if this basic type is not required that means it's a pointer type
|
|
||||||
// so we need to create a new value of the type of the pointer
|
|
||||||
if !b.required {
|
|
||||||
vv := reflect.New(obj.Field(b.index).Type().Elem())
|
|
||||||
|
|
||||||
// and then set the value of the pointer to the new value
|
|
||||||
vv.Elem().Set(val)
|
|
||||||
|
|
||||||
obj.Field(b.index).Set(vv)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
obj.Field(b.index).Set(val)
|
|
||||||
}
|
|
@ -1,47 +0,0 @@
|
|||||||
package schema
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
|
||||||
)
|
|
||||||
|
|
||||||
type enum struct {
|
|
||||||
basic
|
|
||||||
|
|
||||||
values []string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e enum) SchemaType() jsonschema.DataType {
|
|
||||||
return jsonschema.String
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e enum) Definition() jsonschema.Definition {
|
|
||||||
def := e.basic.Definition()
|
|
||||||
def.Enum = e.values
|
|
||||||
return def
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e enum) FromAny(val any) (reflect.Value, error) {
|
|
||||||
v := reflect.ValueOf(val)
|
|
||||||
if v.Kind() != reflect.String {
|
|
||||||
return reflect.Value{}, errors.New("expected string, got " + v.Kind().String())
|
|
||||||
}
|
|
||||||
|
|
||||||
s := v.String()
|
|
||||||
if !slices.Contains(e.values, s) {
|
|
||||||
return reflect.Value{}, errors.New("value " + s + " not in enum")
|
|
||||||
}
|
|
||||||
|
|
||||||
return v, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e enum) SetValueOnField(obj reflect.Value, val reflect.Value) {
|
|
||||||
if !e.required {
|
|
||||||
val = val.Addr()
|
|
||||||
}
|
|
||||||
obj.Field(e.index).Set(val)
|
|
||||||
}
|
|
@ -1,78 +0,0 @@
|
|||||||
package schema
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
|
||||||
)
|
|
||||||
|
|
||||||
type object struct {
|
|
||||||
basic
|
|
||||||
|
|
||||||
ref reflect.Type
|
|
||||||
|
|
||||||
fields map[string]Type
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o object) SchemaType() jsonschema.DataType {
|
|
||||||
return jsonschema.Object
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o object) Definition() jsonschema.Definition {
|
|
||||||
def := o.basic.Definition()
|
|
||||||
def.Type = jsonschema.Object
|
|
||||||
def.Properties = make(map[string]jsonschema.Definition)
|
|
||||||
for k, v := range o.fields {
|
|
||||||
def.Properties[k] = v.Definition()
|
|
||||||
}
|
|
||||||
|
|
||||||
def.AdditionalProperties = false
|
|
||||||
return def
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o object) FromAny(val any) (reflect.Value, error) {
|
|
||||||
// if the value is nil, we can't do anything
|
|
||||||
if val == nil {
|
|
||||||
return reflect.Value{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// now make a new object of the type we're trying to parse
|
|
||||||
obj := reflect.New(o.ref).Elem()
|
|
||||||
|
|
||||||
// now we need to iterate over the fields and set the values
|
|
||||||
for k, v := range o.fields {
|
|
||||||
// get the field by name
|
|
||||||
field := obj.FieldByName(k)
|
|
||||||
if !field.IsValid() {
|
|
||||||
return reflect.Value{}, errors.New("field " + k + " not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
// get the value from the map
|
|
||||||
val2, ok := val.(map[string]interface{})[k]
|
|
||||||
if !ok {
|
|
||||||
return reflect.Value{}, errors.New("field " + k + " not found in map")
|
|
||||||
}
|
|
||||||
|
|
||||||
// now we need to convert the value to the correct type
|
|
||||||
val3, err := v.FromAny(val2)
|
|
||||||
if err != nil {
|
|
||||||
return reflect.Value{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// now we need to set the value on the field
|
|
||||||
v.SetValueOnField(field, val3)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return obj, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o object) SetValueOnField(obj reflect.Value, val reflect.Value) {
|
|
||||||
// if this basic type is not required that means it's a pointer type so we need to set the value to the address of the value
|
|
||||||
if !o.required {
|
|
||||||
val = val.Addr()
|
|
||||||
}
|
|
||||||
|
|
||||||
obj.Field(o.index).Set(val)
|
|
||||||
}
|
|
@ -1,18 +0,0 @@
|
|||||||
package schema
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Type interface {
|
|
||||||
SchemaType() jsonschema.DataType
|
|
||||||
Definition() jsonschema.Definition
|
|
||||||
|
|
||||||
Required() bool
|
|
||||||
Description() string
|
|
||||||
|
|
||||||
FromAny(any) (reflect.Value, error)
|
|
||||||
SetValueOnField(obj reflect.Value, val reflect.Value)
|
|
||||||
}
|
|
79
toolbox.go
79
toolbox.go
@ -1,79 +0,0 @@
|
|||||||
package go_llm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"github.com/sashabaranov/go-openai"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ToolBox is a collection of tools that OpenAI can use to execute functions.
|
|
||||||
// It is a wrapper around a collection of functions, and provides a way to automatically call the correct function with
|
|
||||||
// the correct parameters.
|
|
||||||
type ToolBox struct {
|
|
||||||
funcs []Function
|
|
||||||
names map[string]Function
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewToolBox(fns ...*Function) *ToolBox {
|
|
||||||
res := ToolBox{
|
|
||||||
funcs: []Function{},
|
|
||||||
names: map[string]Function{},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, f := range fns {
|
|
||||||
o := *f
|
|
||||||
res.names[o.Name] = o
|
|
||||||
res.funcs = append(res.funcs, o)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *ToolBox) WithFunction(f Function) *ToolBox {
|
|
||||||
t2 := *t
|
|
||||||
t2.names[f.Name] = f
|
|
||||||
t2.funcs = append(t2.funcs, f)
|
|
||||||
|
|
||||||
return &t2
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToOpenAI will convert the current ToolBox to a slice of openai.Tool, which can be used to send to the OpenAI API.
|
|
||||||
func (t *ToolBox) toOpenAI() []openai.Tool {
|
|
||||||
var res []openai.Tool
|
|
||||||
|
|
||||||
for _, f := range t.funcs {
|
|
||||||
res = append(res, openai.Tool{
|
|
||||||
Type: "function",
|
|
||||||
Function: f.toOpenAIFunction(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *ToolBox) ToToolChoice() any {
|
|
||||||
if len(t.funcs) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return "required"
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrFunctionNotFound = errors.New("function not found")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (t *ToolBox) ExecuteFunction(ctx context.Context, functionName string, params string) (string, error) {
|
|
||||||
f, ok := t.names[functionName]
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName))
|
|
||||||
}
|
|
||||||
|
|
||||||
return f.Execute(ctx, params)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *ToolBox) Execute(ctx context.Context, toolCall ToolCall) (string, error) {
|
|
||||||
return t.ExecuteFunction(ctx, toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments)
|
|
||||||
}
|
|
Loading…
x
Reference in New Issue
Block a user