diff --git a/openai.go b/openai.go index 812f97e..decbf55 100644 --- a/openai.go +++ b/openai.go @@ -3,7 +3,6 @@ package go_llm import ( "context" "fmt" - "log/slog" "strings" oai "github.com/sashabaranov/go-openai" @@ -40,8 +39,10 @@ func (o openaiImpl) newRequestToOpenAIRequest(request Request) oai.ChatCompletio Parameters: tool.Parameters.Definition(), }, }) + } - fmt.Println("tool:", tool.Name, tool.Description, tool.Strict, tool.Parameters.Definition()) + if !request.Toolbox.dontRequireTool { + res.ToolChoice = "required" } } @@ -72,7 +73,6 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R for _, choice := range response.Choices { var toolCalls []ToolCall for _, call := range choice.Message.ToolCalls { - fmt.Println("responseToLLMResponse: call:", call.Function.Arguments) toolCall := ToolCall{ ID: call.ID, FunctionCall: FunctionCall{ @@ -81,8 +81,6 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R }, } - fmt.Println("toolCall.FunctionCall.Arguments:", toolCall.FunctionCall.Arguments) - toolCalls = append(toolCalls, toolCall) } @@ -103,11 +101,8 @@ func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response req := o.newRequestToOpenAIRequest(request) - slog.Info("openaiImpl.ChatComplete", "req", fmt.Sprintf("%#v", req)) resp, err := cl.CreateChatCompletion(ctx, req) - fmt.Println("resp:", fmt.Sprintf("%#v", resp)) - if err != nil { return Response{}, fmt.Errorf("unhandled openaiImpl error: %w", err) } diff --git a/toolbox.go b/toolbox.go index 208219f..e4af6e7 100644 --- a/toolbox.go +++ b/toolbox.go @@ -1,8 +1,10 @@ package go_llm import ( + "context" "errors" "fmt" + "github.com/sashabaranov/go-openai" ) @@ -10,8 +12,9 @@ import ( // It is a wrapper around a collection of functions, and provides a way to automatically call the correct function with // the correct parameters. type ToolBox struct { - funcs []Function - names map[string]Function + funcs []Function + names map[string]Function + dontRequireTool bool } func NewToolBox(fns ...*Function) *ToolBox { @@ -37,6 +40,27 @@ func (t *ToolBox) WithFunction(f Function) *ToolBox { return &t2 } +func (t *ToolBox) WithFunctionRemoved(name string) *ToolBox { + t2 := *t + + delete(t2.names, name) + + for i, f := range t2.funcs { + if f.Name == name { + t2.funcs = append(t2.funcs[:i], t2.funcs[i+1:]...) + break + } + } + + return &t2 +} + +func (t *ToolBox) WithRequireTool(val bool) *ToolBox { + t2 := *t + t2.dontRequireTool = !val + return &t2 +} + // ToOpenAI will convert the current ToolBox to a slice of openai.Tool, which can be used to send to the OpenAI API. func (t *ToolBox) toOpenAI() []openai.Tool { var res []openai.Tool @@ -76,3 +100,43 @@ func (t *ToolBox) executeFunction(ctx *Context, functionName string, params stri func (t *ToolBox) Execute(ctx *Context, toolCall ToolCall) (string, error) { return t.executeFunction(ctx.WithToolCall(&toolCall), toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments) } + +// ExecuteCallbacks will execute all the tool calls in the given list, and call the given callbacks when a new function is created, and when a function is finished. +// OnNewFunction is called when a new function is created +// OnFunctionFinished is called when a function is finished +func (t *ToolBox) ExecuteCallbacks(ctx *Context, toolCalls []ToolCall, OnNewFunction func(ctx context.Context, funcName string, parameter string) (any, error), OnFunctionFinished func(ctx context.Context, funcName string, parameter string, result string, err error, newFunctionResult any) error) ([]ToolCallResponse, error) { + var res []ToolCallResponse + + for _, call := range toolCalls { + ctx := ctx.WithToolCall(&call) + if call.FunctionCall.Name == "" { + return nil, newError(ErrFunctionNotFound, errors.New("function name is empty")) + } + + var arg any + if OnNewFunction != nil { + var err error + arg, err = OnNewFunction(ctx, call.FunctionCall.Name, call.FunctionCall.Arguments) + + if err != nil { + return nil, newError(ErrFunctionNotFound, err) + } + } + out, err := t.Execute(ctx, call) + + if OnFunctionFinished != nil { + err := OnFunctionFinished(ctx, call.FunctionCall.Name, call.FunctionCall.Arguments, out, err, arg) + if err != nil { + return nil, newError(ErrFunctionNotFound, err) + } + } + + res = append(res, ToolCallResponse{ + ID: call.ID, + Result: out, + Error: err, + }) + } + + return res, nil +}