initial commit of untested function stuff
This commit is contained in:
90
toolbox.go
Normal file
90
toolbox.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package go_llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// ToolBox is a collection of tools that OpenAI can use to execute functions.
|
||||
// It is a wrapper around a collection of functions, and provides a way to automatically call the correct function with
|
||||
// the correct parameters.
|
||||
type ToolBox struct {
|
||||
funcs []Function
|
||||
names map[string]Function
|
||||
}
|
||||
|
||||
func NewToolBox(fns ...*Function) *ToolBox {
|
||||
res := ToolBox{
|
||||
funcs: []Function{},
|
||||
names: map[string]Function{},
|
||||
}
|
||||
|
||||
for _, f := range fns {
|
||||
o := *f
|
||||
res.names[o.Name] = o
|
||||
res.funcs = append(res.funcs, o)
|
||||
}
|
||||
|
||||
return &res
|
||||
}
|
||||
|
||||
func (t *ToolBox) WithFunction(f Function) *ToolBox {
|
||||
t2 := *t
|
||||
t2.names[f.Name] = f
|
||||
t2.funcs = append(t2.funcs, f)
|
||||
|
||||
return &t2
|
||||
}
|
||||
|
||||
// ToOpenAI will convert the current ToolBox to a slice of openai.Tool, which can be used to send to the OpenAI API.
|
||||
func (t *ToolBox) toOpenAI() []openai.Tool {
|
||||
var res []openai.Tool
|
||||
|
||||
for _, f := range t.funcs {
|
||||
res = append(res, openai.Tool{
|
||||
Type: "function",
|
||||
Function: f.toOpenAIFunction(),
|
||||
})
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (t *ToolBox) ToToolChoice() any {
|
||||
if len(t.funcs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return "required"
|
||||
}
|
||||
|
||||
var (
|
||||
ErrFunctionNotFound = errors.New("function not found")
|
||||
)
|
||||
|
||||
func (t *ToolBox) ExecuteFunction(ctx context.Context, functionName string, params string) (string, error) {
|
||||
f, ok := t.names[functionName]
|
||||
|
||||
slog.Info("functionName", functionName)
|
||||
|
||||
if !ok {
|
||||
return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName))
|
||||
}
|
||||
|
||||
return f.Execute(ctx, params)
|
||||
}
|
||||
|
||||
func (t *ToolBox) Execute(ctx context.Context, toolCall ToolCall) (string, error) {
|
||||
slog.Info("toolCall", toolCall)
|
||||
|
||||
b, err := json.Marshal(toolCall.FunctionCall.Arguments)
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal arguments: %w", err)
|
||||
}
|
||||
return t.ExecuteFunction(ctx, toolCall.ID, string(b))
|
||||
}
|
Reference in New Issue
Block a user