- Introduce `MCPServer` to support connecting to MCP servers via stdio, SSE, or HTTP. - Implement tool fetching, management, and invocation through MCP. - Add `WithMCPServer` method to `ToolBox` for seamless tool integration. - Extend schema package to handle raw JSON schemas for MCP tools. - Update documentation with MCP usage guidelines and examples.
239 lines
5.3 KiB
Go
239 lines
5.3 KiB
Go
package llm
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"os/exec"
|
|
"sync"
|
|
|
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
|
|
|
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
|
|
)
|
|
|
|
// MCPServer represents a connection to an MCP server.
|
|
// It manages the lifecycle of the connection and provides access to the server's tools.
|
|
type MCPServer struct {
|
|
// Name is a friendly name for this server (used for logging/identification)
|
|
Name string
|
|
|
|
// Command is the command to run the MCP server (for stdio transport)
|
|
Command string
|
|
|
|
// Args are arguments to pass to the command
|
|
Args []string
|
|
|
|
// Env are environment variables to set for the command (in addition to current environment)
|
|
Env []string
|
|
|
|
// URL is the URL for SSE or HTTP transport (alternative to Command)
|
|
URL string
|
|
|
|
// Transport specifies the transport type: "stdio" (default), "sse", or "http"
|
|
Transport string
|
|
|
|
client *mcp.Client
|
|
session *mcp.ClientSession
|
|
tools map[string]*mcp.Tool // tool name -> tool definition
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// Connect establishes a connection to the MCP server.
|
|
func (m *MCPServer) Connect(ctx context.Context) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if m.session != nil {
|
|
return nil // Already connected
|
|
}
|
|
|
|
m.client = mcp.NewClient(&mcp.Implementation{
|
|
Name: "go-llm",
|
|
Version: "1.0.0",
|
|
}, nil)
|
|
|
|
var transport mcp.Transport
|
|
|
|
switch m.Transport {
|
|
case "sse":
|
|
transport = &mcp.SSEClientTransport{
|
|
Endpoint: m.URL,
|
|
}
|
|
case "http":
|
|
transport = &mcp.StreamableClientTransport{
|
|
Endpoint: m.URL,
|
|
}
|
|
default: // "stdio" or empty
|
|
cmd := exec.Command(m.Command, m.Args...)
|
|
cmd.Env = append(os.Environ(), m.Env...)
|
|
transport = &mcp.CommandTransport{
|
|
Command: cmd,
|
|
}
|
|
}
|
|
|
|
session, err := m.client.Connect(ctx, transport, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to MCP server %s: %w", m.Name, err)
|
|
}
|
|
|
|
m.session = session
|
|
|
|
// Load tools
|
|
m.tools = make(map[string]*mcp.Tool)
|
|
for tool, err := range session.Tools(ctx, nil) {
|
|
if err != nil {
|
|
m.session.Close()
|
|
m.session = nil
|
|
return fmt.Errorf("failed to list tools from %s: %w", m.Name, err)
|
|
}
|
|
m.tools[tool.Name] = tool
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Close closes the connection to the MCP server.
|
|
func (m *MCPServer) Close() error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if m.session == nil {
|
|
return nil
|
|
}
|
|
|
|
err := m.session.Close()
|
|
m.session = nil
|
|
m.tools = nil
|
|
return err
|
|
}
|
|
|
|
// IsConnected returns true if the server is connected.
|
|
func (m *MCPServer) IsConnected() bool {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
return m.session != nil
|
|
}
|
|
|
|
// Tools returns the list of tool names available from this server.
|
|
func (m *MCPServer) Tools() []string {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
var names []string
|
|
for name := range m.tools {
|
|
names = append(names, name)
|
|
}
|
|
return names
|
|
}
|
|
|
|
// HasTool returns true if this server provides the named tool.
|
|
func (m *MCPServer) HasTool(name string) bool {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
_, ok := m.tools[name]
|
|
return ok
|
|
}
|
|
|
|
// CallTool calls a tool on the MCP server.
|
|
func (m *MCPServer) CallTool(ctx context.Context, name string, arguments map[string]any) (any, error) {
|
|
m.mu.RLock()
|
|
session := m.session
|
|
m.mu.RUnlock()
|
|
|
|
if session == nil {
|
|
return nil, fmt.Errorf("not connected to MCP server %s", m.Name)
|
|
}
|
|
|
|
result, err := session.CallTool(ctx, &mcp.CallToolParams{
|
|
Name: name,
|
|
Arguments: arguments,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Process the result content
|
|
if len(result.Content) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
// If there's a single text content, return it as a string
|
|
if len(result.Content) == 1 {
|
|
if textContent, ok := result.Content[0].(*mcp.TextContent); ok {
|
|
return textContent.Text, nil
|
|
}
|
|
}
|
|
|
|
// For multiple contents or non-text, serialize to string
|
|
return contentToString(result.Content), nil
|
|
}
|
|
|
|
// toFunction converts an MCP tool to a go-llm Function (for schema purposes only).
|
|
func (m *MCPServer) toFunction(tool *mcp.Tool) Function {
|
|
var inputSchema map[string]any
|
|
if tool.InputSchema != nil {
|
|
data, err := json.Marshal(tool.InputSchema)
|
|
if err == nil {
|
|
_ = json.Unmarshal(data, &inputSchema)
|
|
}
|
|
}
|
|
|
|
if inputSchema == nil {
|
|
inputSchema = map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{},
|
|
}
|
|
}
|
|
|
|
return Function{
|
|
Name: tool.Name,
|
|
Description: tool.Description,
|
|
Parameters: schema.NewRaw(inputSchema),
|
|
}
|
|
}
|
|
|
|
// contentToString converts MCP content to a string representation.
|
|
func contentToString(content []mcp.Content) string {
|
|
var parts []string
|
|
for _, c := range content {
|
|
switch tc := c.(type) {
|
|
case *mcp.TextContent:
|
|
parts = append(parts, tc.Text)
|
|
default:
|
|
if data, err := json.Marshal(c); err == nil {
|
|
parts = append(parts, string(data))
|
|
}
|
|
}
|
|
}
|
|
if len(parts) == 1 {
|
|
return parts[0]
|
|
}
|
|
data, _ := json.Marshal(parts)
|
|
return string(data)
|
|
}
|
|
|
|
// WithMCPServer adds an MCP server to the toolbox.
|
|
// The server must already be connected. Tools from the server will be available
|
|
// for use, and tool calls will be routed to the appropriate server.
|
|
func (t ToolBox) WithMCPServer(server *MCPServer) ToolBox {
|
|
if t.mcpServers == nil {
|
|
t.mcpServers = make(map[string]*MCPServer)
|
|
}
|
|
|
|
server.mu.RLock()
|
|
defer server.mu.RUnlock()
|
|
|
|
for name, tool := range server.tools {
|
|
// Add the function definition (for schema)
|
|
fn := server.toFunction(tool)
|
|
t.functions[name] = fn
|
|
|
|
// Track which server owns this tool
|
|
t.mcpServers[name] = server
|
|
}
|
|
|
|
return t
|
|
}
|