Fix unmarshalling issues and adjust logging for debugging

Modify `FunctionCall` struct to handle arguments as strings. Add debugging logs to facilitate error tracing and improve JSON unmarshalling in various functions.
This commit is contained in:
Steve Dudenhoeffer 2024-11-11 00:23:01 -05:00
parent cd4ad59a38
commit 0993a8e865
5 changed files with 37 additions and 28 deletions

View File

@ -2,6 +2,7 @@ package go_llm
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
anth "github.com/liushuangls/go-anthropic/v2" anth "github.com/liushuangls/go-anthropic/v2"
"log" "log"
@ -120,13 +121,18 @@ func (a anthropic) responseToLLMResponse(in anth.MessagesResponse) Response {
case anth.MessagesContentTypeToolUse: case anth.MessagesContentTypeToolUse:
if msg.MessageContentToolUse != nil { if msg.MessageContentToolUse != nil {
choice.Calls = append(choice.Calls, ToolCall{ b, e := json.Marshal(msg.MessageContentToolUse.Input)
ID: msg.MessageContentToolUse.ID, if e != nil {
FunctionCall: FunctionCall{ log.Println("failed to marshal input", e)
Name: msg.MessageContentToolUse.Name, } else {
Arguments: msg.MessageContentToolUse.Input, choice.Calls = append(choice.Calls, ToolCall{
}, ID: msg.MessageContentToolUse.ID,
}) FunctionCall: FunctionCall{
Name: msg.MessageContentToolUse.Name,
Arguments: string(b),
},
})
}
} }
} }

View File

@ -37,16 +37,17 @@ func (f *Function) Execute(ctx context.Context, input string) (string, error) {
} }
// first, we need to parse the input into the struct // first, we need to parse the input into the struct
p := reflect.New(f.paramType).Elem() p := reflect.New(f.paramType)
fmt.Println("Function.Execute", f.Name, "input:", input)
//m := map[string]any{} //m := map[string]any{}
err := json.Unmarshal([]byte(input), p.Addr().Interface()) err := json.Unmarshal([]byte(input), p.Interface())
if err != nil { if err != nil {
return "", fmt.Errorf("failed to unmarshal input: %w", err) return "", fmt.Errorf("failed to unmarshal input: %w (input: %s)", err, input)
} }
// now we can call the function // now we can call the function
exec := func(ctx context.Context) (string, error) { exec := func(ctx context.Context) (string, error) {
out := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p}) out := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()})
if len(out) != 2 { if len(out) != 2 {
return "", fmt.Errorf("function %s must return two values, got %d", f.Name, len(out)) return "", fmt.Errorf("function %s must return two values, got %d", f.Name, len(out))
@ -87,5 +88,5 @@ func (f *Function) toOpenAIDefinition() jsonschema.Definition {
type FunctionCall struct { type FunctionCall struct {
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
Arguments any `json:"arguments,omitempty"` Arguments string `json:"arguments,omitempty"`
} }

View File

@ -2,6 +2,7 @@ package go_llm
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"github.com/google/generative-ai-go/genai" "github.com/google/generative-ai-go/genai"
"google.golang.org/api/option" "google.golang.org/api/option"
@ -63,11 +64,17 @@ func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Respon
choice := ResponseChoice{} choice := ResponseChoice{}
choice.Content = v.Name choice.Content = v.Name
b, e := json.Marshal(v.Args)
if e != nil {
return Response{}, fmt.Errorf("error marshalling args: %w", e)
}
call := ToolCall{ call := ToolCall{
ID: v.Name, ID: v.Name,
FunctionCall: FunctionCall{ FunctionCall: FunctionCall{
Name: v.Name, Name: v.Name,
Arguments: v.Args, Arguments: string(b),
}, },
} }

View File

@ -64,9 +64,11 @@ func (o openaiImpl) requestToOpenAIRequest(request Request) oai.ChatCompletionRe
Name: tool.Name, Name: tool.Name,
Description: tool.Description, Description: tool.Description,
Strict: tool.Strict, Strict: tool.Strict,
Parameters: tool.Parameters, Parameters: tool.Parameters.Definition(),
}, },
}) })
fmt.Println("tool:", tool.Name, tool.Description, tool.Strict, tool.Parameters.Definition())
} }
if request.Temperature != nil { if request.Temperature != nil {
@ -94,8 +96,9 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R
res := Response{} res := Response{}
for _, choice := range response.Choices { for _, choice := range response.Choices {
var tools []ToolCall var toolCalls []ToolCall
for _, call := range choice.Message.ToolCalls { for _, call := range choice.Message.ToolCalls {
fmt.Println("responseToLLMResponse: call:", call.Function.Arguments)
toolCall := ToolCall{ toolCall := ToolCall{
ID: call.ID, ID: call.ID,
FunctionCall: FunctionCall{ FunctionCall: FunctionCall{
@ -104,7 +107,9 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R
}, },
} }
tools = append(tools, toolCall) fmt.Println("toolCall.FunctionCall.Arguments:", toolCall.FunctionCall.Arguments)
toolCalls = append(toolCalls, toolCall)
} }
res.Choices = append(res.Choices, ResponseChoice{ res.Choices = append(res.Choices, ResponseChoice{
@ -112,6 +117,7 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R
Role: Role(choice.Message.Role), Role: Role(choice.Message.Role),
Name: choice.Message.Name, Name: choice.Message.Name,
Refusal: choice.Message.Refusal, Refusal: choice.Message.Refusal,
Calls: toolCalls,
}) })
} }

View File

@ -2,11 +2,9 @@ package go_llm
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
"log/slog"
) )
// ToolBox is a collection of tools that OpenAI can use to execute functions. // ToolBox is a collection of tools that OpenAI can use to execute functions.
@ -69,8 +67,6 @@ var (
func (t *ToolBox) ExecuteFunction(ctx context.Context, functionName string, params string) (string, error) { func (t *ToolBox) ExecuteFunction(ctx context.Context, functionName string, params string) (string, error) {
f, ok := t.names[functionName] f, ok := t.names[functionName]
slog.Info("functionName", functionName)
if !ok { if !ok {
return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName)) return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName))
} }
@ -79,12 +75,5 @@ func (t *ToolBox) ExecuteFunction(ctx context.Context, functionName string, para
} }
func (t *ToolBox) Execute(ctx context.Context, toolCall ToolCall) (string, error) { func (t *ToolBox) Execute(ctx context.Context, toolCall ToolCall) (string, error) {
slog.Info("toolCall", toolCall) return t.ExecuteFunction(ctx, toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments)
b, err := json.Marshal(toolCall.FunctionCall.Arguments)
if err != nil {
return "", fmt.Errorf("failed to marshal arguments: %w", err)
}
return t.ExecuteFunction(ctx, toolCall.ID, string(b))
} }