Refactor toolbox and function handling to support synthetic fields and improve type definitions
This commit is contained in:
136
toolbox.go
136
toolbox.go
@@ -4,79 +4,82 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
// ToolBox is a collection of tools that OpenAI can use to execute functions.
|
||||
// It is a wrapper around a collection of functions, and provides a way to automatically call the correct function with
|
||||
// the correct parameters.
|
||||
type ToolBox struct {
|
||||
funcs []Function
|
||||
names map[string]Function
|
||||
functions map[string]Function
|
||||
dontRequireTool bool
|
||||
}
|
||||
|
||||
func NewToolBox(fns ...*Function) *ToolBox {
|
||||
func NewToolBox(fns ...Function) ToolBox {
|
||||
res := ToolBox{
|
||||
funcs: []Function{},
|
||||
names: map[string]Function{},
|
||||
functions: map[string]Function{},
|
||||
}
|
||||
|
||||
for _, f := range fns {
|
||||
o := *f
|
||||
res.names[o.Name] = o
|
||||
res.funcs = append(res.funcs, o)
|
||||
}
|
||||
|
||||
return &res
|
||||
}
|
||||
|
||||
func (t *ToolBox) WithFunction(f Function) *ToolBox {
|
||||
t2 := *t
|
||||
t2.names[f.Name] = f
|
||||
t2.funcs = append(t2.funcs, f)
|
||||
|
||||
return &t2
|
||||
}
|
||||
|
||||
func (t *ToolBox) WithFunctionRemoved(name string) *ToolBox {
|
||||
t2 := *t
|
||||
|
||||
delete(t2.names, name)
|
||||
|
||||
for i, f := range t2.funcs {
|
||||
if f.Name == name {
|
||||
t2.funcs = append(t2.funcs[:i], t2.funcs[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return &t2
|
||||
}
|
||||
|
||||
func (t *ToolBox) WithRequireTool(val bool) *ToolBox {
|
||||
t2 := *t
|
||||
t2.dontRequireTool = !val
|
||||
return &t2
|
||||
}
|
||||
|
||||
// ToOpenAI will convert the current ToolBox to a slice of openai.Tool, which can be used to send to the OpenAI API.
|
||||
func (t *ToolBox) toOpenAI() []openai.Tool {
|
||||
var res []openai.Tool
|
||||
|
||||
for _, f := range t.funcs {
|
||||
res = append(res, openai.Tool{
|
||||
Type: "function",
|
||||
Function: f.toOpenAIFunction(),
|
||||
})
|
||||
res.functions[f.Name] = f
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (t *ToolBox) ToToolChoice() any {
|
||||
if len(t.funcs) == 0 {
|
||||
func (t ToolBox) Functions() []Function {
|
||||
var res []Function
|
||||
|
||||
for _, f := range t.functions {
|
||||
res = append(res, f)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (t ToolBox) WithFunction(f Function) ToolBox {
|
||||
t.functions[f.Name] = f
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (t ToolBox) WithFunctions(fns ...Function) ToolBox {
|
||||
for _, f := range fns {
|
||||
t.functions[f.Name] = f
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (t ToolBox) WithSyntheticFieldsAddedToAllFunctions(fieldsAndDescriptions map[string]string) ToolBox {
|
||||
for k, v := range t.functions {
|
||||
t.functions[k] = v.WithSyntheticFields(fieldsAndDescriptions)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (t ToolBox) ForEachFunction(fn func(f Function)) {
|
||||
for _, f := range t.functions {
|
||||
fn(f)
|
||||
}
|
||||
}
|
||||
|
||||
func (t ToolBox) WithFunctionRemoved(name string) ToolBox {
|
||||
delete(t.functions, name)
|
||||
return t
|
||||
}
|
||||
|
||||
func (t ToolBox) WithRequireTool(val bool) ToolBox {
|
||||
t.dontRequireTool = !val
|
||||
return t
|
||||
}
|
||||
|
||||
func (t ToolBox) RequiresTool() bool {
|
||||
return !t.dontRequireTool && len(t.functions) > 0
|
||||
}
|
||||
|
||||
func (t ToolBox) ToToolChoice() any {
|
||||
if len(t.functions) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -87,8 +90,8 @@ var (
|
||||
ErrFunctionNotFound = errors.New("function not found")
|
||||
)
|
||||
|
||||
func (t *ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) {
|
||||
f, ok := t.names[functionName]
|
||||
func (t ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) {
|
||||
f, ok := t.functions[functionName]
|
||||
|
||||
if !ok {
|
||||
return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName))
|
||||
@@ -97,14 +100,29 @@ func (t *ToolBox) executeFunction(ctx *Context, functionName string, params stri
|
||||
return f.Execute(ctx, params)
|
||||
}
|
||||
|
||||
func (t *ToolBox) Execute(ctx *Context, toolCall ToolCall) (any, error) {
|
||||
func (t ToolBox) Execute(ctx *Context, toolCall ToolCall) (any, error) {
|
||||
return t.executeFunction(ctx.WithToolCall(&toolCall), toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments)
|
||||
}
|
||||
|
||||
func (t ToolBox) GetSyntheticParametersFromFunctionContext(ctx context.Context) map[string]string {
|
||||
val := ctx.Value("syntheticParameters")
|
||||
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
syntheticParameters, ok := val.(map[string]string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return syntheticParameters
|
||||
}
|
||||
|
||||
// ExecuteCallbacks will execute all the tool calls in the given list, and call the given callbacks when a new function is created, and when a function is finished.
|
||||
// OnNewFunction is called when a new function is created
|
||||
// OnFunctionFinished is called when a function is finished
|
||||
func (t *ToolBox) ExecuteCallbacks(ctx *Context, toolCalls []ToolCall, OnNewFunction func(ctx context.Context, funcName string, parameter string) (any, error), OnFunctionFinished func(ctx context.Context, funcName string, parameter string, result any, err error, newFunctionResult any) error) ([]ToolCallResponse, error) {
|
||||
func (t ToolBox) ExecuteCallbacks(ctx *Context, toolCalls []ToolCall, OnNewFunction func(ctx context.Context, funcName string, parameter string) (any, error), OnFunctionFinished func(ctx context.Context, funcName string, parameter string, result any, err error, newFunctionResult any) error) ([]ToolCallResponse, error) {
|
||||
var res []ToolCallResponse
|
||||
|
||||
for _, call := range toolCalls {
|
||||
|
Reference in New Issue
Block a user