package llm import ( "context" "encoding/json" "fmt" "reflect" "gitea.stevedudenhoeffer.com/steve/go-llm/v2/internal/schema" ) // Tool defines a tool that the LLM can invoke. type Tool struct { // Name is the tool's unique identifier. Name string // Description tells the LLM what this tool does. Description string // Schema is the JSON Schema for the tool's parameters. Schema map[string]any // fn holds the implementation function (set via Define or DefineSimple). fn reflect.Value pTyp reflect.Type // nil for parameterless tools // isMCP indicates this tool is provided by an MCP server. isMCP bool mcpServer *MCPServer } // Define creates a tool from a typed handler function. // T must be a struct. Struct fields become the tool's parameters. // // Struct tags: // - `json:"name"` — parameter name // - `description:"..."` — parameter description // - `enum:"a,b,c"` — enum constraint // // Pointer fields are optional; non-pointer fields are required. // // Example: // // type WeatherParams struct { // City string `json:"city" description:"The city to query"` // Unit string `json:"unit" description:"Temperature unit" enum:"celsius,fahrenheit"` // } // // llm.Define[WeatherParams]("get_weather", "Get weather for a city", // func(ctx context.Context, p WeatherParams) (string, error) { // return fmt.Sprintf("72F in %s", p.City), nil // }, // ) func Define[T any](name, description string, fn func(context.Context, T) (string, error)) Tool { var zero T return Tool{ Name: name, Description: description, Schema: schema.FromStruct(zero), fn: reflect.ValueOf(fn), pTyp: reflect.TypeOf(zero), } } // DefineSimple creates a parameterless tool. // // Example: // // llm.DefineSimple("get_time", "Get the current time", // func(ctx context.Context) (string, error) { // return time.Now().Format(time.RFC3339), nil // }, // ) func DefineSimple(name, description string, fn func(context.Context) (string, error)) Tool { return Tool{ Name: name, Description: description, Schema: map[string]any{"type": "object", "properties": map[string]any{}}, fn: reflect.ValueOf(fn), } } // Execute runs the tool with the given JSON arguments string. func (t Tool) Execute(ctx context.Context, argsJSON string) (string, error) { if t.isMCP { var args map[string]any if argsJSON != "" && argsJSON != "{}" { if err := json.Unmarshal([]byte(argsJSON), &args); err != nil { return "", fmt.Errorf("invalid MCP tool arguments: %w", err) } } return t.mcpServer.CallTool(ctx, t.Name, args) } // Parameterless tool if t.pTyp == nil { out := t.fn.Call([]reflect.Value{reflect.ValueOf(ctx)}) if !out[1].IsNil() { return "", out[1].Interface().(error) } return out[0].String(), nil } // Typed tool: unmarshal JSON into the struct, call the function p := reflect.New(t.pTyp) if argsJSON != "" && argsJSON != "{}" { if err := json.Unmarshal([]byte(argsJSON), p.Interface()); err != nil { return "", fmt.Errorf("invalid tool arguments: %w", err) } } out := t.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()}) if !out[1].IsNil() { return "", out[1].Interface().(error) } return out[0].String(), nil } // ToolBox is a collection of tools available for use by an LLM. type ToolBox struct { tools map[string]Tool mcpServers []*MCPServer } // NewToolBox creates a new ToolBox from the given tools. func NewToolBox(tools ...Tool) *ToolBox { tb := &ToolBox{tools: make(map[string]Tool)} for _, t := range tools { tb.tools[t.Name] = t } return tb } // Add adds tools to the toolbox and returns it for chaining. func (tb *ToolBox) Add(tools ...Tool) *ToolBox { if tb.tools == nil { tb.tools = make(map[string]Tool) } for _, t := range tools { tb.tools[t.Name] = t } return tb } // AddMCP adds an MCP server's tools to the toolbox. The server must be connected. func (tb *ToolBox) AddMCP(server *MCPServer) *ToolBox { if tb.tools == nil { tb.tools = make(map[string]Tool) } tb.mcpServers = append(tb.mcpServers, server) for _, tool := range server.ListTools() { tb.tools[tool.Name] = tool } return tb } // AllTools returns all tools (local + MCP) as a slice. func (tb *ToolBox) AllTools() []Tool { if tb == nil { return nil } tools := make([]Tool, 0, len(tb.tools)) for _, t := range tb.tools { tools = append(tools, t) } return tools } // Execute executes a tool call by name. func (tb *ToolBox) Execute(ctx context.Context, call ToolCall) (string, error) { if tb == nil { return "", ErrNoToolsConfigured } tool, ok := tb.tools[call.Name] if !ok { return "", fmt.Errorf("%w: %s", ErrToolNotFound, call.Name) } return tool.Execute(ctx, call.Arguments) } // ExecuteAll executes all tool calls and returns tool result messages. func (tb *ToolBox) ExecuteAll(ctx context.Context, calls []ToolCall) ([]Message, error) { var results []Message for _, call := range calls { result, err := tb.Execute(ctx, call) text := result if err != nil { text = "Error: " + err.Error() } results = append(results, ToolResultMessage(call.ID, text)) } return results, nil }