Files
go-llm/v2/tool_coerce.go
T
steve 5c5d861915
CI / Root Module (push) Failing after 30s
CI / Lint (push) Failing after 3s
CI / V2 Module (push) Successful in 1m54s
fix(v2): coerce string-encoded numbers/bools in tool arguments
LLMs occasionally return numeric or boolean tool-call fields as JSON
strings (e.g. "3" instead of 3, "true" instead of true), which Go's
strict json.Unmarshal rejects. The strict unmarshal stays as the happy
path; on failure we retry with a coercion pass that walks the target
struct (recursing into nested structs, slices, maps, and pointer fields)
and converts strings to the appropriate kind. Returns the original error
if coercion can't recover.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-27 22:12:56 +00:00

135 lines
2.8 KiB
Go

package llm
import (
"encoding/json"
"reflect"
"strconv"
"strings"
)
// coerceArgsToType reparses argsJSON with leniency: where the target struct
// expects a numeric or boolean field but the JSON value is a string, it
// converts the string to the target kind. Recurses into nested structs,
// slices, maps, and pointer fields.
//
// Returns a freshly marshaled JSON byte slice that can be unmarshaled into
// the target type with strict json.Unmarshal.
func coerceArgsToType(argsJSON []byte, target reflect.Type) ([]byte, error) {
var raw any
if err := json.Unmarshal(argsJSON, &raw); err != nil {
return nil, err
}
raw = coerceValue(raw, target)
return json.Marshal(raw)
}
func coerceValue(v any, t reflect.Type) any {
if t == nil {
return v
}
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
switch t.Kind() {
case reflect.Struct:
m, ok := v.(map[string]any)
if !ok {
return v
}
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if !f.IsExported() {
continue
}
name := jsonFieldName(f)
if name == "-" {
continue
}
if val, present := m[name]; present {
m[name] = coerceValue(val, f.Type)
}
}
return m
case reflect.Slice, reflect.Array:
arr, ok := v.([]any)
if !ok {
return v
}
elemType := t.Elem()
for i := range arr {
arr[i] = coerceValue(arr[i], elemType)
}
return arr
case reflect.Map:
m, ok := v.(map[string]any)
if !ok {
return v
}
valType := t.Elem()
for k := range m {
m[k] = coerceValue(m[k], valType)
}
return m
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if s, ok := v.(string); ok {
s = strings.TrimSpace(s)
s = strings.TrimPrefix(s, "+")
if n, err := strconv.ParseInt(s, 10, 64); err == nil {
return n
}
if f, err := strconv.ParseFloat(s, 64); err == nil {
return int64(f)
}
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if s, ok := v.(string); ok {
s = strings.TrimSpace(s)
s = strings.TrimPrefix(s, "+")
if n, err := strconv.ParseUint(s, 10, 64); err == nil {
return n
}
if f, err := strconv.ParseFloat(s, 64); err == nil && f >= 0 {
return uint64(f)
}
}
case reflect.Float32, reflect.Float64:
if s, ok := v.(string); ok {
s = strings.TrimSpace(s)
if f, err := strconv.ParseFloat(s, 64); err == nil {
return f
}
}
case reflect.Bool:
if s, ok := v.(string); ok {
if b, err := strconv.ParseBool(strings.TrimSpace(s)); err == nil {
return b
}
}
}
return v
}
func jsonFieldName(f reflect.StructField) string {
tag := f.Tag.Get("json")
if tag == "" {
return f.Name
}
if idx := strings.Index(tag, ","); idx >= 0 {
tag = tag[:idx]
}
if tag == "-" {
return "-"
}
if tag == "" {
return f.Name
}
return tag
}