diff --git a/v2/tool.go b/v2/tool.go index c90f1c7..072d970 100644 --- a/v2/tool.go +++ b/v2/tool.go @@ -104,8 +104,16 @@ func (t Tool) Execute(ctx context.Context, argsJSON string) (string, error) { // Typed tool: unmarshal JSON into the struct, call the function p := reflect.New(t.pTyp) if argsJSON != "" && argsJSON != "{}" { - if err := json.Unmarshal([]byte(argsJSON), p.Interface()); err != nil { - return "", fmt.Errorf("invalid tool arguments: %w", err) + err := json.Unmarshal([]byte(argsJSON), p.Interface()) + if err != nil { + // LLMs sometimes return numeric/boolean fields as JSON strings + // (e.g. "3" instead of 3). Retry with type coercion. + if coerced, cerr := coerceArgsToType([]byte(argsJSON), t.pTyp); cerr == nil { + err = json.Unmarshal(coerced, p.Interface()) + } + if err != nil { + return "", fmt.Errorf("invalid tool arguments: %w", err) + } } } out := t.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()}) diff --git a/v2/tool_coerce.go b/v2/tool_coerce.go new file mode 100644 index 0000000..802a51b --- /dev/null +++ b/v2/tool_coerce.go @@ -0,0 +1,134 @@ +package llm + +import ( + "encoding/json" + "reflect" + "strconv" + "strings" +) + +// coerceArgsToType reparses argsJSON with leniency: where the target struct +// expects a numeric or boolean field but the JSON value is a string, it +// converts the string to the target kind. Recurses into nested structs, +// slices, maps, and pointer fields. +// +// Returns a freshly marshaled JSON byte slice that can be unmarshaled into +// the target type with strict json.Unmarshal. +func coerceArgsToType(argsJSON []byte, target reflect.Type) ([]byte, error) { + var raw any + if err := json.Unmarshal(argsJSON, &raw); err != nil { + return nil, err + } + raw = coerceValue(raw, target) + return json.Marshal(raw) +} + +func coerceValue(v any, t reflect.Type) any { + if t == nil { + return v + } + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + switch t.Kind() { + case reflect.Struct: + m, ok := v.(map[string]any) + if !ok { + return v + } + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if !f.IsExported() { + continue + } + name := jsonFieldName(f) + if name == "-" { + continue + } + if val, present := m[name]; present { + m[name] = coerceValue(val, f.Type) + } + } + return m + + case reflect.Slice, reflect.Array: + arr, ok := v.([]any) + if !ok { + return v + } + elemType := t.Elem() + for i := range arr { + arr[i] = coerceValue(arr[i], elemType) + } + return arr + + case reflect.Map: + m, ok := v.(map[string]any) + if !ok { + return v + } + valType := t.Elem() + for k := range m { + m[k] = coerceValue(m[k], valType) + } + return m + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if s, ok := v.(string); ok { + s = strings.TrimSpace(s) + s = strings.TrimPrefix(s, "+") + if n, err := strconv.ParseInt(s, 10, 64); err == nil { + return n + } + if f, err := strconv.ParseFloat(s, 64); err == nil { + return int64(f) + } + } + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if s, ok := v.(string); ok { + s = strings.TrimSpace(s) + s = strings.TrimPrefix(s, "+") + if n, err := strconv.ParseUint(s, 10, 64); err == nil { + return n + } + if f, err := strconv.ParseFloat(s, 64); err == nil && f >= 0 { + return uint64(f) + } + } + + case reflect.Float32, reflect.Float64: + if s, ok := v.(string); ok { + s = strings.TrimSpace(s) + if f, err := strconv.ParseFloat(s, 64); err == nil { + return f + } + } + + case reflect.Bool: + if s, ok := v.(string); ok { + if b, err := strconv.ParseBool(strings.TrimSpace(s)); err == nil { + return b + } + } + } + return v +} + +func jsonFieldName(f reflect.StructField) string { + tag := f.Tag.Get("json") + if tag == "" { + return f.Name + } + if idx := strings.Index(tag, ","); idx >= 0 { + tag = tag[:idx] + } + if tag == "-" { + return "-" + } + if tag == "" { + return f.Name + } + return tag +} diff --git a/v2/tool_coerce_test.go b/v2/tool_coerce_test.go new file mode 100644 index 0000000..b5affc3 --- /dev/null +++ b/v2/tool_coerce_test.go @@ -0,0 +1,130 @@ +package llm + +import ( + "context" + "testing" +) + +func TestExecuteCoercesStringNumbers(t *testing.T) { + type params struct { + Memory string `json:"memory"` + ReplaceMemoryID *uint `json:"replace_memory_id,omitempty"` + RelationshipChange int `json:"relationship_change"` + } + + var got params + tool := Define("process", "test", + func(ctx context.Context, p params) (string, error) { + got = p + return "ok", nil + }, + ) + + cases := []struct { + name string + args string + wantInt int + wantUint uint + }{ + {"int as string", `{"memory":"x","relationship_change":"3"}`, 3, 0}, + {"int as string with plus", `{"memory":"x","relationship_change":"+3"}`, 3, 0}, + {"int as string negative", `{"memory":"x","relationship_change":"-2"}`, -2, 0}, + {"int as string with whitespace", `{"memory":"x","relationship_change":" 4 "}`, 4, 0}, + {"int as string with decimal", `{"memory":"x","relationship_change":"2.7"}`, 2, 0}, + {"native int still works", `{"memory":"x","relationship_change":5}`, 5, 0}, + {"pointer uint as string", `{"memory":"x","replace_memory_id":"42","relationship_change":0}`, 0, 42}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got = params{} + result, err := tool.Execute(context.Background(), tc.args) + if err != nil { + t.Fatalf("execute failed: %v", err) + } + if result != "ok" { + t.Errorf("expected 'ok', got %q", result) + } + if got.RelationshipChange != tc.wantInt { + t.Errorf("RelationshipChange: want %d, got %d", tc.wantInt, got.RelationshipChange) + } + if tc.wantUint != 0 { + if got.ReplaceMemoryID == nil || *got.ReplaceMemoryID != tc.wantUint { + t.Errorf("ReplaceMemoryID: want %d, got %v", tc.wantUint, got.ReplaceMemoryID) + } + } + }) + } +} + +func TestExecuteCoercesStringBoolAndFloat(t *testing.T) { + type params struct { + Enabled bool `json:"enabled"` + Ratio float64 `json:"ratio"` + } + + var got params + tool := Define("cfg", "test", + func(ctx context.Context, p params) (string, error) { + got = p + return "ok", nil + }, + ) + + if _, err := tool.Execute(context.Background(), `{"enabled":"true","ratio":"0.5"}`); err != nil { + t.Fatalf("execute failed: %v", err) + } + if !got.Enabled { + t.Errorf("expected enabled=true, got false") + } + if got.Ratio != 0.5 { + t.Errorf("expected ratio=0.5, got %v", got.Ratio) + } +} + +func TestExecuteCoercesNestedAndSlices(t *testing.T) { + type inner struct { + N int `json:"n"` + } + type params struct { + Items []inner `json:"items"` + Tags []int `json:"tags"` + } + + var got params + tool := Define("nest", "test", + func(ctx context.Context, p params) (string, error) { + got = p + return "ok", nil + }, + ) + + args := `{"items":[{"n":"1"},{"n":"2"}],"tags":["10","20"]}` + if _, err := tool.Execute(context.Background(), args); err != nil { + t.Fatalf("execute failed: %v", err) + } + if len(got.Items) != 2 || got.Items[0].N != 1 || got.Items[1].N != 2 { + t.Errorf("nested struct coercion failed: %+v", got.Items) + } + if len(got.Tags) != 2 || got.Tags[0] != 10 || got.Tags[1] != 20 { + t.Errorf("slice element coercion failed: %+v", got.Tags) + } +} + +func TestExecuteUnrecoverableArgsErrors(t *testing.T) { + type params struct { + N int `json:"n"` + } + tool := Define("bad", "test", + func(ctx context.Context, p params) (string, error) { + return "ok", nil + }, + ) + + if _, err := tool.Execute(context.Background(), `{"n":"not-a-number"}`); err == nil { + t.Errorf("expected error for unparseable string") + } + if _, err := tool.Execute(context.Background(), `{not json`); err == nil { + t.Errorf("expected error for malformed JSON") + } +}