fix(v2): coerce string-encoded numbers/bools in tool arguments
LLMs occasionally return numeric or boolean tool-call fields as JSON strings (e.g. "3" instead of 3, "true" instead of true), which Go's strict json.Unmarshal rejects. The strict unmarshal stays as the happy path; on failure we retry with a coercion pass that walks the target struct (recursing into nested structs, slices, maps, and pointer fields) and converts strings to the appropriate kind. Returns the original error if coercion can't recover. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
+9
-1
@@ -104,10 +104,18 @@ func (t Tool) Execute(ctx context.Context, argsJSON string) (string, error) {
|
|||||||
// Typed tool: unmarshal JSON into the struct, call the function
|
// Typed tool: unmarshal JSON into the struct, call the function
|
||||||
p := reflect.New(t.pTyp)
|
p := reflect.New(t.pTyp)
|
||||||
if argsJSON != "" && argsJSON != "{}" {
|
if argsJSON != "" && argsJSON != "{}" {
|
||||||
if err := json.Unmarshal([]byte(argsJSON), p.Interface()); err != nil {
|
err := json.Unmarshal([]byte(argsJSON), p.Interface())
|
||||||
|
if err != nil {
|
||||||
|
// LLMs sometimes return numeric/boolean fields as JSON strings
|
||||||
|
// (e.g. "3" instead of 3). Retry with type coercion.
|
||||||
|
if coerced, cerr := coerceArgsToType([]byte(argsJSON), t.pTyp); cerr == nil {
|
||||||
|
err = json.Unmarshal(coerced, p.Interface())
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
return "", fmt.Errorf("invalid tool arguments: %w", err)
|
return "", fmt.Errorf("invalid tool arguments: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
out := t.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()})
|
out := t.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()})
|
||||||
if !out[1].IsNil() {
|
if !out[1].IsNil() {
|
||||||
return "", out[1].Interface().(error)
|
return "", out[1].Interface().(error)
|
||||||
|
|||||||
@@ -0,0 +1,134 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// coerceArgsToType reparses argsJSON with leniency: where the target struct
|
||||||
|
// expects a numeric or boolean field but the JSON value is a string, it
|
||||||
|
// converts the string to the target kind. Recurses into nested structs,
|
||||||
|
// slices, maps, and pointer fields.
|
||||||
|
//
|
||||||
|
// Returns a freshly marshaled JSON byte slice that can be unmarshaled into
|
||||||
|
// the target type with strict json.Unmarshal.
|
||||||
|
func coerceArgsToType(argsJSON []byte, target reflect.Type) ([]byte, error) {
|
||||||
|
var raw any
|
||||||
|
if err := json.Unmarshal(argsJSON, &raw); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
raw = coerceValue(raw, target)
|
||||||
|
return json.Marshal(raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func coerceValue(v any, t reflect.Type) any {
|
||||||
|
if t == nil {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
for t.Kind() == reflect.Ptr {
|
||||||
|
t = t.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch t.Kind() {
|
||||||
|
case reflect.Struct:
|
||||||
|
m, ok := v.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
for i := 0; i < t.NumField(); i++ {
|
||||||
|
f := t.Field(i)
|
||||||
|
if !f.IsExported() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := jsonFieldName(f)
|
||||||
|
if name == "-" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if val, present := m[name]; present {
|
||||||
|
m[name] = coerceValue(val, f.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
arr, ok := v.([]any)
|
||||||
|
if !ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
elemType := t.Elem()
|
||||||
|
for i := range arr {
|
||||||
|
arr[i] = coerceValue(arr[i], elemType)
|
||||||
|
}
|
||||||
|
return arr
|
||||||
|
|
||||||
|
case reflect.Map:
|
||||||
|
m, ok := v.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
valType := t.Elem()
|
||||||
|
for k := range m {
|
||||||
|
m[k] = coerceValue(m[k], valType)
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
s = strings.TrimPrefix(s, "+")
|
||||||
|
if n, err := strconv.ParseInt(s, 10, 64); err == nil {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
if f, err := strconv.ParseFloat(s, 64); err == nil {
|
||||||
|
return int64(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
s = strings.TrimPrefix(s, "+")
|
||||||
|
if n, err := strconv.ParseUint(s, 10, 64); err == nil {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
if f, err := strconv.ParseFloat(s, 64); err == nil && f >= 0 {
|
||||||
|
return uint64(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
if f, err := strconv.ParseFloat(s, 64); err == nil {
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Bool:
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
if b, err := strconv.ParseBool(strings.TrimSpace(s)); err == nil {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func jsonFieldName(f reflect.StructField) string {
|
||||||
|
tag := f.Tag.Get("json")
|
||||||
|
if tag == "" {
|
||||||
|
return f.Name
|
||||||
|
}
|
||||||
|
if idx := strings.Index(tag, ","); idx >= 0 {
|
||||||
|
tag = tag[:idx]
|
||||||
|
}
|
||||||
|
if tag == "-" {
|
||||||
|
return "-"
|
||||||
|
}
|
||||||
|
if tag == "" {
|
||||||
|
return f.Name
|
||||||
|
}
|
||||||
|
return tag
|
||||||
|
}
|
||||||
@@ -0,0 +1,130 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExecuteCoercesStringNumbers(t *testing.T) {
|
||||||
|
type params struct {
|
||||||
|
Memory string `json:"memory"`
|
||||||
|
ReplaceMemoryID *uint `json:"replace_memory_id,omitempty"`
|
||||||
|
RelationshipChange int `json:"relationship_change"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var got params
|
||||||
|
tool := Define("process", "test",
|
||||||
|
func(ctx context.Context, p params) (string, error) {
|
||||||
|
got = p
|
||||||
|
return "ok", nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
args string
|
||||||
|
wantInt int
|
||||||
|
wantUint uint
|
||||||
|
}{
|
||||||
|
{"int as string", `{"memory":"x","relationship_change":"3"}`, 3, 0},
|
||||||
|
{"int as string with plus", `{"memory":"x","relationship_change":"+3"}`, 3, 0},
|
||||||
|
{"int as string negative", `{"memory":"x","relationship_change":"-2"}`, -2, 0},
|
||||||
|
{"int as string with whitespace", `{"memory":"x","relationship_change":" 4 "}`, 4, 0},
|
||||||
|
{"int as string with decimal", `{"memory":"x","relationship_change":"2.7"}`, 2, 0},
|
||||||
|
{"native int still works", `{"memory":"x","relationship_change":5}`, 5, 0},
|
||||||
|
{"pointer uint as string", `{"memory":"x","replace_memory_id":"42","relationship_change":0}`, 0, 42},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got = params{}
|
||||||
|
result, err := tool.Execute(context.Background(), tc.args)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execute failed: %v", err)
|
||||||
|
}
|
||||||
|
if result != "ok" {
|
||||||
|
t.Errorf("expected 'ok', got %q", result)
|
||||||
|
}
|
||||||
|
if got.RelationshipChange != tc.wantInt {
|
||||||
|
t.Errorf("RelationshipChange: want %d, got %d", tc.wantInt, got.RelationshipChange)
|
||||||
|
}
|
||||||
|
if tc.wantUint != 0 {
|
||||||
|
if got.ReplaceMemoryID == nil || *got.ReplaceMemoryID != tc.wantUint {
|
||||||
|
t.Errorf("ReplaceMemoryID: want %d, got %v", tc.wantUint, got.ReplaceMemoryID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecuteCoercesStringBoolAndFloat(t *testing.T) {
|
||||||
|
type params struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Ratio float64 `json:"ratio"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var got params
|
||||||
|
tool := Define("cfg", "test",
|
||||||
|
func(ctx context.Context, p params) (string, error) {
|
||||||
|
got = p
|
||||||
|
return "ok", nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if _, err := tool.Execute(context.Background(), `{"enabled":"true","ratio":"0.5"}`); err != nil {
|
||||||
|
t.Fatalf("execute failed: %v", err)
|
||||||
|
}
|
||||||
|
if !got.Enabled {
|
||||||
|
t.Errorf("expected enabled=true, got false")
|
||||||
|
}
|
||||||
|
if got.Ratio != 0.5 {
|
||||||
|
t.Errorf("expected ratio=0.5, got %v", got.Ratio)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecuteCoercesNestedAndSlices(t *testing.T) {
|
||||||
|
type inner struct {
|
||||||
|
N int `json:"n"`
|
||||||
|
}
|
||||||
|
type params struct {
|
||||||
|
Items []inner `json:"items"`
|
||||||
|
Tags []int `json:"tags"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var got params
|
||||||
|
tool := Define("nest", "test",
|
||||||
|
func(ctx context.Context, p params) (string, error) {
|
||||||
|
got = p
|
||||||
|
return "ok", nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
args := `{"items":[{"n":"1"},{"n":"2"}],"tags":["10","20"]}`
|
||||||
|
if _, err := tool.Execute(context.Background(), args); err != nil {
|
||||||
|
t.Fatalf("execute failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(got.Items) != 2 || got.Items[0].N != 1 || got.Items[1].N != 2 {
|
||||||
|
t.Errorf("nested struct coercion failed: %+v", got.Items)
|
||||||
|
}
|
||||||
|
if len(got.Tags) != 2 || got.Tags[0] != 10 || got.Tags[1] != 20 {
|
||||||
|
t.Errorf("slice element coercion failed: %+v", got.Tags)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecuteUnrecoverableArgsErrors(t *testing.T) {
|
||||||
|
type params struct {
|
||||||
|
N int `json:"n"`
|
||||||
|
}
|
||||||
|
tool := Define("bad", "test",
|
||||||
|
func(ctx context.Context, p params) (string, error) {
|
||||||
|
return "ok", nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if _, err := tool.Execute(context.Background(), `{"n":"not-a-number"}`); err == nil {
|
||||||
|
t.Errorf("expected error for unparseable string")
|
||||||
|
}
|
||||||
|
if _, err := tool.Execute(context.Background(), `{not json`); err == nil {
|
||||||
|
t.Errorf("expected error for malformed JSON")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user