91 lines
1.9 KiB
Go
91 lines
1.9 KiB
Go
|
package go_llm
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"encoding/json"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"github.com/sashabaranov/go-openai"
|
||
|
"log/slog"
|
||
|
)
|
||
|
|
||
|
// ToolBox is a collection of tools that OpenAI can use to execute functions.
|
||
|
// It is a wrapper around a collection of functions, and provides a way to automatically call the correct function with
|
||
|
// the correct parameters.
|
||
|
type ToolBox struct {
|
||
|
funcs []Function
|
||
|
names map[string]Function
|
||
|
}
|
||
|
|
||
|
func NewToolBox(fns ...*Function) *ToolBox {
|
||
|
res := ToolBox{
|
||
|
funcs: []Function{},
|
||
|
names: map[string]Function{},
|
||
|
}
|
||
|
|
||
|
for _, f := range fns {
|
||
|
o := *f
|
||
|
res.names[o.Name] = o
|
||
|
res.funcs = append(res.funcs, o)
|
||
|
}
|
||
|
|
||
|
return &res
|
||
|
}
|
||
|
|
||
|
func (t *ToolBox) WithFunction(f Function) *ToolBox {
|
||
|
t2 := *t
|
||
|
t2.names[f.Name] = f
|
||
|
t2.funcs = append(t2.funcs, f)
|
||
|
|
||
|
return &t2
|
||
|
}
|
||
|
|
||
|
// ToOpenAI will convert the current ToolBox to a slice of openai.Tool, which can be used to send to the OpenAI API.
|
||
|
func (t *ToolBox) toOpenAI() []openai.Tool {
|
||
|
var res []openai.Tool
|
||
|
|
||
|
for _, f := range t.funcs {
|
||
|
res = append(res, openai.Tool{
|
||
|
Type: "function",
|
||
|
Function: f.toOpenAIFunction(),
|
||
|
})
|
||
|
}
|
||
|
|
||
|
return res
|
||
|
}
|
||
|
|
||
|
func (t *ToolBox) ToToolChoice() any {
|
||
|
if len(t.funcs) == 0 {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
return "required"
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
ErrFunctionNotFound = errors.New("function not found")
|
||
|
)
|
||
|
|
||
|
func (t *ToolBox) ExecuteFunction(ctx context.Context, functionName string, params string) (string, error) {
|
||
|
f, ok := t.names[functionName]
|
||
|
|
||
|
slog.Info("functionName", functionName)
|
||
|
|
||
|
if !ok {
|
||
|
return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName))
|
||
|
}
|
||
|
|
||
|
return f.Execute(ctx, params)
|
||
|
}
|
||
|
|
||
|
func (t *ToolBox) Execute(ctx context.Context, toolCall ToolCall) (string, error) {
|
||
|
slog.Info("toolCall", toolCall)
|
||
|
|
||
|
b, err := json.Marshal(toolCall.FunctionCall.Arguments)
|
||
|
|
||
|
if err != nil {
|
||
|
return "", fmt.Errorf("failed to marshal arguments: %w", err)
|
||
|
}
|
||
|
return t.ExecuteFunction(ctx, toolCall.ID, string(b))
|
||
|
}
|