From 5c5d86191579a926017d4c7b4c67cf2739377f27 Mon Sep 17 00:00:00 2001 From: Steve Dudenhoeffer Date: Mon, 27 Apr 2026 22:12:56 +0000 Subject: [PATCH] fix(v2): coerce string-encoded numbers/bools in tool arguments LLMs occasionally return numeric or boolean tool-call fields as JSON strings (e.g. "3" instead of 3, "true" instead of true), which Go's strict json.Unmarshal rejects. The strict unmarshal stays as the happy path; on failure we retry with a coercion pass that walks the target struct (recursing into nested structs, slices, maps, and pointer fields) and converts strings to the appropriate kind. Returns the original error if coercion can't recover. Co-Authored-By: Claude Opus 4.7 --- v2/tool.go | 12 +++- v2/tool_coerce.go | 134 +++++++++++++++++++++++++++++++++++++++++ v2/tool_coerce_test.go | 130 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 274 insertions(+), 2 deletions(-) create mode 100644 v2/tool_coerce.go create mode 100644 v2/tool_coerce_test.go diff --git a/v2/tool.go b/v2/tool.go index c90f1c7..072d970 100644 --- a/v2/tool.go +++ b/v2/tool.go @@ -104,8 +104,16 @@ func (t Tool) Execute(ctx context.Context, argsJSON string) (string, error) { // Typed tool: unmarshal JSON into the struct, call the function p := reflect.New(t.pTyp) if argsJSON != "" && argsJSON != "{}" { - if err := json.Unmarshal([]byte(argsJSON), p.Interface()); err != nil { - return "", fmt.Errorf("invalid tool arguments: %w", err) + err := json.Unmarshal([]byte(argsJSON), p.Interface()) + if err != nil { + // LLMs sometimes return numeric/boolean fields as JSON strings + // (e.g. "3" instead of 3). Retry with type coercion. + if coerced, cerr := coerceArgsToType([]byte(argsJSON), t.pTyp); cerr == nil { + err = json.Unmarshal(coerced, p.Interface()) + } + if err != nil { + return "", fmt.Errorf("invalid tool arguments: %w", err) + } } } out := t.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()}) diff --git a/v2/tool_coerce.go b/v2/tool_coerce.go new file mode 100644 index 0000000..802a51b --- /dev/null +++ b/v2/tool_coerce.go @@ -0,0 +1,134 @@ +package llm + +import ( + "encoding/json" + "reflect" + "strconv" + "strings" +) + +// coerceArgsToType reparses argsJSON with leniency: where the target struct +// expects a numeric or boolean field but the JSON value is a string, it +// converts the string to the target kind. Recurses into nested structs, +// slices, maps, and pointer fields. +// +// Returns a freshly marshaled JSON byte slice that can be unmarshaled into +// the target type with strict json.Unmarshal. +func coerceArgsToType(argsJSON []byte, target reflect.Type) ([]byte, error) { + var raw any + if err := json.Unmarshal(argsJSON, &raw); err != nil { + return nil, err + } + raw = coerceValue(raw, target) + return json.Marshal(raw) +} + +func coerceValue(v any, t reflect.Type) any { + if t == nil { + return v + } + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + switch t.Kind() { + case reflect.Struct: + m, ok := v.(map[string]any) + if !ok { + return v + } + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if !f.IsExported() { + continue + } + name := jsonFieldName(f) + if name == "-" { + continue + } + if val, present := m[name]; present { + m[name] = coerceValue(val, f.Type) + } + } + return m + + case reflect.Slice, reflect.Array: + arr, ok := v.([]any) + if !ok { + return v + } + elemType := t.Elem() + for i := range arr { + arr[i] = coerceValue(arr[i], elemType) + } + return arr + + case reflect.Map: + m, ok := v.(map[string]any) + if !ok { + return v + } + valType := t.Elem() + for k := range m { + m[k] = coerceValue(m[k], valType) + } + return m + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if s, ok := v.(string); ok { + s = strings.TrimSpace(s) + s = strings.TrimPrefix(s, "+") + if n, err := strconv.ParseInt(s, 10, 64); err == nil { + return n + } + if f, err := strconv.ParseFloat(s, 64); err == nil { + return int64(f) + } + } + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if s, ok := v.(string); ok { + s = strings.TrimSpace(s) + s = strings.TrimPrefix(s, "+") + if n, err := strconv.ParseUint(s, 10, 64); err == nil { + return n + } + if f, err := strconv.ParseFloat(s, 64); err == nil && f >= 0 { + return uint64(f) + } + } + + case reflect.Float32, reflect.Float64: + if s, ok := v.(string); ok { + s = strings.TrimSpace(s) + if f, err := strconv.ParseFloat(s, 64); err == nil { + return f + } + } + + case reflect.Bool: + if s, ok := v.(string); ok { + if b, err := strconv.ParseBool(strings.TrimSpace(s)); err == nil { + return b + } + } + } + return v +} + +func jsonFieldName(f reflect.StructField) string { + tag := f.Tag.Get("json") + if tag == "" { + return f.Name + } + if idx := strings.Index(tag, ","); idx >= 0 { + tag = tag[:idx] + } + if tag == "-" { + return "-" + } + if tag == "" { + return f.Name + } + return tag +} diff --git a/v2/tool_coerce_test.go b/v2/tool_coerce_test.go new file mode 100644 index 0000000..b5affc3 --- /dev/null +++ b/v2/tool_coerce_test.go @@ -0,0 +1,130 @@ +package llm + +import ( + "context" + "testing" +) + +func TestExecuteCoercesStringNumbers(t *testing.T) { + type params struct { + Memory string `json:"memory"` + ReplaceMemoryID *uint `json:"replace_memory_id,omitempty"` + RelationshipChange int `json:"relationship_change"` + } + + var got params + tool := Define("process", "test", + func(ctx context.Context, p params) (string, error) { + got = p + return "ok", nil + }, + ) + + cases := []struct { + name string + args string + wantInt int + wantUint uint + }{ + {"int as string", `{"memory":"x","relationship_change":"3"}`, 3, 0}, + {"int as string with plus", `{"memory":"x","relationship_change":"+3"}`, 3, 0}, + {"int as string negative", `{"memory":"x","relationship_change":"-2"}`, -2, 0}, + {"int as string with whitespace", `{"memory":"x","relationship_change":" 4 "}`, 4, 0}, + {"int as string with decimal", `{"memory":"x","relationship_change":"2.7"}`, 2, 0}, + {"native int still works", `{"memory":"x","relationship_change":5}`, 5, 0}, + {"pointer uint as string", `{"memory":"x","replace_memory_id":"42","relationship_change":0}`, 0, 42}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got = params{} + result, err := tool.Execute(context.Background(), tc.args) + if err != nil { + t.Fatalf("execute failed: %v", err) + } + if result != "ok" { + t.Errorf("expected 'ok', got %q", result) + } + if got.RelationshipChange != tc.wantInt { + t.Errorf("RelationshipChange: want %d, got %d", tc.wantInt, got.RelationshipChange) + } + if tc.wantUint != 0 { + if got.ReplaceMemoryID == nil || *got.ReplaceMemoryID != tc.wantUint { + t.Errorf("ReplaceMemoryID: want %d, got %v", tc.wantUint, got.ReplaceMemoryID) + } + } + }) + } +} + +func TestExecuteCoercesStringBoolAndFloat(t *testing.T) { + type params struct { + Enabled bool `json:"enabled"` + Ratio float64 `json:"ratio"` + } + + var got params + tool := Define("cfg", "test", + func(ctx context.Context, p params) (string, error) { + got = p + return "ok", nil + }, + ) + + if _, err := tool.Execute(context.Background(), `{"enabled":"true","ratio":"0.5"}`); err != nil { + t.Fatalf("execute failed: %v", err) + } + if !got.Enabled { + t.Errorf("expected enabled=true, got false") + } + if got.Ratio != 0.5 { + t.Errorf("expected ratio=0.5, got %v", got.Ratio) + } +} + +func TestExecuteCoercesNestedAndSlices(t *testing.T) { + type inner struct { + N int `json:"n"` + } + type params struct { + Items []inner `json:"items"` + Tags []int `json:"tags"` + } + + var got params + tool := Define("nest", "test", + func(ctx context.Context, p params) (string, error) { + got = p + return "ok", nil + }, + ) + + args := `{"items":[{"n":"1"},{"n":"2"}],"tags":["10","20"]}` + if _, err := tool.Execute(context.Background(), args); err != nil { + t.Fatalf("execute failed: %v", err) + } + if len(got.Items) != 2 || got.Items[0].N != 1 || got.Items[1].N != 2 { + t.Errorf("nested struct coercion failed: %+v", got.Items) + } + if len(got.Tags) != 2 || got.Tags[0] != 10 || got.Tags[1] != 20 { + t.Errorf("slice element coercion failed: %+v", got.Tags) + } +} + +func TestExecuteUnrecoverableArgsErrors(t *testing.T) { + type params struct { + N int `json:"n"` + } + tool := Define("bad", "test", + func(ctx context.Context, p params) (string, error) { + return "ok", nil + }, + ) + + if _, err := tool.Execute(context.Background(), `{"n":"not-a-number"}`); err == nil { + t.Errorf("expected error for unparseable string") + } + if _, err := tool.Execute(context.Background(), `{not json`); err == nil { + t.Errorf("expected error for malformed JSON") + } +}