Files
go-llm/v2/tool.go
T
steve 5c5d861915
CI / Root Module (push) Failing after 30s
CI / Lint (push) Failing after 3s
CI / V2 Module (push) Successful in 1m54s
fix(v2): coerce string-encoded numbers/bools in tool arguments
LLMs occasionally return numeric or boolean tool-call fields as JSON
strings (e.g. "3" instead of 3, "true" instead of true), which Go's
strict json.Unmarshal rejects. The strict unmarshal stays as the happy
path; on failure we retry with a coercion pass that walks the target
struct (recursing into nested structs, slices, maps, and pointer fields)
and converts strings to the appropriate kind. Returns the original error
if coercion can't recover.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-27 22:12:56 +00:00

202 lines
5.4 KiB
Go

package llm
import (
"context"
"encoding/json"
"fmt"
"reflect"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/internal/schema"
)
// Tool defines a tool that the LLM can invoke.
type Tool struct {
// Name is the tool's unique identifier.
Name string
// Description tells the LLM what this tool does.
Description string
// Schema is the JSON Schema for the tool's parameters.
Schema map[string]any
// fn holds the implementation function (set via Define or DefineSimple).
fn reflect.Value
pTyp reflect.Type // nil for parameterless tools
// isMCP indicates this tool is provided by an MCP server.
isMCP bool
mcpServer *MCPServer
}
// Define creates a tool from a typed handler function.
// T must be a struct. Struct fields become the tool's parameters.
//
// Struct tags:
// - `json:"name"` — parameter name
// - `description:"..."` — parameter description
// - `enum:"a,b,c"` — enum constraint
//
// Pointer fields are optional; non-pointer fields are required.
//
// Example:
//
// type WeatherParams struct {
// City string `json:"city" description:"The city to query"`
// Unit string `json:"unit" description:"Temperature unit" enum:"celsius,fahrenheit"`
// }
//
// llm.Define[WeatherParams]("get_weather", "Get weather for a city",
// func(ctx context.Context, p WeatherParams) (string, error) {
// return fmt.Sprintf("72F in %s", p.City), nil
// },
// )
func Define[T any](name, description string, fn func(context.Context, T) (string, error)) Tool {
var zero T
return Tool{
Name: name,
Description: description,
Schema: schema.FromStruct(zero),
fn: reflect.ValueOf(fn),
pTyp: reflect.TypeOf(zero),
}
}
// DefineSimple creates a parameterless tool.
//
// Example:
//
// llm.DefineSimple("get_time", "Get the current time",
// func(ctx context.Context) (string, error) {
// return time.Now().Format(time.RFC3339), nil
// },
// )
func DefineSimple(name, description string, fn func(context.Context) (string, error)) Tool {
return Tool{
Name: name,
Description: description,
Schema: map[string]any{"type": "object", "properties": map[string]any{}},
fn: reflect.ValueOf(fn),
}
}
// Execute runs the tool with the given JSON arguments string.
func (t Tool) Execute(ctx context.Context, argsJSON string) (string, error) {
if t.isMCP {
var args map[string]any
if argsJSON != "" && argsJSON != "{}" {
if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
return "", fmt.Errorf("invalid MCP tool arguments: %w", err)
}
}
return t.mcpServer.CallTool(ctx, t.Name, args)
}
// Parameterless tool
if t.pTyp == nil {
out := t.fn.Call([]reflect.Value{reflect.ValueOf(ctx)})
if !out[1].IsNil() {
return "", out[1].Interface().(error)
}
return out[0].String(), nil
}
// Typed tool: unmarshal JSON into the struct, call the function
p := reflect.New(t.pTyp)
if argsJSON != "" && argsJSON != "{}" {
err := json.Unmarshal([]byte(argsJSON), p.Interface())
if err != nil {
// LLMs sometimes return numeric/boolean fields as JSON strings
// (e.g. "3" instead of 3). Retry with type coercion.
if coerced, cerr := coerceArgsToType([]byte(argsJSON), t.pTyp); cerr == nil {
err = json.Unmarshal(coerced, p.Interface())
}
if err != nil {
return "", fmt.Errorf("invalid tool arguments: %w", err)
}
}
}
out := t.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()})
if !out[1].IsNil() {
return "", out[1].Interface().(error)
}
return out[0].String(), nil
}
// ToolBox is a collection of tools available for use by an LLM.
type ToolBox struct {
tools map[string]Tool
mcpServers []*MCPServer
}
// NewToolBox creates a new ToolBox from the given tools.
func NewToolBox(tools ...Tool) *ToolBox {
tb := &ToolBox{tools: make(map[string]Tool)}
for _, t := range tools {
tb.tools[t.Name] = t
}
return tb
}
// Add adds tools to the toolbox and returns it for chaining.
func (tb *ToolBox) Add(tools ...Tool) *ToolBox {
if tb.tools == nil {
tb.tools = make(map[string]Tool)
}
for _, t := range tools {
tb.tools[t.Name] = t
}
return tb
}
// AddMCP adds an MCP server's tools to the toolbox. The server must be connected.
func (tb *ToolBox) AddMCP(server *MCPServer) *ToolBox {
if tb.tools == nil {
tb.tools = make(map[string]Tool)
}
tb.mcpServers = append(tb.mcpServers, server)
for _, tool := range server.ListTools() {
tb.tools[tool.Name] = tool
}
return tb
}
// AllTools returns all tools (local + MCP) as a slice.
func (tb *ToolBox) AllTools() []Tool {
if tb == nil {
return nil
}
tools := make([]Tool, 0, len(tb.tools))
for _, t := range tb.tools {
tools = append(tools, t)
}
return tools
}
// Execute executes a tool call by name.
func (tb *ToolBox) Execute(ctx context.Context, call ToolCall) (string, error) {
if tb == nil {
return "", ErrNoToolsConfigured
}
tool, ok := tb.tools[call.Name]
if !ok {
return "", fmt.Errorf("%w: %s", ErrToolNotFound, call.Name)
}
return tool.Execute(ctx, call.Arguments)
}
// ExecuteAll executes all tool calls and returns tool result messages.
func (tb *ToolBox) ExecuteAll(ctx context.Context, calls []ToolCall) ([]Message, error) {
var results []Message
for _, call := range calls {
result, err := tb.Execute(ctx, call)
text := result
if err != nil {
text = "Error: " + err.Error()
}
results = append(results, ToolResultMessage(call.ID, text))
}
return results, nil
}