Add MCP integration with MCPServer for tool-based interactions
- Introduce `MCPServer` to support connecting to MCP servers via stdio, SSE, or HTTP. - Implement tool fetching, management, and invocation through MCP. - Add `WithMCPServer` method to `ToolBox` for seamless tool integration. - Extend schema package to handle raw JSON schemas for MCP tools. - Update documentation with MCP usage guidelines and examples.
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
.claude
|
||||
.idea
|
||||
*.exe
|
||||
*.exe
|
||||
.env
|
||||
41
CLAUDE.md
41
CLAUDE.md
@@ -20,6 +20,7 @@
|
||||
- `llm.go`: Contains core interfaces (`LLM`, `ChatCompletion`) and shared types (`Message`, `Role`, `Image`).
|
||||
- Provider implementations are in `openai.go`, `anthropic.go`, and `google.go`.
|
||||
- Schema definitions for tool calling are in the `schema/` directory.
|
||||
- `mcp.go`: MCP (Model Context Protocol) client integration for connecting to MCP servers.
|
||||
- **Imports**: Organize imports into groups: standard library, then third-party libraries.
|
||||
- **Documentation**: Use standard Go doc comments for exported symbols.
|
||||
- **README.md**: The README.md file should always be kept up to date with any significant changes to the project.
|
||||
@@ -45,3 +46,43 @@
|
||||
- `Ctrl+S` - Settings
|
||||
- `Ctrl+N` - New conversation
|
||||
- `Esc` - Exit/Cancel
|
||||
|
||||
## MCP (Model Context Protocol) Support
|
||||
|
||||
The library supports connecting to MCP servers to use their tools. MCP servers can be connected via:
|
||||
- **stdio**: Run a command as a subprocess
|
||||
- **sse**: Connect to an SSE endpoint
|
||||
- **http**: Connect to a streamable HTTP endpoint
|
||||
|
||||
### Usage Example
|
||||
```go
|
||||
ctx := context.Background()
|
||||
|
||||
// Create and connect to an MCP server
|
||||
server := &llm.MCPServer{
|
||||
Name: "my-server",
|
||||
Command: "my-mcp-server",
|
||||
Args: []string{"--some-flag"},
|
||||
}
|
||||
if err := server.Connect(ctx); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer server.Close()
|
||||
|
||||
// Add the server to a toolbox
|
||||
toolbox := llm.NewToolBox().WithMCPServer(server)
|
||||
|
||||
// Use the toolbox in requests - MCP tools are automatically available
|
||||
req := llm.Request{
|
||||
Messages: []llm.Message{{Role: llm.RoleUser, Text: "Use the MCP tool"}},
|
||||
Toolbox: toolbox,
|
||||
}
|
||||
```
|
||||
|
||||
### MCPServer Options
|
||||
- `Name`: Friendly name for logging
|
||||
- `Command`: Command to run (for stdio transport)
|
||||
- `Args`: Command arguments
|
||||
- `Env`: Additional environment variables
|
||||
- `URL`: Endpoint URL (for sse/http transport)
|
||||
- `Transport`: "stdio" (default), "sse", or "http"
|
||||
|
||||
6
go.mod
6
go.mod
@@ -8,7 +8,9 @@ require (
|
||||
github.com/charmbracelet/bubbles v0.21.0
|
||||
github.com/charmbracelet/bubbletea v1.3.10
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/liushuangls/go-anthropic/v2 v2.17.0
|
||||
github.com/modelcontextprotocol/go-sdk v1.2.0
|
||||
github.com/openai/openai-go v1.12.0
|
||||
golang.org/x/image v0.35.0
|
||||
google.golang.org/genai v1.43.0
|
||||
@@ -30,11 +32,11 @@ require (
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/jsonschema-go v0.3.0 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.11 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.16.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/joho/godotenv v1.5.1 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
@@ -48,6 +50,7 @@ require (
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 // indirect
|
||||
go.opentelemetry.io/otel v1.39.0 // indirect
|
||||
@@ -55,6 +58,7 @@ require (
|
||||
go.opentelemetry.io/otel/trace v1.39.0 // indirect
|
||||
golang.org/x/crypto v0.47.0 // indirect
|
||||
golang.org/x/net v0.49.0 // indirect
|
||||
golang.org/x/oauth2 v0.32.0 // indirect
|
||||
golang.org/x/sys v0.40.0 // indirect
|
||||
golang.org/x/text v0.33.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260122232226-8e98ce8d340d // indirect
|
||||
|
||||
12
go.sum
12
go.sum
@@ -35,10 +35,14 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q=
|
||||
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
|
||||
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
@@ -61,6 +65,8 @@ github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2J
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s=
|
||||
github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
@@ -89,6 +95,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 h1:ssfIgGNANqpVFCndZvcuyKbl0g+UAVcbBcqGkG28H0Y=
|
||||
@@ -111,6 +119,8 @@ golang.org/x/image v0.35.0 h1:LKjiHdgMtO8z7Fh18nGY6KDcoEtVfsgLDPeLyguqb7I=
|
||||
golang.org/x/image v0.35.0/go.mod h1:MwPLTVgvxSASsxdLzKrl8BRFuyqMyGhLwmC+TO1Sybk=
|
||||
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
|
||||
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
|
||||
golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY=
|
||||
golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -119,6 +129,8 @@ golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
|
||||
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
|
||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/genai v1.43.0 h1:8vhqhzJNZu1U94e2m+KvDq/TUUjSmDrs1aKkvTa8SoM=
|
||||
|
||||
238
mcp.go
Normal file
238
mcp.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
|
||||
)
|
||||
|
||||
// MCPServer represents a connection to an MCP server.
|
||||
// It manages the lifecycle of the connection and provides access to the server's tools.
|
||||
type MCPServer struct {
|
||||
// Name is a friendly name for this server (used for logging/identification)
|
||||
Name string
|
||||
|
||||
// Command is the command to run the MCP server (for stdio transport)
|
||||
Command string
|
||||
|
||||
// Args are arguments to pass to the command
|
||||
Args []string
|
||||
|
||||
// Env are environment variables to set for the command (in addition to current environment)
|
||||
Env []string
|
||||
|
||||
// URL is the URL for SSE or HTTP transport (alternative to Command)
|
||||
URL string
|
||||
|
||||
// Transport specifies the transport type: "stdio" (default), "sse", or "http"
|
||||
Transport string
|
||||
|
||||
client *mcp.Client
|
||||
session *mcp.ClientSession
|
||||
tools map[string]*mcp.Tool // tool name -> tool definition
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Connect establishes a connection to the MCP server.
|
||||
func (m *MCPServer) Connect(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.session != nil {
|
||||
return nil // Already connected
|
||||
}
|
||||
|
||||
m.client = mcp.NewClient(&mcp.Implementation{
|
||||
Name: "go-llm",
|
||||
Version: "1.0.0",
|
||||
}, nil)
|
||||
|
||||
var transport mcp.Transport
|
||||
|
||||
switch m.Transport {
|
||||
case "sse":
|
||||
transport = &mcp.SSEClientTransport{
|
||||
Endpoint: m.URL,
|
||||
}
|
||||
case "http":
|
||||
transport = &mcp.StreamableClientTransport{
|
||||
Endpoint: m.URL,
|
||||
}
|
||||
default: // "stdio" or empty
|
||||
cmd := exec.Command(m.Command, m.Args...)
|
||||
cmd.Env = append(os.Environ(), m.Env...)
|
||||
transport = &mcp.CommandTransport{
|
||||
Command: cmd,
|
||||
}
|
||||
}
|
||||
|
||||
session, err := m.client.Connect(ctx, transport, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to MCP server %s: %w", m.Name, err)
|
||||
}
|
||||
|
||||
m.session = session
|
||||
|
||||
// Load tools
|
||||
m.tools = make(map[string]*mcp.Tool)
|
||||
for tool, err := range session.Tools(ctx, nil) {
|
||||
if err != nil {
|
||||
m.session.Close()
|
||||
m.session = nil
|
||||
return fmt.Errorf("failed to list tools from %s: %w", m.Name, err)
|
||||
}
|
||||
m.tools[tool.Name] = tool
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the connection to the MCP server.
|
||||
func (m *MCPServer) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.session == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := m.session.Close()
|
||||
m.session = nil
|
||||
m.tools = nil
|
||||
return err
|
||||
}
|
||||
|
||||
// IsConnected returns true if the server is connected.
|
||||
func (m *MCPServer) IsConnected() bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.session != nil
|
||||
}
|
||||
|
||||
// Tools returns the list of tool names available from this server.
|
||||
func (m *MCPServer) Tools() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var names []string
|
||||
for name := range m.tools {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// HasTool returns true if this server provides the named tool.
|
||||
func (m *MCPServer) HasTool(name string) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
_, ok := m.tools[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// CallTool calls a tool on the MCP server.
|
||||
func (m *MCPServer) CallTool(ctx context.Context, name string, arguments map[string]any) (any, error) {
|
||||
m.mu.RLock()
|
||||
session := m.session
|
||||
m.mu.RUnlock()
|
||||
|
||||
if session == nil {
|
||||
return nil, fmt.Errorf("not connected to MCP server %s", m.Name)
|
||||
}
|
||||
|
||||
result, err := session.CallTool(ctx, &mcp.CallToolParams{
|
||||
Name: name,
|
||||
Arguments: arguments,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Process the result content
|
||||
if len(result.Content) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// If there's a single text content, return it as a string
|
||||
if len(result.Content) == 1 {
|
||||
if textContent, ok := result.Content[0].(*mcp.TextContent); ok {
|
||||
return textContent.Text, nil
|
||||
}
|
||||
}
|
||||
|
||||
// For multiple contents or non-text, serialize to string
|
||||
return contentToString(result.Content), nil
|
||||
}
|
||||
|
||||
// toFunction converts an MCP tool to a go-llm Function (for schema purposes only).
|
||||
func (m *MCPServer) toFunction(tool *mcp.Tool) Function {
|
||||
var inputSchema map[string]any
|
||||
if tool.InputSchema != nil {
|
||||
data, err := json.Marshal(tool.InputSchema)
|
||||
if err == nil {
|
||||
_ = json.Unmarshal(data, &inputSchema)
|
||||
}
|
||||
}
|
||||
|
||||
if inputSchema == nil {
|
||||
inputSchema = map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
|
||||
return Function{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
Parameters: schema.NewRaw(inputSchema),
|
||||
}
|
||||
}
|
||||
|
||||
// contentToString converts MCP content to a string representation.
|
||||
func contentToString(content []mcp.Content) string {
|
||||
var parts []string
|
||||
for _, c := range content {
|
||||
switch tc := c.(type) {
|
||||
case *mcp.TextContent:
|
||||
parts = append(parts, tc.Text)
|
||||
default:
|
||||
if data, err := json.Marshal(c); err == nil {
|
||||
parts = append(parts, string(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(parts) == 1 {
|
||||
return parts[0]
|
||||
}
|
||||
data, _ := json.Marshal(parts)
|
||||
return string(data)
|
||||
}
|
||||
|
||||
// WithMCPServer adds an MCP server to the toolbox.
|
||||
// The server must already be connected. Tools from the server will be available
|
||||
// for use, and tool calls will be routed to the appropriate server.
|
||||
func (t ToolBox) WithMCPServer(server *MCPServer) ToolBox {
|
||||
if t.mcpServers == nil {
|
||||
t.mcpServers = make(map[string]*MCPServer)
|
||||
}
|
||||
|
||||
server.mu.RLock()
|
||||
defer server.mu.RUnlock()
|
||||
|
||||
for name, tool := range server.tools {
|
||||
// Add the function definition (for schema)
|
||||
fn := server.toFunction(tool)
|
||||
t.functions[name] = fn
|
||||
|
||||
// Track which server owns this tool
|
||||
t.mcpServers[name] = server
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
134
schema/raw.go
Normal file
134
schema/raw.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/openai/openai-go"
|
||||
"google.golang.org/genai"
|
||||
)
|
||||
|
||||
// Raw represents a raw JSON schema that is passed through directly.
|
||||
// This is used for MCP tools where we receive the schema from the server.
|
||||
type Raw struct {
|
||||
schema map[string]any
|
||||
}
|
||||
|
||||
// NewRaw creates a new Raw schema from a map.
|
||||
func NewRaw(schema map[string]any) Raw {
|
||||
if schema == nil {
|
||||
schema = map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
return Raw{schema: schema}
|
||||
}
|
||||
|
||||
// NewRawFromJSON creates a new Raw schema from JSON bytes.
|
||||
func NewRawFromJSON(data []byte) (Raw, error) {
|
||||
var schema map[string]any
|
||||
if err := json.Unmarshal(data, &schema); err != nil {
|
||||
return Raw{}, fmt.Errorf("failed to parse JSON schema: %w", err)
|
||||
}
|
||||
return NewRaw(schema), nil
|
||||
}
|
||||
|
||||
func (r Raw) OpenAIParameters() openai.FunctionParameters {
|
||||
return openai.FunctionParameters(r.schema)
|
||||
}
|
||||
|
||||
func (r Raw) GoogleParameters() *genai.Schema {
|
||||
return mapToGenaiSchema(r.schema)
|
||||
}
|
||||
|
||||
func (r Raw) AnthropicInputSchema() map[string]any {
|
||||
return r.schema
|
||||
}
|
||||
|
||||
func (r Raw) Required() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (r Raw) Description() string {
|
||||
if desc, ok := r.schema["description"].(string); ok {
|
||||
return desc
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (r Raw) FromAny(val any) (reflect.Value, error) {
|
||||
return reflect.ValueOf(val), nil
|
||||
}
|
||||
|
||||
func (r Raw) SetValueOnField(obj reflect.Value, val reflect.Value) {
|
||||
// No-op for raw schemas
|
||||
}
|
||||
|
||||
// mapToGenaiSchema converts a map[string]any JSON schema to genai.Schema
|
||||
func mapToGenaiSchema(m map[string]any) *genai.Schema {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
schema := &genai.Schema{}
|
||||
|
||||
// Type
|
||||
if t, ok := m["type"].(string); ok {
|
||||
switch t {
|
||||
case "string":
|
||||
schema.Type = genai.TypeString
|
||||
case "number":
|
||||
schema.Type = genai.TypeNumber
|
||||
case "integer":
|
||||
schema.Type = genai.TypeInteger
|
||||
case "boolean":
|
||||
schema.Type = genai.TypeBoolean
|
||||
case "array":
|
||||
schema.Type = genai.TypeArray
|
||||
case "object":
|
||||
schema.Type = genai.TypeObject
|
||||
}
|
||||
}
|
||||
|
||||
// Description
|
||||
if desc, ok := m["description"].(string); ok {
|
||||
schema.Description = desc
|
||||
}
|
||||
|
||||
// Enum
|
||||
if enum, ok := m["enum"].([]any); ok {
|
||||
for _, e := range enum {
|
||||
if s, ok := e.(string); ok {
|
||||
schema.Enum = append(schema.Enum, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Properties (for objects)
|
||||
if props, ok := m["properties"].(map[string]any); ok {
|
||||
schema.Properties = make(map[string]*genai.Schema)
|
||||
for k, v := range props {
|
||||
if vm, ok := v.(map[string]any); ok {
|
||||
schema.Properties[k] = mapToGenaiSchema(vm)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Required
|
||||
if req, ok := m["required"].([]any); ok {
|
||||
for _, r := range req {
|
||||
if s, ok := r.(string); ok {
|
||||
schema.Required = append(schema.Required, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Items (for arrays)
|
||||
if items, ok := m["items"].(map[string]any); ok {
|
||||
schema.Items = mapToGenaiSchema(items)
|
||||
}
|
||||
|
||||
return schema
|
||||
}
|
||||
14
toolbox.go
14
toolbox.go
@@ -2,6 +2,7 @@ package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
// the correct parameters.
|
||||
type ToolBox struct {
|
||||
functions map[string]Function
|
||||
mcpServers map[string]*MCPServer // tool name -> MCP server that provides it
|
||||
dontRequireTool bool
|
||||
}
|
||||
|
||||
@@ -91,6 +93,18 @@ var (
|
||||
)
|
||||
|
||||
func (t ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) {
|
||||
// Check if this is an MCP tool
|
||||
if server, ok := t.mcpServers[functionName]; ok {
|
||||
var args map[string]any
|
||||
if params != "" {
|
||||
if err := json.Unmarshal([]byte(params), &args); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse MCP tool arguments: %w", err)
|
||||
}
|
||||
}
|
||||
return server.CallTool(ctx, functionName, args)
|
||||
}
|
||||
|
||||
// Regular function
|
||||
f, ok := t.functions[functionName]
|
||||
|
||||
if !ok {
|
||||
|
||||
Reference in New Issue
Block a user