diff --git a/.gitignore b/.gitignore index 0a94273..e6a0eb9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .claude .idea -*.exe \ No newline at end of file +*.exe +.env \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md index 1185e07..3824b30 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -20,6 +20,7 @@ - `llm.go`: Contains core interfaces (`LLM`, `ChatCompletion`) and shared types (`Message`, `Role`, `Image`). - Provider implementations are in `openai.go`, `anthropic.go`, and `google.go`. - Schema definitions for tool calling are in the `schema/` directory. + - `mcp.go`: MCP (Model Context Protocol) client integration for connecting to MCP servers. - **Imports**: Organize imports into groups: standard library, then third-party libraries. - **Documentation**: Use standard Go doc comments for exported symbols. - **README.md**: The README.md file should always be kept up to date with any significant changes to the project. @@ -45,3 +46,43 @@ - `Ctrl+S` - Settings - `Ctrl+N` - New conversation - `Esc` - Exit/Cancel + +## MCP (Model Context Protocol) Support + +The library supports connecting to MCP servers to use their tools. MCP servers can be connected via: +- **stdio**: Run a command as a subprocess +- **sse**: Connect to an SSE endpoint +- **http**: Connect to a streamable HTTP endpoint + +### Usage Example +```go +ctx := context.Background() + +// Create and connect to an MCP server +server := &llm.MCPServer{ + Name: "my-server", + Command: "my-mcp-server", + Args: []string{"--some-flag"}, +} +if err := server.Connect(ctx); err != nil { + log.Fatal(err) +} +defer server.Close() + +// Add the server to a toolbox +toolbox := llm.NewToolBox().WithMCPServer(server) + +// Use the toolbox in requests - MCP tools are automatically available +req := llm.Request{ + Messages: []llm.Message{{Role: llm.RoleUser, Text: "Use the MCP tool"}}, + Toolbox: toolbox, +} +``` + +### MCPServer Options +- `Name`: Friendly name for logging +- `Command`: Command to run (for stdio transport) +- `Args`: Command arguments +- `Env`: Additional environment variables +- `URL`: Endpoint URL (for sse/http transport) +- `Transport`: "stdio" (default), "sse", or "http" diff --git a/go.mod b/go.mod index c9bec2a..8e10002 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,9 @@ require ( github.com/charmbracelet/bubbles v0.21.0 github.com/charmbracelet/bubbletea v1.3.10 github.com/charmbracelet/lipgloss v1.1.0 + github.com/joho/godotenv v1.5.1 github.com/liushuangls/go-anthropic/v2 v2.17.0 + github.com/modelcontextprotocol/go-sdk v1.2.0 github.com/openai/openai-go v1.12.0 golang.org/x/image v0.35.0 google.golang.org/genai v1.43.0 @@ -30,11 +32,11 @@ require ( github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/google/go-cmp v0.7.0 // indirect + github.com/google/jsonschema-go v0.3.0 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.11 // indirect github.com/googleapis/gax-go/v2 v2.16.0 // indirect github.com/gorilla/websocket v1.5.3 // indirect - github.com/joho/godotenv v1.5.1 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-localereader v0.0.1 // indirect @@ -48,6 +50,7 @@ require ( github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 // indirect go.opentelemetry.io/otel v1.39.0 // indirect @@ -55,6 +58,7 @@ require ( go.opentelemetry.io/otel/trace v1.39.0 // indirect golang.org/x/crypto v0.47.0 // indirect golang.org/x/net v0.49.0 // indirect + golang.org/x/oauth2 v0.32.0 // indirect golang.org/x/sys v0.40.0 // indirect golang.org/x/text v0.33.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260122232226-8e98ce8d340d // indirect diff --git a/go.sum b/go.sum index c250d8c..a7a8e88 100644 --- a/go.sum +++ b/go.sum @@ -35,10 +35,14 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= +github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -61,6 +65,8 @@ github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2J github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s= +github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= @@ -89,6 +95,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 h1:ssfIgGNANqpVFCndZvcuyKbl0g+UAVcbBcqGkG28H0Y= @@ -111,6 +119,8 @@ golang.org/x/image v0.35.0 h1:LKjiHdgMtO8z7Fh18nGY6KDcoEtVfsgLDPeLyguqb7I= golang.org/x/image v0.35.0/go.mod h1:MwPLTVgvxSASsxdLzKrl8BRFuyqMyGhLwmC+TO1Sybk= golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY= +golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -119,6 +129,8 @@ golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= +golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/genai v1.43.0 h1:8vhqhzJNZu1U94e2m+KvDq/TUUjSmDrs1aKkvTa8SoM= diff --git a/mcp.go b/mcp.go new file mode 100644 index 0000000..342a6fb --- /dev/null +++ b/mcp.go @@ -0,0 +1,238 @@ +package llm + +import ( + "context" + "encoding/json" + "fmt" + "os" + "os/exec" + "sync" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "gitea.stevedudenhoeffer.com/steve/go-llm/schema" +) + +// MCPServer represents a connection to an MCP server. +// It manages the lifecycle of the connection and provides access to the server's tools. +type MCPServer struct { + // Name is a friendly name for this server (used for logging/identification) + Name string + + // Command is the command to run the MCP server (for stdio transport) + Command string + + // Args are arguments to pass to the command + Args []string + + // Env are environment variables to set for the command (in addition to current environment) + Env []string + + // URL is the URL for SSE or HTTP transport (alternative to Command) + URL string + + // Transport specifies the transport type: "stdio" (default), "sse", or "http" + Transport string + + client *mcp.Client + session *mcp.ClientSession + tools map[string]*mcp.Tool // tool name -> tool definition + mu sync.RWMutex +} + +// Connect establishes a connection to the MCP server. +func (m *MCPServer) Connect(ctx context.Context) error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.session != nil { + return nil // Already connected + } + + m.client = mcp.NewClient(&mcp.Implementation{ + Name: "go-llm", + Version: "1.0.0", + }, nil) + + var transport mcp.Transport + + switch m.Transport { + case "sse": + transport = &mcp.SSEClientTransport{ + Endpoint: m.URL, + } + case "http": + transport = &mcp.StreamableClientTransport{ + Endpoint: m.URL, + } + default: // "stdio" or empty + cmd := exec.Command(m.Command, m.Args...) + cmd.Env = append(os.Environ(), m.Env...) + transport = &mcp.CommandTransport{ + Command: cmd, + } + } + + session, err := m.client.Connect(ctx, transport, nil) + if err != nil { + return fmt.Errorf("failed to connect to MCP server %s: %w", m.Name, err) + } + + m.session = session + + // Load tools + m.tools = make(map[string]*mcp.Tool) + for tool, err := range session.Tools(ctx, nil) { + if err != nil { + m.session.Close() + m.session = nil + return fmt.Errorf("failed to list tools from %s: %w", m.Name, err) + } + m.tools[tool.Name] = tool + } + + return nil +} + +// Close closes the connection to the MCP server. +func (m *MCPServer) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.session == nil { + return nil + } + + err := m.session.Close() + m.session = nil + m.tools = nil + return err +} + +// IsConnected returns true if the server is connected. +func (m *MCPServer) IsConnected() bool { + m.mu.RLock() + defer m.mu.RUnlock() + return m.session != nil +} + +// Tools returns the list of tool names available from this server. +func (m *MCPServer) Tools() []string { + m.mu.RLock() + defer m.mu.RUnlock() + + var names []string + for name := range m.tools { + names = append(names, name) + } + return names +} + +// HasTool returns true if this server provides the named tool. +func (m *MCPServer) HasTool(name string) bool { + m.mu.RLock() + defer m.mu.RUnlock() + _, ok := m.tools[name] + return ok +} + +// CallTool calls a tool on the MCP server. +func (m *MCPServer) CallTool(ctx context.Context, name string, arguments map[string]any) (any, error) { + m.mu.RLock() + session := m.session + m.mu.RUnlock() + + if session == nil { + return nil, fmt.Errorf("not connected to MCP server %s", m.Name) + } + + result, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: name, + Arguments: arguments, + }) + if err != nil { + return nil, err + } + + // Process the result content + if len(result.Content) == 0 { + return nil, nil + } + + // If there's a single text content, return it as a string + if len(result.Content) == 1 { + if textContent, ok := result.Content[0].(*mcp.TextContent); ok { + return textContent.Text, nil + } + } + + // For multiple contents or non-text, serialize to string + return contentToString(result.Content), nil +} + +// toFunction converts an MCP tool to a go-llm Function (for schema purposes only). +func (m *MCPServer) toFunction(tool *mcp.Tool) Function { + var inputSchema map[string]any + if tool.InputSchema != nil { + data, err := json.Marshal(tool.InputSchema) + if err == nil { + _ = json.Unmarshal(data, &inputSchema) + } + } + + if inputSchema == nil { + inputSchema = map[string]any{ + "type": "object", + "properties": map[string]any{}, + } + } + + return Function{ + Name: tool.Name, + Description: tool.Description, + Parameters: schema.NewRaw(inputSchema), + } +} + +// contentToString converts MCP content to a string representation. +func contentToString(content []mcp.Content) string { + var parts []string + for _, c := range content { + switch tc := c.(type) { + case *mcp.TextContent: + parts = append(parts, tc.Text) + default: + if data, err := json.Marshal(c); err == nil { + parts = append(parts, string(data)) + } + } + } + if len(parts) == 1 { + return parts[0] + } + data, _ := json.Marshal(parts) + return string(data) +} + +// WithMCPServer adds an MCP server to the toolbox. +// The server must already be connected. Tools from the server will be available +// for use, and tool calls will be routed to the appropriate server. +func (t ToolBox) WithMCPServer(server *MCPServer) ToolBox { + if t.mcpServers == nil { + t.mcpServers = make(map[string]*MCPServer) + } + + server.mu.RLock() + defer server.mu.RUnlock() + + for name, tool := range server.tools { + // Add the function definition (for schema) + fn := server.toFunction(tool) + t.functions[name] = fn + + // Track which server owns this tool + t.mcpServers[name] = server + } + + return t +} diff --git a/schema/raw.go b/schema/raw.go new file mode 100644 index 0000000..b4dcb27 --- /dev/null +++ b/schema/raw.go @@ -0,0 +1,134 @@ +package schema + +import ( + "encoding/json" + "fmt" + "reflect" + + "github.com/openai/openai-go" + "google.golang.org/genai" +) + +// Raw represents a raw JSON schema that is passed through directly. +// This is used for MCP tools where we receive the schema from the server. +type Raw struct { + schema map[string]any +} + +// NewRaw creates a new Raw schema from a map. +func NewRaw(schema map[string]any) Raw { + if schema == nil { + schema = map[string]any{ + "type": "object", + "properties": map[string]any{}, + } + } + return Raw{schema: schema} +} + +// NewRawFromJSON creates a new Raw schema from JSON bytes. +func NewRawFromJSON(data []byte) (Raw, error) { + var schema map[string]any + if err := json.Unmarshal(data, &schema); err != nil { + return Raw{}, fmt.Errorf("failed to parse JSON schema: %w", err) + } + return NewRaw(schema), nil +} + +func (r Raw) OpenAIParameters() openai.FunctionParameters { + return openai.FunctionParameters(r.schema) +} + +func (r Raw) GoogleParameters() *genai.Schema { + return mapToGenaiSchema(r.schema) +} + +func (r Raw) AnthropicInputSchema() map[string]any { + return r.schema +} + +func (r Raw) Required() bool { + return false +} + +func (r Raw) Description() string { + if desc, ok := r.schema["description"].(string); ok { + return desc + } + return "" +} + +func (r Raw) FromAny(val any) (reflect.Value, error) { + return reflect.ValueOf(val), nil +} + +func (r Raw) SetValueOnField(obj reflect.Value, val reflect.Value) { + // No-op for raw schemas +} + +// mapToGenaiSchema converts a map[string]any JSON schema to genai.Schema +func mapToGenaiSchema(m map[string]any) *genai.Schema { + if m == nil { + return nil + } + + schema := &genai.Schema{} + + // Type + if t, ok := m["type"].(string); ok { + switch t { + case "string": + schema.Type = genai.TypeString + case "number": + schema.Type = genai.TypeNumber + case "integer": + schema.Type = genai.TypeInteger + case "boolean": + schema.Type = genai.TypeBoolean + case "array": + schema.Type = genai.TypeArray + case "object": + schema.Type = genai.TypeObject + } + } + + // Description + if desc, ok := m["description"].(string); ok { + schema.Description = desc + } + + // Enum + if enum, ok := m["enum"].([]any); ok { + for _, e := range enum { + if s, ok := e.(string); ok { + schema.Enum = append(schema.Enum, s) + } + } + } + + // Properties (for objects) + if props, ok := m["properties"].(map[string]any); ok { + schema.Properties = make(map[string]*genai.Schema) + for k, v := range props { + if vm, ok := v.(map[string]any); ok { + schema.Properties[k] = mapToGenaiSchema(vm) + } + } + } + + // Required + if req, ok := m["required"].([]any); ok { + for _, r := range req { + if s, ok := r.(string); ok { + schema.Required = append(schema.Required, s) + } + } + } + + // Items (for arrays) + if items, ok := m["items"].(map[string]any); ok { + schema.Items = mapToGenaiSchema(items) + } + + return schema +} diff --git a/toolbox.go b/toolbox.go index 313d9bd..10cb892 100644 --- a/toolbox.go +++ b/toolbox.go @@ -2,6 +2,7 @@ package llm import ( "context" + "encoding/json" "errors" "fmt" ) @@ -11,6 +12,7 @@ import ( // the correct parameters. type ToolBox struct { functions map[string]Function + mcpServers map[string]*MCPServer // tool name -> MCP server that provides it dontRequireTool bool } @@ -91,6 +93,18 @@ var ( ) func (t ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) { + // Check if this is an MCP tool + if server, ok := t.mcpServers[functionName]; ok { + var args map[string]any + if params != "" { + if err := json.Unmarshal([]byte(params), &args); err != nil { + return nil, fmt.Errorf("failed to parse MCP tool arguments: %w", err) + } + } + return server.CallTool(ctx, functionName, args) + } + + // Regular function f, ok := t.functions[functionName] if !ok {