161 lines
3.8 KiB
Go
161 lines
3.8 KiB
Go
package go_llm
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
)
|
|
|
|
// ToolBox is a collection of tools that OpenAI can use to execute functions.
|
|
// It is a wrapper around a collection of functions, and provides a way to automatically call the correct function with
|
|
// the correct parameters.
|
|
type ToolBox struct {
|
|
functions map[string]Function
|
|
dontRequireTool bool
|
|
}
|
|
|
|
func NewToolBox(fns ...Function) ToolBox {
|
|
res := ToolBox{
|
|
functions: map[string]Function{},
|
|
}
|
|
|
|
for _, f := range fns {
|
|
res.functions[f.Name] = f
|
|
}
|
|
|
|
return res
|
|
}
|
|
|
|
func (t ToolBox) Functions() []Function {
|
|
var res []Function
|
|
|
|
for _, f := range t.functions {
|
|
res = append(res, f)
|
|
}
|
|
|
|
return res
|
|
}
|
|
|
|
func (t ToolBox) WithFunction(f Function) ToolBox {
|
|
t.functions[f.Name] = f
|
|
|
|
return t
|
|
}
|
|
|
|
func (t ToolBox) WithFunctions(fns ...Function) ToolBox {
|
|
for _, f := range fns {
|
|
t.functions[f.Name] = f
|
|
}
|
|
|
|
return t
|
|
}
|
|
|
|
func (t ToolBox) WithSyntheticFieldsAddedToAllFunctions(fieldsAndDescriptions map[string]string) ToolBox {
|
|
for k, v := range t.functions {
|
|
t.functions[k] = v.WithSyntheticFields(fieldsAndDescriptions)
|
|
}
|
|
|
|
return t
|
|
}
|
|
|
|
func (t ToolBox) ForEachFunction(fn func(f Function)) {
|
|
for _, f := range t.functions {
|
|
fn(f)
|
|
}
|
|
}
|
|
|
|
func (t ToolBox) WithFunctionRemoved(name string) ToolBox {
|
|
delete(t.functions, name)
|
|
return t
|
|
}
|
|
|
|
func (t ToolBox) WithRequireTool(val bool) ToolBox {
|
|
t.dontRequireTool = !val
|
|
return t
|
|
}
|
|
|
|
func (t ToolBox) RequiresTool() bool {
|
|
return !t.dontRequireTool && len(t.functions) > 0
|
|
}
|
|
|
|
func (t ToolBox) ToToolChoice() any {
|
|
if len(t.functions) == 0 {
|
|
return nil
|
|
}
|
|
|
|
return "required"
|
|
}
|
|
|
|
var (
|
|
ErrFunctionNotFound = errors.New("function not found")
|
|
)
|
|
|
|
func (t ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) {
|
|
f, ok := t.functions[functionName]
|
|
|
|
if !ok {
|
|
return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName))
|
|
}
|
|
|
|
return f.Execute(ctx, params)
|
|
}
|
|
|
|
func (t ToolBox) Execute(ctx *Context, toolCall ToolCall) (any, error) {
|
|
return t.executeFunction(ctx.WithToolCall(&toolCall), toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments)
|
|
}
|
|
|
|
func (t ToolBox) GetSyntheticParametersFromFunctionContext(ctx context.Context) map[string]string {
|
|
val := ctx.Value("syntheticParameters")
|
|
|
|
if val == nil {
|
|
return nil
|
|
}
|
|
|
|
syntheticParameters, ok := val.(map[string]string)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
return syntheticParameters
|
|
}
|
|
|
|
// ExecuteCallbacks will execute all the tool calls in the given list, and call the given callbacks when a new function is created, and when a function is finished.
|
|
// OnNewFunction is called when a new function is created
|
|
// OnFunctionFinished is called when a function is finished
|
|
func (t ToolBox) ExecuteCallbacks(ctx *Context, toolCalls []ToolCall, OnNewFunction func(ctx context.Context, funcName string, parameter string) (any, error), OnFunctionFinished func(ctx context.Context, funcName string, parameter string, result any, err error, newFunctionResult any) error) ([]ToolCallResponse, error) {
|
|
var res []ToolCallResponse
|
|
|
|
for _, call := range toolCalls {
|
|
ctx := ctx.WithToolCall(&call)
|
|
if call.FunctionCall.Name == "" {
|
|
return nil, newError(ErrFunctionNotFound, errors.New("function name is empty"))
|
|
}
|
|
|
|
var arg any
|
|
if OnNewFunction != nil {
|
|
var err error
|
|
arg, err = OnNewFunction(ctx, call.FunctionCall.Name, call.FunctionCall.Arguments)
|
|
|
|
if err != nil {
|
|
return nil, newError(ErrFunctionNotFound, err)
|
|
}
|
|
}
|
|
out, err := t.Execute(ctx, call)
|
|
|
|
if OnFunctionFinished != nil {
|
|
err := OnFunctionFinished(ctx, call.FunctionCall.Name, call.FunctionCall.Arguments, out, err, arg)
|
|
if err != nil {
|
|
return nil, newError(ErrFunctionNotFound, err)
|
|
}
|
|
}
|
|
|
|
res = append(res, ToolCallResponse{
|
|
ID: call.ID,
|
|
Result: out,
|
|
Error: err,
|
|
})
|
|
}
|
|
|
|
return res, nil
|
|
}
|