Introduced `WithFunctionRemoved` and `ExecuteCallbacks` methods to enhance `ToolBox` functionality. This allows dynamic function removal and execution of custom callbacks during tool call processing. Also cleaned up logging and improved handling for required tools in `openai.go`.
143 lines
3.5 KiB
Go
143 lines
3.5 KiB
Go
package go_llm
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"github.com/sashabaranov/go-openai"
|
|
)
|
|
|
|
// ToolBox is a collection of tools that OpenAI can use to execute functions.
|
|
// It is a wrapper around a collection of functions, and provides a way to automatically call the correct function with
|
|
// the correct parameters.
|
|
type ToolBox struct {
|
|
funcs []Function
|
|
names map[string]Function
|
|
dontRequireTool bool
|
|
}
|
|
|
|
func NewToolBox(fns ...*Function) *ToolBox {
|
|
res := ToolBox{
|
|
funcs: []Function{},
|
|
names: map[string]Function{},
|
|
}
|
|
|
|
for _, f := range fns {
|
|
o := *f
|
|
res.names[o.Name] = o
|
|
res.funcs = append(res.funcs, o)
|
|
}
|
|
|
|
return &res
|
|
}
|
|
|
|
func (t *ToolBox) WithFunction(f Function) *ToolBox {
|
|
t2 := *t
|
|
t2.names[f.Name] = f
|
|
t2.funcs = append(t2.funcs, f)
|
|
|
|
return &t2
|
|
}
|
|
|
|
func (t *ToolBox) WithFunctionRemoved(name string) *ToolBox {
|
|
t2 := *t
|
|
|
|
delete(t2.names, name)
|
|
|
|
for i, f := range t2.funcs {
|
|
if f.Name == name {
|
|
t2.funcs = append(t2.funcs[:i], t2.funcs[i+1:]...)
|
|
break
|
|
}
|
|
}
|
|
|
|
return &t2
|
|
}
|
|
|
|
func (t *ToolBox) WithRequireTool(val bool) *ToolBox {
|
|
t2 := *t
|
|
t2.dontRequireTool = !val
|
|
return &t2
|
|
}
|
|
|
|
// ToOpenAI will convert the current ToolBox to a slice of openai.Tool, which can be used to send to the OpenAI API.
|
|
func (t *ToolBox) toOpenAI() []openai.Tool {
|
|
var res []openai.Tool
|
|
|
|
for _, f := range t.funcs {
|
|
res = append(res, openai.Tool{
|
|
Type: "function",
|
|
Function: f.toOpenAIFunction(),
|
|
})
|
|
}
|
|
|
|
return res
|
|
}
|
|
|
|
func (t *ToolBox) ToToolChoice() any {
|
|
if len(t.funcs) == 0 {
|
|
return nil
|
|
}
|
|
|
|
return "required"
|
|
}
|
|
|
|
var (
|
|
ErrFunctionNotFound = errors.New("function not found")
|
|
)
|
|
|
|
func (t *ToolBox) executeFunction(ctx *Context, functionName string, params string) (string, error) {
|
|
f, ok := t.names[functionName]
|
|
|
|
if !ok {
|
|
return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName))
|
|
}
|
|
|
|
return f.Execute(ctx, params)
|
|
}
|
|
|
|
func (t *ToolBox) Execute(ctx *Context, toolCall ToolCall) (string, error) {
|
|
return t.executeFunction(ctx.WithToolCall(&toolCall), toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments)
|
|
}
|
|
|
|
// ExecuteCallbacks will execute all the tool calls in the given list, and call the given callbacks when a new function is created, and when a function is finished.
|
|
// OnNewFunction is called when a new function is created
|
|
// OnFunctionFinished is called when a function is finished
|
|
func (t *ToolBox) ExecuteCallbacks(ctx *Context, toolCalls []ToolCall, OnNewFunction func(ctx context.Context, funcName string, parameter string) (any, error), OnFunctionFinished func(ctx context.Context, funcName string, parameter string, result string, err error, newFunctionResult any) error) ([]ToolCallResponse, error) {
|
|
var res []ToolCallResponse
|
|
|
|
for _, call := range toolCalls {
|
|
ctx := ctx.WithToolCall(&call)
|
|
if call.FunctionCall.Name == "" {
|
|
return nil, newError(ErrFunctionNotFound, errors.New("function name is empty"))
|
|
}
|
|
|
|
var arg any
|
|
if OnNewFunction != nil {
|
|
var err error
|
|
arg, err = OnNewFunction(ctx, call.FunctionCall.Name, call.FunctionCall.Arguments)
|
|
|
|
if err != nil {
|
|
return nil, newError(ErrFunctionNotFound, err)
|
|
}
|
|
}
|
|
out, err := t.Execute(ctx, call)
|
|
|
|
if OnFunctionFinished != nil {
|
|
err := OnFunctionFinished(ctx, call.FunctionCall.Name, call.FunctionCall.Arguments, out, err, arg)
|
|
if err != nil {
|
|
return nil, newError(ErrFunctionNotFound, err)
|
|
}
|
|
}
|
|
|
|
res = append(res, ToolCallResponse{
|
|
ID: call.ID,
|
|
Result: out,
|
|
Error: err,
|
|
})
|
|
}
|
|
|
|
return res, nil
|
|
}
|