go-llm/toolbox.go

161 lines
3.8 KiB
Go

package go_llm
import (
"context"
"errors"
"fmt"
)
// ToolBox is a collection of tools that OpenAI can use to execute functions.
// It is a wrapper around a collection of functions, and provides a way to automatically call the correct function with
// the correct parameters.
type ToolBox struct {
functions map[string]Function
dontRequireTool bool
}
func NewToolBox(fns ...Function) ToolBox {
res := ToolBox{
functions: map[string]Function{},
}
for _, f := range fns {
res.functions[f.Name] = f
}
return res
}
func (t ToolBox) Functions() []Function {
var res []Function
for _, f := range t.functions {
res = append(res, f)
}
return res
}
func (t ToolBox) WithFunction(f Function) ToolBox {
t.functions[f.Name] = f
return t
}
func (t ToolBox) WithFunctions(fns ...Function) ToolBox {
for _, f := range fns {
t.functions[f.Name] = f
}
return t
}
func (t ToolBox) WithSyntheticFieldsAddedToAllFunctions(fieldsAndDescriptions map[string]string) ToolBox {
for k, v := range t.functions {
t.functions[k] = v.WithSyntheticFields(fieldsAndDescriptions)
}
return t
}
func (t ToolBox) ForEachFunction(fn func(f Function)) {
for _, f := range t.functions {
fn(f)
}
}
func (t ToolBox) WithFunctionRemoved(name string) ToolBox {
delete(t.functions, name)
return t
}
func (t ToolBox) WithRequireTool(val bool) ToolBox {
t.dontRequireTool = !val
return t
}
func (t ToolBox) RequiresTool() bool {
return !t.dontRequireTool && len(t.functions) > 0
}
func (t ToolBox) ToToolChoice() any {
if len(t.functions) == 0 {
return nil
}
return "required"
}
var (
ErrFunctionNotFound = errors.New("function not found")
)
func (t ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) {
f, ok := t.functions[functionName]
if !ok {
return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName))
}
return f.Execute(ctx, params)
}
func (t ToolBox) Execute(ctx *Context, toolCall ToolCall) (any, error) {
return t.executeFunction(ctx.WithToolCall(&toolCall), toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments)
}
func (t ToolBox) GetSyntheticParametersFromFunctionContext(ctx context.Context) map[string]string {
val := ctx.Value("syntheticParameters")
if val == nil {
return nil
}
syntheticParameters, ok := val.(map[string]string)
if !ok {
return nil
}
return syntheticParameters
}
// ExecuteCallbacks will execute all the tool calls in the given list, and call the given callbacks when a new function is created, and when a function is finished.
// OnNewFunction is called when a new function is created
// OnFunctionFinished is called when a function is finished
func (t ToolBox) ExecuteCallbacks(ctx *Context, toolCalls []ToolCall, OnNewFunction func(ctx context.Context, funcName string, parameter string) (any, error), OnFunctionFinished func(ctx context.Context, funcName string, parameter string, result any, err error, newFunctionResult any) error) ([]ToolCallResponse, error) {
var res []ToolCallResponse
for _, call := range toolCalls {
ctx := ctx.WithToolCall(&call)
if call.FunctionCall.Name == "" {
return nil, newError(ErrFunctionNotFound, errors.New("function name is empty"))
}
var arg any
if OnNewFunction != nil {
var err error
arg, err = OnNewFunction(ctx, call.FunctionCall.Name, call.FunctionCall.Arguments)
if err != nil {
return nil, newError(ErrFunctionNotFound, err)
}
}
out, err := t.Execute(ctx, call)
if OnFunctionFinished != nil {
err := OnFunctionFinished(ctx, call.FunctionCall.Name, call.FunctionCall.Arguments, out, err, arg)
if err != nil {
return nil, newError(ErrFunctionNotFound, err)
}
}
res = append(res, ToolCallResponse{
ID: call.ID,
Result: out,
Error: err,
})
}
return res, nil
}