v2 is a new Go module (v2/) with a dramatically simpler API: - Unified Message type (no more Input marker interface) - Define[T] for ergonomic tool creation with standard context.Context - Chat session with automatic tool-call loop (agent loop) - Streaming via pull-based StreamReader - MCP one-call connect (MCPStdioServer, MCPHTTPServer, MCPSSEServer) - Middleware support (logging, retry, timeout, usage tracking) - Decoupled JSON Schema (map[string]any, no provider coupling) - Sample tools: WebSearch, Browser, Exec, ReadFile, WriteFile, HTTP - Providers: OpenAI, Anthropic, Google (all with streaming) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
194 lines
5.1 KiB
Go
194 lines
5.1 KiB
Go
package llm
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"reflect"
|
|
|
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/internal/schema"
|
|
)
|
|
|
|
// Tool defines a tool that the LLM can invoke.
|
|
type Tool struct {
|
|
// Name is the tool's unique identifier.
|
|
Name string
|
|
|
|
// Description tells the LLM what this tool does.
|
|
Description string
|
|
|
|
// Schema is the JSON Schema for the tool's parameters.
|
|
Schema map[string]any
|
|
|
|
// fn holds the implementation function (set via Define or DefineSimple).
|
|
fn reflect.Value
|
|
pTyp reflect.Type // nil for parameterless tools
|
|
|
|
// isMCP indicates this tool is provided by an MCP server.
|
|
isMCP bool
|
|
mcpServer *MCPServer
|
|
}
|
|
|
|
// Define creates a tool from a typed handler function.
|
|
// T must be a struct. Struct fields become the tool's parameters.
|
|
//
|
|
// Struct tags:
|
|
// - `json:"name"` — parameter name
|
|
// - `description:"..."` — parameter description
|
|
// - `enum:"a,b,c"` — enum constraint
|
|
//
|
|
// Pointer fields are optional; non-pointer fields are required.
|
|
//
|
|
// Example:
|
|
//
|
|
// type WeatherParams struct {
|
|
// City string `json:"city" description:"The city to query"`
|
|
// Unit string `json:"unit" description:"Temperature unit" enum:"celsius,fahrenheit"`
|
|
// }
|
|
//
|
|
// llm.Define[WeatherParams]("get_weather", "Get weather for a city",
|
|
// func(ctx context.Context, p WeatherParams) (string, error) {
|
|
// return fmt.Sprintf("72F in %s", p.City), nil
|
|
// },
|
|
// )
|
|
func Define[T any](name, description string, fn func(context.Context, T) (string, error)) Tool {
|
|
var zero T
|
|
return Tool{
|
|
Name: name,
|
|
Description: description,
|
|
Schema: schema.FromStruct(zero),
|
|
fn: reflect.ValueOf(fn),
|
|
pTyp: reflect.TypeOf(zero),
|
|
}
|
|
}
|
|
|
|
// DefineSimple creates a parameterless tool.
|
|
//
|
|
// Example:
|
|
//
|
|
// llm.DefineSimple("get_time", "Get the current time",
|
|
// func(ctx context.Context) (string, error) {
|
|
// return time.Now().Format(time.RFC3339), nil
|
|
// },
|
|
// )
|
|
func DefineSimple(name, description string, fn func(context.Context) (string, error)) Tool {
|
|
return Tool{
|
|
Name: name,
|
|
Description: description,
|
|
Schema: map[string]any{"type": "object", "properties": map[string]any{}},
|
|
fn: reflect.ValueOf(fn),
|
|
}
|
|
}
|
|
|
|
// Execute runs the tool with the given JSON arguments string.
|
|
func (t Tool) Execute(ctx context.Context, argsJSON string) (string, error) {
|
|
if t.isMCP {
|
|
var args map[string]any
|
|
if argsJSON != "" && argsJSON != "{}" {
|
|
if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
|
|
return "", fmt.Errorf("invalid MCP tool arguments: %w", err)
|
|
}
|
|
}
|
|
return t.mcpServer.CallTool(ctx, t.Name, args)
|
|
}
|
|
|
|
// Parameterless tool
|
|
if t.pTyp == nil {
|
|
out := t.fn.Call([]reflect.Value{reflect.ValueOf(ctx)})
|
|
if !out[1].IsNil() {
|
|
return "", out[1].Interface().(error)
|
|
}
|
|
return out[0].String(), nil
|
|
}
|
|
|
|
// Typed tool: unmarshal JSON into the struct, call the function
|
|
p := reflect.New(t.pTyp)
|
|
if argsJSON != "" && argsJSON != "{}" {
|
|
if err := json.Unmarshal([]byte(argsJSON), p.Interface()); err != nil {
|
|
return "", fmt.Errorf("invalid tool arguments: %w", err)
|
|
}
|
|
}
|
|
out := t.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()})
|
|
if !out[1].IsNil() {
|
|
return "", out[1].Interface().(error)
|
|
}
|
|
return out[0].String(), nil
|
|
}
|
|
|
|
// ToolBox is a collection of tools available for use by an LLM.
|
|
type ToolBox struct {
|
|
tools map[string]Tool
|
|
mcpServers []*MCPServer
|
|
}
|
|
|
|
// NewToolBox creates a new ToolBox from the given tools.
|
|
func NewToolBox(tools ...Tool) *ToolBox {
|
|
tb := &ToolBox{tools: make(map[string]Tool)}
|
|
for _, t := range tools {
|
|
tb.tools[t.Name] = t
|
|
}
|
|
return tb
|
|
}
|
|
|
|
// Add adds tools to the toolbox and returns it for chaining.
|
|
func (tb *ToolBox) Add(tools ...Tool) *ToolBox {
|
|
if tb.tools == nil {
|
|
tb.tools = make(map[string]Tool)
|
|
}
|
|
for _, t := range tools {
|
|
tb.tools[t.Name] = t
|
|
}
|
|
return tb
|
|
}
|
|
|
|
// AddMCP adds an MCP server's tools to the toolbox. The server must be connected.
|
|
func (tb *ToolBox) AddMCP(server *MCPServer) *ToolBox {
|
|
if tb.tools == nil {
|
|
tb.tools = make(map[string]Tool)
|
|
}
|
|
tb.mcpServers = append(tb.mcpServers, server)
|
|
|
|
for _, tool := range server.ListTools() {
|
|
tb.tools[tool.Name] = tool
|
|
}
|
|
return tb
|
|
}
|
|
|
|
// AllTools returns all tools (local + MCP) as a slice.
|
|
func (tb *ToolBox) AllTools() []Tool {
|
|
if tb == nil {
|
|
return nil
|
|
}
|
|
tools := make([]Tool, 0, len(tb.tools))
|
|
for _, t := range tb.tools {
|
|
tools = append(tools, t)
|
|
}
|
|
return tools
|
|
}
|
|
|
|
// Execute executes a tool call by name.
|
|
func (tb *ToolBox) Execute(ctx context.Context, call ToolCall) (string, error) {
|
|
if tb == nil {
|
|
return "", ErrNoToolsConfigured
|
|
}
|
|
tool, ok := tb.tools[call.Name]
|
|
if !ok {
|
|
return "", fmt.Errorf("%w: %s", ErrToolNotFound, call.Name)
|
|
}
|
|
return tool.Execute(ctx, call.Arguments)
|
|
}
|
|
|
|
// ExecuteAll executes all tool calls and returns tool result messages.
|
|
func (tb *ToolBox) ExecuteAll(ctx context.Context, calls []ToolCall) ([]Message, error) {
|
|
var results []Message
|
|
for _, call := range calls {
|
|
result, err := tb.Execute(ctx, call)
|
|
text := result
|
|
if err != nil {
|
|
text = "Error: " + err.Error()
|
|
}
|
|
results = append(results, ToolResultMessage(call.ID, text))
|
|
}
|
|
return results, nil
|
|
}
|