Add go-llm v2: redesigned API for simpler LLM abstraction

v2 is a new Go module (v2/) with a dramatically simpler API:
- Unified Message type (no more Input marker interface)
- Define[T] for ergonomic tool creation with standard context.Context
- Chat session with automatic tool-call loop (agent loop)
- Streaming via pull-based StreamReader
- MCP one-call connect (MCPStdioServer, MCPHTTPServer, MCPSSEServer)
- Middleware support (logging, retry, timeout, usage tracking)
- Decoupled JSON Schema (map[string]any, no provider coupling)
- Sample tools: WebSearch, Browser, Exec, ReadFile, WriteFile, HTTP
- Providers: OpenAI, Anthropic, Google (all with streaming)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-07 20:00:08 -05:00
parent 85a848d96e
commit a4cb4baab5
28 changed files with 3598 additions and 0 deletions

31
v2/CLAUDE.md Normal file
View File

@@ -0,0 +1,31 @@
# CLAUDE.md for go-llm v2
## Build and Test Commands
- Build project: `cd v2 && go build ./...`
- Run all tests: `cd v2 && go test ./...`
- Run specific test: `cd v2 && go test -v -run <TestName> ./...`
- Tidy dependencies: `cd v2 && go mod tidy`
- Vet: `cd v2 && go vet ./...`
## Code Style Guidelines
- **Indentation**: Standard Go tabs
- **Naming**: `camelCase` for unexported, `PascalCase` for exported
- **Error Handling**: Always check and handle errors immediately. Wrap with `fmt.Errorf("%w: ...", err)`
- **Imports**: Standard library first, then third-party, then internal packages
## Package Structure
- Root package `llm` — public API (Client, Model, Chat, ToolBox, Message types)
- `provider/` — Provider interface that backends implement
- `openai/`, `anthropic/`, `google/` — Provider implementations
- `tools/` — Ready-to-use sample tools (WebSearch, Browser, Exec, ReadFile, WriteFile, HTTP)
- `internal/schema/` — JSON Schema generation from Go structs
- `internal/imageutil/` — Image compression utilities
## Key Design Decisions
1. Unified `Message` type instead of marker interfaces
2. `map[string]any` JSON Schema (no provider coupling)
3. Tool functions return `(string, error)`, use standard `context.Context`
4. `Chat.Send()` auto-loops tool calls; `Chat.SendRaw()` for manual control
5. MCP one-call connect: `MCPStdioServer(ctx, cmd, args...)`
6. Streaming via pull-based `StreamReader.Next()`
7. Middleware for logging, retry, timeout, usage tracking

273
v2/anthropic/anthropic.go Normal file
View File

@@ -0,0 +1,273 @@
// Package anthropic implements the go-llm v2 provider interface for Anthropic.
package anthropic
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/internal/imageutil"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
anth "github.com/liushuangls/go-anthropic/v2"
)
// Provider implements the provider.Provider interface for Anthropic.
type Provider struct {
apiKey string
}
// New creates a new Anthropic provider.
func New(apiKey string) *Provider {
return &Provider{apiKey: apiKey}
}
// Complete performs a non-streaming completion.
func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
cl := anth.NewClient(p.apiKey)
anthReq := p.buildRequest(req)
resp, err := cl.CreateMessages(ctx, anthReq)
if err != nil {
return provider.Response{}, fmt.Errorf("anthropic completion error: %w", err)
}
return p.convertResponse(resp), nil
}
// Stream performs a streaming completion.
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
cl := anth.NewClient(p.apiKey)
anthReq := p.buildRequest(req)
resp, err := cl.CreateMessagesStream(ctx, anth.MessagesStreamRequest{
MessagesRequest: anthReq,
OnContentBlockDelta: func(data anth.MessagesEventContentBlockDeltaData) {
if data.Delta.Type == "text_delta" && data.Delta.Text != nil {
events <- provider.StreamEvent{
Type: provider.StreamEventText,
Text: *data.Delta.Text,
}
}
},
})
if err != nil {
return fmt.Errorf("anthropic stream error: %w", err)
}
result := p.convertResponse(resp)
events <- provider.StreamEvent{
Type: provider.StreamEventDone,
Response: &result,
}
return nil
}
func (p *Provider) buildRequest(req provider.Request) anth.MessagesRequest {
anthReq := anth.MessagesRequest{
Model: anth.Model(req.Model),
MaxTokens: 4096,
}
if req.MaxTokens != nil {
anthReq.MaxTokens = *req.MaxTokens
}
var msgs []anth.Message
for _, msg := range req.Messages {
if msg.Role == "system" {
if len(anthReq.System) > 0 {
anthReq.System += "\n"
}
anthReq.System += msg.Content
continue
}
if msg.Role == "tool" {
// Tool results in Anthropic format - use the helper
toolUseID := msg.ToolCallID
content := msg.Content
isError := false
msgs = append(msgs, anth.Message{
Role: anth.RoleUser,
Content: []anth.MessageContent{
{
Type: anth.MessagesContentTypeToolResult,
MessageContentToolResult: &anth.MessageContentToolResult{
ToolUseID: &toolUseID,
Content: []anth.MessageContent{
{
Type: anth.MessagesContentTypeText,
Text: &content,
},
},
IsError: &isError,
},
},
},
})
continue
}
role := anth.RoleUser
if msg.Role == "assistant" {
role = anth.RoleAssistant
}
m := anth.Message{
Role: role,
Content: []anth.MessageContent{},
}
if msg.Content != "" {
m.Content = append(m.Content, anth.MessageContent{
Type: anth.MessagesContentTypeText,
Text: &msg.Content,
})
}
// Handle tool calls in assistant messages
for _, tc := range msg.ToolCalls {
var input json.RawMessage
if tc.Arguments != "" {
input = json.RawMessage(tc.Arguments)
} else {
input = json.RawMessage("{}")
}
m.Content = append(m.Content, anth.MessageContent{
Type: anth.MessagesContentTypeToolUse,
MessageContentToolUse: &anth.MessageContentToolUse{
ID: tc.ID,
Name: tc.Name,
Input: input,
},
})
}
// Handle images
for _, img := range msg.Images {
if role == anth.RoleAssistant {
role = anth.RoleUser
m.Role = anth.RoleUser
}
if img.Base64 != "" {
b64 := img.Base64
contentType := img.ContentType
// Compress if > 5MiB
raw, err := base64.StdEncoding.DecodeString(b64)
if err == nil && len(raw) >= 5242880 {
compressed, mime, cerr := imageutil.CompressImage(b64, 5*1024*1024)
if cerr == nil {
b64 = compressed
contentType = mime
}
}
m.Content = append(m.Content, anth.NewImageMessageContent(
anth.NewMessageContentSource(
anth.MessagesContentSourceTypeBase64,
contentType,
b64,
)))
} else if img.URL != "" {
// Download and convert to base64 (Anthropic doesn't support URLs directly)
resp, err := http.Get(img.URL)
if err != nil {
continue
}
data, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
continue
}
contentType := resp.Header.Get("Content-Type")
b64 := base64.StdEncoding.EncodeToString(data)
m.Content = append(m.Content, anth.NewImageMessageContent(
anth.NewMessageContentSource(
anth.MessagesContentSourceTypeBase64,
contentType,
b64,
)))
}
}
// Merge consecutive same-role messages (Anthropic requires alternating)
if len(msgs) > 0 && msgs[len(msgs)-1].Role == role {
msgs[len(msgs)-1].Content = append(msgs[len(msgs)-1].Content, m.Content...)
} else {
msgs = append(msgs, m)
}
}
for _, tool := range req.Tools {
anthReq.Tools = append(anthReq.Tools, anth.ToolDefinition{
Name: tool.Name,
Description: tool.Description,
InputSchema: tool.Schema,
})
}
anthReq.Messages = msgs
if req.Temperature != nil {
f := float32(*req.Temperature)
anthReq.Temperature = &f
}
if req.TopP != nil {
f := float32(*req.TopP)
anthReq.TopP = &f
}
if len(req.Stop) > 0 {
anthReq.StopSequences = req.Stop
}
return anthReq
}
func (p *Provider) convertResponse(resp anth.MessagesResponse) provider.Response {
var res provider.Response
var textParts []string
for _, block := range resp.Content {
switch block.Type {
case anth.MessagesContentTypeText:
if block.Text != nil {
textParts = append(textParts, *block.Text)
}
case anth.MessagesContentTypeToolUse:
if block.MessageContentToolUse != nil {
args, _ := json.Marshal(block.MessageContentToolUse.Input)
res.ToolCalls = append(res.ToolCalls, provider.ToolCall{
ID: block.MessageContentToolUse.ID,
Name: block.MessageContentToolUse.Name,
Arguments: string(args),
})
}
}
}
res.Text = strings.Join(textParts, "")
res.Usage = &provider.Usage{
InputTokens: resp.Usage.InputTokens,
OutputTokens: resp.Usage.OutputTokens,
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
}
return res
}

153
v2/chat.go Normal file
View File

@@ -0,0 +1,153 @@
package llm
import (
"context"
"fmt"
)
// Chat manages a multi-turn conversation with automatic history tracking
// and optional automatic tool-call execution.
type Chat struct {
model *Model
messages []Message
tools *ToolBox
opts []RequestOption
}
// NewChat creates a new conversation with the given model.
func NewChat(model *Model, opts ...RequestOption) *Chat {
return &Chat{
model: model,
opts: opts,
}
}
// SetSystem sets or replaces the system message.
func (c *Chat) SetSystem(text string) {
filtered := make([]Message, 0, len(c.messages)+1)
for _, m := range c.messages {
if m.Role != RoleSystem {
filtered = append(filtered, m)
}
}
c.messages = append([]Message{SystemMessage(text)}, filtered...)
}
// SetTools configures the tools available for this chat.
func (c *Chat) SetTools(tb *ToolBox) {
c.tools = tb
}
// Send sends a user message and returns the assistant's text response.
// If the model calls tools, they are executed automatically and the loop
// continues until the model produces a text response (the "agent loop").
func (c *Chat) Send(ctx context.Context, text string) (string, error) {
return c.SendMessage(ctx, UserMessage(text))
}
// SendWithImages sends a user message with images attached.
func (c *Chat) SendWithImages(ctx context.Context, text string, images ...Image) (string, error) {
return c.SendMessage(ctx, UserMessageWithImages(text, images...))
}
// SendMessage sends an arbitrary message and returns the final text response.
// Handles the full tool-call loop automatically.
func (c *Chat) SendMessage(ctx context.Context, msg Message) (string, error) {
c.messages = append(c.messages, msg)
opts := c.buildOpts()
for {
resp, err := c.model.Complete(ctx, c.messages, opts...)
if err != nil {
return "", fmt.Errorf("completion failed: %w", err)
}
c.messages = append(c.messages, resp.Message())
if !resp.HasToolCalls() {
return resp.Text, nil
}
if c.tools == nil {
return "", ErrNoToolsConfigured
}
toolResults, err := c.tools.ExecuteAll(ctx, resp.ToolCalls)
if err != nil {
return "", fmt.Errorf("tool execution failed: %w", err)
}
c.messages = append(c.messages, toolResults...)
}
}
// SendRaw sends a message and returns the raw Response without automatic tool execution.
// Useful when you want to handle tool calls manually.
func (c *Chat) SendRaw(ctx context.Context, msg Message) (Response, error) {
c.messages = append(c.messages, msg)
opts := c.buildOpts()
resp, err := c.model.Complete(ctx, c.messages, opts...)
if err != nil {
return Response{}, err
}
c.messages = append(c.messages, resp.Message())
return resp, nil
}
// SendStream sends a user message and returns a StreamReader for streaming responses.
func (c *Chat) SendStream(ctx context.Context, text string) (*StreamReader, error) {
c.messages = append(c.messages, UserMessage(text))
opts := c.buildOpts()
cfg := &requestConfig{}
for _, opt := range opts {
opt(cfg)
}
req := buildProviderRequest(c.model.model, c.messages, cfg)
return newStreamReader(ctx, c.model.provider, req)
}
// AddToolResults manually adds tool results to the conversation.
// Use with SendRaw when handling tool calls manually.
func (c *Chat) AddToolResults(results ...Message) {
c.messages = append(c.messages, results...)
}
// Messages returns the current conversation history (read-only copy).
func (c *Chat) Messages() []Message {
cp := make([]Message, len(c.messages))
copy(cp, c.messages)
return cp
}
// Reset clears the conversation history.
func (c *Chat) Reset() {
c.messages = nil
}
// Fork creates a copy of this chat with identical history, for branching conversations.
func (c *Chat) Fork() *Chat {
c2 := &Chat{
model: c.model,
messages: make([]Message, len(c.messages)),
tools: c.tools,
opts: c.opts,
}
copy(c2.messages, c.messages)
return c2
}
func (c *Chat) buildOpts() []RequestOption {
opts := make([]RequestOption, len(c.opts))
copy(opts, c.opts)
if c.tools != nil {
opts = append(opts, WithTools(c.tools))
}
return opts
}

48
v2/constructors.go Normal file
View File

@@ -0,0 +1,48 @@
package llm
import (
anthProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/anthropic"
googleProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/google"
openaiProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openai"
)
// OpenAI creates an OpenAI client.
//
// Example:
//
// model := llm.OpenAI("sk-...").Model("gpt-4o")
func OpenAI(apiKey string, opts ...ClientOption) *Client {
cfg := &clientConfig{}
for _, opt := range opts {
opt(cfg)
}
return newClient(openaiProvider.New(apiKey, cfg.baseURL))
}
// Anthropic creates an Anthropic client.
//
// Example:
//
// model := llm.Anthropic("sk-ant-...").Model("claude-sonnet-4-20250514")
func Anthropic(apiKey string, opts ...ClientOption) *Client {
cfg := &clientConfig{}
for _, opt := range opts {
opt(cfg)
}
_ = cfg // Anthropic doesn't support custom base URL in the SDK
return newClient(anthProvider.New(apiKey))
}
// Google creates a Google (Gemini) client.
//
// Example:
//
// model := llm.Google("...").Model("gemini-2.0-flash")
func Google(apiKey string, opts ...ClientOption) *Client {
cfg := &clientConfig{}
for _, opt := range opts {
opt(cfg)
}
_ = cfg // Google doesn't support custom base URL in the SDK
return newClient(googleProvider.New(apiKey))
}

17
v2/errors.go Normal file
View File

@@ -0,0 +1,17 @@
package llm
import "errors"
var (
// ErrNoToolsConfigured is returned when the model requests tool calls but no tools are available.
ErrNoToolsConfigured = errors.New("model requested tool calls but no tools configured")
// ErrToolNotFound is returned when a requested tool is not in the toolbox.
ErrToolNotFound = errors.New("tool not found")
// ErrNotConnected is returned when trying to use an MCP server that isn't connected.
ErrNotConnected = errors.New("MCP server not connected")
// ErrStreamClosed is returned when trying to read from a closed stream.
ErrStreamClosed = errors.New("stream closed")
)

39
v2/go.mod Normal file
View File

@@ -0,0 +1,39 @@
module gitea.stevedudenhoeffer.com/steve/go-llm/v2
go 1.24.0
toolchain go1.24.2
require (
github.com/liushuangls/go-anthropic/v2 v2.17.0
github.com/modelcontextprotocol/go-sdk v1.2.0
github.com/openai/openai-go v1.12.0
golang.org/x/image v0.35.0
google.golang.org/genai v1.45.0
)
require (
cloud.google.com/go v0.116.0 // indirect
cloud.google.com/go/auth v0.9.3 // indirect
cloud.google.com/go/compute/metadata v0.5.0 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/jsonschema-go v0.3.0 // indirect
github.com/google/s2a-go v0.1.8 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/tidwall/gjson v1.14.4 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
go.opencensus.io v0.24.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/oauth2 v0.30.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.33.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
google.golang.org/grpc v1.66.2 // indirect
google.golang.org/protobuf v1.34.2 // indirect
)

159
v2/go.sum Normal file
View File

@@ -0,0 +1,159 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE=
cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U=
cloud.google.com/go/auth v0.9.3 h1:VOEUIAADkkLtyfr3BLa3R8Ed/j6w1jTBmARx+wb5w5U=
cloud.google.com/go/auth v0.9.3/go.mod h1:7z6VY+7h3KUdRov5F1i8NDP5ZzWKYmEPO842BgCsmTk=
cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY=
cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q=
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM=
github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw=
github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/liushuangls/go-anthropic/v2 v2.17.0 h1:iBA6h7aghi1q86owEQ95XE2R2MF/0dQ7bCxtwTxOg4c=
github.com/liushuangls/go-anthropic/v2 v2.17.0/go.mod h1:a550cJXPoTG2FL3DvfKG2zzD5O2vjgvo4tHtoGPzFLU=
github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s=
github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10=
github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0=
github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/image v0.35.0 h1:LKjiHdgMtO8z7Fh18nGY6KDcoEtVfsgLDPeLyguqb7I=
golang.org/x/image v0.35.0/go.mod h1:MwPLTVgvxSASsxdLzKrl8BRFuyqMyGhLwmC+TO1Sybk=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/genai v1.45.0 h1:s80ZpS42XW0zu/ogiOtenCio17nJ7reEFJjoCftukpA=
google.golang.org/genai v1.45.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
google.golang.org/grpc v1.66.2 h1:3QdXkuq3Bkh7w+ywLdLvM56cmGvQHUMZpiCzt6Rqaoo=
google.golang.org/grpc v1.66.2/go.mod h1:s3/l6xSSCURdVfAnL+TqCNMyTDAGN6+lZeVxnZR128Y=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

322
v2/google/google.go Normal file
View File

@@ -0,0 +1,322 @@
// Package google implements the go-llm v2 provider interface for Google (Gemini).
package google
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
"google.golang.org/genai"
)
// Provider implements the provider.Provider interface for Google Gemini.
type Provider struct {
apiKey string
}
// New creates a new Google provider.
func New(apiKey string) *Provider {
return &Provider{apiKey: apiKey}
}
// Complete performs a non-streaming completion.
func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
cl, err := genai.NewClient(ctx, &genai.ClientConfig{
APIKey: p.apiKey,
Backend: genai.BackendGeminiAPI,
})
if err != nil {
return provider.Response{}, fmt.Errorf("google client error: %w", err)
}
contents, cfg := p.buildRequest(req)
resp, err := cl.Models.GenerateContent(ctx, req.Model, contents, cfg)
if err != nil {
return provider.Response{}, fmt.Errorf("google completion error: %w", err)
}
return p.convertResponse(resp)
}
// Stream performs a streaming completion.
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
cl, err := genai.NewClient(ctx, &genai.ClientConfig{
APIKey: p.apiKey,
Backend: genai.BackendGeminiAPI,
})
if err != nil {
return fmt.Errorf("google client error: %w", err)
}
contents, cfg := p.buildRequest(req)
var fullText strings.Builder
var toolCalls []provider.ToolCall
for resp, err := range cl.Models.GenerateContentStream(ctx, req.Model, contents, cfg) {
if err != nil {
return fmt.Errorf("google stream error: %w", err)
}
for _, c := range resp.Candidates {
if c.Content == nil {
continue
}
for _, part := range c.Content.Parts {
if part.Text != "" {
fullText.WriteString(part.Text)
events <- provider.StreamEvent{
Type: provider.StreamEventText,
Text: part.Text,
}
}
if part.FunctionCall != nil {
args, _ := json.Marshal(part.FunctionCall.Args)
tc := provider.ToolCall{
ID: part.FunctionCall.Name,
Name: part.FunctionCall.Name,
Arguments: string(args),
}
toolCalls = append(toolCalls, tc)
events <- provider.StreamEvent{
Type: provider.StreamEventToolStart,
ToolCall: &tc,
ToolIndex: len(toolCalls) - 1,
}
events <- provider.StreamEvent{
Type: provider.StreamEventToolEnd,
ToolCall: &tc,
ToolIndex: len(toolCalls) - 1,
}
}
}
}
}
events <- provider.StreamEvent{
Type: provider.StreamEventDone,
Response: &provider.Response{
Text: fullText.String(),
ToolCalls: toolCalls,
},
}
return nil
}
func (p *Provider) buildRequest(req provider.Request) ([]*genai.Content, *genai.GenerateContentConfig) {
var contents []*genai.Content
cfg := &genai.GenerateContentConfig{}
for _, tool := range req.Tools {
cfg.Tools = append(cfg.Tools, &genai.Tool{
FunctionDeclarations: []*genai.FunctionDeclaration{
{
Name: tool.Name,
Description: tool.Description,
Parameters: schemaToGenai(tool.Schema),
},
},
})
}
if req.Temperature != nil {
f := float32(*req.Temperature)
cfg.Temperature = &f
}
if req.MaxTokens != nil {
cfg.MaxOutputTokens = int32(*req.MaxTokens)
}
if req.TopP != nil {
f := float32(*req.TopP)
cfg.TopP = &f
}
if len(req.Stop) > 0 {
cfg.StopSequences = req.Stop
}
for _, msg := range req.Messages {
var role genai.Role
switch msg.Role {
case "system":
cfg.SystemInstruction = genai.NewContentFromText(msg.Content, genai.RoleUser)
continue
case "assistant":
role = genai.RoleModel
case "tool":
// Tool results go as function responses (Genai uses RoleUser for function responses)
contents = append(contents, &genai.Content{
Role: genai.RoleUser,
Parts: []*genai.Part{
{
FunctionResponse: &genai.FunctionResponse{
Name: msg.ToolCallID,
Response: map[string]any{
"result": msg.Content,
},
},
},
},
})
continue
default:
role = genai.RoleUser
}
var parts []*genai.Part
if msg.Content != "" {
parts = append(parts, genai.NewPartFromText(msg.Content))
}
// Handle tool calls in assistant messages
for _, tc := range msg.ToolCalls {
var args map[string]any
if tc.Arguments != "" {
_ = json.Unmarshal([]byte(tc.Arguments), &args)
}
parts = append(parts, &genai.Part{
FunctionCall: &genai.FunctionCall{
Name: tc.Name,
Args: args,
},
})
}
for _, img := range msg.Images {
if img.URL != "" {
// Gemini doesn't support URLs directly; download
resp, err := http.Get(img.URL)
if err != nil {
continue
}
data, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
continue
}
mimeType := http.DetectContentType(data)
parts = append(parts, genai.NewPartFromBytes(data, mimeType))
} else if img.Base64 != "" {
data, err := base64.StdEncoding.DecodeString(img.Base64)
if err != nil {
continue
}
parts = append(parts, genai.NewPartFromBytes(data, img.ContentType))
}
}
contents = append(contents, genai.NewContentFromParts(parts, role))
}
return contents, cfg
}
func (p *Provider) convertResponse(resp *genai.GenerateContentResponse) (provider.Response, error) {
var res provider.Response
for _, c := range resp.Candidates {
if c.Content == nil {
continue
}
for _, part := range c.Content.Parts {
if part.Text != "" {
res.Text += part.Text
}
if part.FunctionCall != nil {
args, _ := json.Marshal(part.FunctionCall.Args)
res.ToolCalls = append(res.ToolCalls, provider.ToolCall{
ID: part.FunctionCall.Name,
Name: part.FunctionCall.Name,
Arguments: string(args),
})
}
}
}
if resp.UsageMetadata != nil {
res.Usage = &provider.Usage{
InputTokens: int(resp.UsageMetadata.PromptTokenCount),
OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount),
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
}
}
return res, nil
}
// schemaToGenai converts a JSON Schema map to a genai.Schema.
func schemaToGenai(s map[string]any) *genai.Schema {
if s == nil {
return nil
}
schema := &genai.Schema{}
if t, ok := s["type"].(string); ok {
switch t {
case "object":
schema.Type = genai.TypeObject
case "array":
schema.Type = genai.TypeArray
case "string":
schema.Type = genai.TypeString
case "integer":
schema.Type = genai.TypeInteger
case "number":
schema.Type = genai.TypeNumber
case "boolean":
schema.Type = genai.TypeBoolean
}
}
if desc, ok := s["description"].(string); ok {
schema.Description = desc
}
if props, ok := s["properties"].(map[string]any); ok {
schema.Properties = make(map[string]*genai.Schema)
for k, v := range props {
if vm, ok := v.(map[string]any); ok {
schema.Properties[k] = schemaToGenai(vm)
}
}
}
if req, ok := s["required"].([]string); ok {
schema.Required = req
} else if req, ok := s["required"].([]any); ok {
for _, r := range req {
if rs, ok := r.(string); ok {
schema.Required = append(schema.Required, rs)
}
}
}
if items, ok := s["items"].(map[string]any); ok {
schema.Items = schemaToGenai(items)
}
if enums, ok := s["enum"].([]string); ok {
schema.Enum = enums
} else if enums, ok := s["enum"].([]any); ok {
for _, e := range enums {
if es, ok := e.(string); ok {
schema.Enum = append(schema.Enum, es)
}
}
}
return schema
}

View File

@@ -0,0 +1,105 @@
// Package imageutil provides image compression utilities.
package imageutil
import (
"bytes"
"encoding/base64"
"fmt"
"image"
"image/gif"
"image/jpeg"
_ "image/png" // register PNG decoder
"net/http"
"golang.org/x/image/draw"
)
// CompressImage takes a base-64-encoded image (JPEG, PNG or GIF) and returns
// a base-64-encoded version that is at most maxLength bytes, along with the MIME type.
func CompressImage(b64 string, maxLength int) (string, string, error) {
raw, err := base64.StdEncoding.DecodeString(b64)
if err != nil {
return "", "", fmt.Errorf("base64 decode: %w", err)
}
mime := http.DetectContentType(raw)
if len(raw) <= maxLength {
return b64, mime, nil
}
switch mime {
case "image/gif":
return compressGIF(raw, maxLength)
default:
return compressRaster(raw, maxLength)
}
}
func compressRaster(src []byte, maxLength int) (string, string, error) {
img, _, err := image.Decode(bytes.NewReader(src))
if err != nil {
return "", "", fmt.Errorf("decode raster: %w", err)
}
quality := 95
for {
var buf bytes.Buffer
if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: quality}); err != nil {
return "", "", fmt.Errorf("jpeg encode: %w", err)
}
if buf.Len() <= maxLength {
return base64.StdEncoding.EncodeToString(buf.Bytes()), "image/jpeg", nil
}
if quality > 20 {
quality -= 5
continue
}
b := img.Bounds()
if b.Dx() < 100 || b.Dy() < 100 {
return "", "", fmt.Errorf("cannot compress below %.02fMiB without destroying image", float64(maxLength)/1048576.0)
}
dst := image.NewRGBA(image.Rect(0, 0, int(float64(b.Dx())*0.8), int(float64(b.Dy())*0.8)))
draw.ApproxBiLinear.Scale(dst, dst.Bounds(), img, b, draw.Over, nil)
img = dst
quality = 95
}
}
func compressGIF(src []byte, maxLength int) (string, string, error) {
g, err := gif.DecodeAll(bytes.NewReader(src))
if err != nil {
return "", "", fmt.Errorf("gif decode: %w", err)
}
for {
var buf bytes.Buffer
if err := gif.EncodeAll(&buf, g); err != nil {
return "", "", fmt.Errorf("gif encode: %w", err)
}
if buf.Len() <= maxLength {
return base64.StdEncoding.EncodeToString(buf.Bytes()), "image/gif", nil
}
w, h := g.Config.Width, g.Config.Height
if w < 100 || h < 100 {
return "", "", fmt.Errorf("cannot compress animated GIF below %.02fMiB", float64(maxLength)/1048576.0)
}
nw, nh := int(float64(w)*0.8), int(float64(h)*0.8)
for i, frm := range g.Image {
rgba := image.NewRGBA(frm.Bounds())
draw.Draw(rgba, rgba.Bounds(), frm, frm.Bounds().Min, draw.Src)
dst := image.NewRGBA(image.Rect(0, 0, nw, nh))
draw.ApproxBiLinear.Scale(dst, dst.Bounds(), rgba, rgba.Bounds(), draw.Over, nil)
paletted := image.NewPaletted(dst.Bounds(), nil)
draw.FloydSteinberg.Draw(paletted, paletted.Bounds(), dst, dst.Bounds().Min)
g.Image[i] = paletted
}
g.Config.Width, g.Config.Height = nw, nh
}
}

View File

@@ -0,0 +1,188 @@
// Package schema provides JSON Schema generation from Go structs.
// It produces standard JSON Schema as map[string]any, with no provider-specific types.
package schema
import (
"reflect"
"strings"
)
// FromStruct generates a JSON Schema object from a Go struct.
// Struct tags used:
// - `json:"name"` — sets the property name (standard Go JSON convention)
// - `description:"..."` — sets the property description
// - `enum:"a,b,c"` — restricts string values to the given set
//
// Pointer fields are treated as optional; non-pointer fields are required.
// Anonymous (embedded) struct fields are flattened into the parent.
func FromStruct(v any) map[string]any {
t := reflect.TypeOf(v)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
panic("schema.FromStruct expects a struct or pointer to struct")
}
return objectSchema(t)
}
func objectSchema(t reflect.Type) map[string]any {
properties := map[string]any{}
var required []string
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
// Skip unexported fields
if !field.IsExported() {
continue
}
// Flatten anonymous (embedded) structs
if field.Anonymous {
ft := field.Type
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
if ft.Kind() == reflect.Struct {
embedded := objectSchema(ft)
if props, ok := embedded["properties"].(map[string]any); ok {
for k, v := range props {
properties[k] = v
}
}
if req, ok := embedded["required"].([]string); ok {
required = append(required, req...)
}
}
continue
}
name := fieldName(field)
isRequired := true
ft := field.Type
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
isRequired = false
}
prop := fieldSchema(field, ft)
properties[name] = prop
if isRequired {
required = append(required, name)
}
}
result := map[string]any{
"type": "object",
"properties": properties,
}
if len(required) > 0 {
result["required"] = required
}
return result
}
func fieldSchema(field reflect.StructField, ft reflect.Type) map[string]any {
prop := map[string]any{}
// Check for enum tag first
if enumTag, ok := field.Tag.Lookup("enum"); ok {
vals := parseEnum(enumTag)
prop["type"] = "string"
prop["enum"] = vals
if desc, ok := field.Tag.Lookup("description"); ok {
prop["description"] = desc
}
return prop
}
switch ft.Kind() {
case reflect.String:
prop["type"] = "string"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
prop["type"] = "integer"
case reflect.Float32, reflect.Float64:
prop["type"] = "number"
case reflect.Bool:
prop["type"] = "boolean"
case reflect.Struct:
return objectSchema(ft)
case reflect.Slice:
prop["type"] = "array"
elemType := ft.Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
prop["items"] = typeSchema(elemType)
case reflect.Map:
prop["type"] = "object"
if ft.Key().Kind() == reflect.String {
valType := ft.Elem()
if valType.Kind() == reflect.Ptr {
valType = valType.Elem()
}
prop["additionalProperties"] = typeSchema(valType)
}
default:
prop["type"] = "string" // fallback
}
if desc, ok := field.Tag.Lookup("description"); ok {
prop["description"] = desc
}
return prop
}
func typeSchema(t reflect.Type) map[string]any {
switch t.Kind() {
case reflect.String:
return map[string]any{"type": "string"}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return map[string]any{"type": "integer"}
case reflect.Float32, reflect.Float64:
return map[string]any{"type": "number"}
case reflect.Bool:
return map[string]any{"type": "boolean"}
case reflect.Struct:
return objectSchema(t)
case reflect.Slice:
elemType := t.Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
return map[string]any{
"type": "array",
"items": typeSchema(elemType),
}
default:
return map[string]any{"type": "string"}
}
}
func fieldName(f reflect.StructField) string {
if tag, ok := f.Tag.Lookup("json"); ok {
parts := strings.SplitN(tag, ",", 2)
if parts[0] != "" && parts[0] != "-" {
return parts[0]
}
}
return f.Name
}
func parseEnum(tag string) []string {
parts := strings.Split(tag, ",")
var vals []string
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
vals = append(vals, p)
}
}
return vals
}

View File

@@ -0,0 +1,181 @@
package schema
import (
"encoding/json"
"testing"
)
type SimpleParams struct {
Name string `json:"name" description:"The name"`
Age int `json:"age" description:"The age"`
}
type OptionalParams struct {
Required string `json:"required" description:"A required field"`
Optional *string `json:"optional,omitempty" description:"An optional field"`
}
type EnumParams struct {
Color string `json:"color" description:"The color" enum:"red,green,blue"`
}
type NestedParams struct {
Inner SimpleParams `json:"inner" description:"Nested object"`
}
type ArrayParams struct {
Items []string `json:"items" description:"A list of items"`
}
type EmbeddedBase struct {
ID string `json:"id" description:"The ID"`
}
type EmbeddedParams struct {
EmbeddedBase
Name string `json:"name" description:"The name"`
}
func TestFromStruct_Simple(t *testing.T) {
s := FromStruct(SimpleParams{})
if s["type"] != "object" {
t.Errorf("expected type=object, got %v", s["type"])
}
props, ok := s["properties"].(map[string]any)
if !ok {
t.Fatal("expected properties to be map[string]any")
}
if len(props) != 2 {
t.Errorf("expected 2 properties, got %d", len(props))
}
nameSchema, ok := props["name"].(map[string]any)
if !ok {
t.Fatal("expected name property to be map[string]any")
}
if nameSchema["type"] != "string" {
t.Errorf("expected name type=string, got %v", nameSchema["type"])
}
if nameSchema["description"] != "The name" {
t.Errorf("expected name description='The name', got %v", nameSchema["description"])
}
ageSchema, ok := props["age"].(map[string]any)
if !ok {
t.Fatal("expected age property to be map[string]any")
}
if ageSchema["type"] != "integer" {
t.Errorf("expected age type=integer, got %v", ageSchema["type"])
}
required, ok := s["required"].([]string)
if !ok {
t.Fatal("expected required to be []string")
}
if len(required) != 2 {
t.Errorf("expected 2 required fields, got %d", len(required))
}
}
func TestFromStruct_Optional(t *testing.T) {
s := FromStruct(OptionalParams{})
required, ok := s["required"].([]string)
if !ok {
t.Fatal("expected required to be []string")
}
// Only "required" field should be required, not "optional"
if len(required) != 1 {
t.Errorf("expected 1 required field, got %d: %v", len(required), required)
}
if required[0] != "required" {
t.Errorf("expected required field 'required', got %v", required[0])
}
}
func TestFromStruct_Enum(t *testing.T) {
s := FromStruct(EnumParams{})
props := s["properties"].(map[string]any)
colorSchema := props["color"].(map[string]any)
if colorSchema["type"] != "string" {
t.Errorf("expected enum type=string, got %v", colorSchema["type"])
}
enums, ok := colorSchema["enum"].([]string)
if !ok {
t.Fatal("expected enum to be []string")
}
if len(enums) != 3 {
t.Errorf("expected 3 enum values, got %d", len(enums))
}
}
func TestFromStruct_Nested(t *testing.T) {
s := FromStruct(NestedParams{})
props := s["properties"].(map[string]any)
innerSchema := props["inner"].(map[string]any)
if innerSchema["type"] != "object" {
t.Errorf("expected nested type=object, got %v", innerSchema["type"])
}
innerProps := innerSchema["properties"].(map[string]any)
if len(innerProps) != 2 {
t.Errorf("expected 2 inner properties, got %d", len(innerProps))
}
}
func TestFromStruct_Array(t *testing.T) {
s := FromStruct(ArrayParams{})
props := s["properties"].(map[string]any)
itemsSchema := props["items"].(map[string]any)
if itemsSchema["type"] != "array" {
t.Errorf("expected array type=array, got %v", itemsSchema["type"])
}
items := itemsSchema["items"].(map[string]any)
if items["type"] != "string" {
t.Errorf("expected items type=string, got %v", items["type"])
}
}
func TestFromStruct_Embedded(t *testing.T) {
s := FromStruct(EmbeddedParams{})
props := s["properties"].(map[string]any)
// Should have both ID from embedded and Name
if len(props) != 2 {
t.Errorf("expected 2 properties (flattened), got %d", len(props))
}
if _, ok := props["id"]; !ok {
t.Error("expected 'id' property from embedded struct")
}
if _, ok := props["name"]; !ok {
t.Error("expected 'name' property")
}
}
func TestFromStruct_ValidJSON(t *testing.T) {
s := FromStruct(SimpleParams{})
data, err := json.Marshal(s)
if err != nil {
t.Fatalf("schema should be valid JSON: %v", err)
}
var parsed map[string]any
if err := json.Unmarshal(data, &parsed); err != nil {
t.Fatalf("schema should round-trip through JSON: %v", err)
}
}

199
v2/llm.go Normal file
View File

@@ -0,0 +1,199 @@
package llm
import (
"context"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
// Client represents an LLM provider. Create with OpenAI(), Anthropic(), Google().
type Client struct {
p provider.Provider
middleware []Middleware
}
// newClient creates a Client backed by the given provider.
func newClient(p provider.Provider) *Client {
return &Client{p: p}
}
// Model returns a Model for the specified model version.
func (c *Client) Model(modelVersion string) *Model {
return &Model{
provider: c.p,
model: modelVersion,
middleware: c.middleware,
}
}
// WithMiddleware returns a new Client with additional middleware applied to all models.
func (c *Client) WithMiddleware(mw ...Middleware) *Client {
c2 := &Client{
p: c.p,
middleware: append(append([]Middleware{}, c.middleware...), mw...),
}
return c2
}
// Model represents a specific model from a provider, ready for completions.
type Model struct {
provider provider.Provider
model string
middleware []Middleware
}
// Complete sends a non-streaming completion request.
func (m *Model) Complete(ctx context.Context, messages []Message, opts ...RequestOption) (Response, error) {
cfg := &requestConfig{}
for _, opt := range opts {
opt(cfg)
}
chain := m.buildChain()
return chain(ctx, m.model, messages, cfg)
}
// Stream sends a streaming completion request, returning a StreamReader.
func (m *Model) Stream(ctx context.Context, messages []Message, opts ...RequestOption) (*StreamReader, error) {
cfg := &requestConfig{}
for _, opt := range opts {
opt(cfg)
}
req := buildProviderRequest(m.model, messages, cfg)
return newStreamReader(ctx, m.provider, req)
}
// WithMiddleware returns a new Model with additional middleware applied.
func (m *Model) WithMiddleware(mw ...Middleware) *Model {
return &Model{
provider: m.provider,
model: m.model,
middleware: append(append([]Middleware{}, m.middleware...), mw...),
}
}
func (m *Model) buildChain() CompletionFunc {
// Base handler that calls the provider
base := func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
req := buildProviderRequest(model, messages, cfg)
resp, err := m.provider.Complete(ctx, req)
if err != nil {
return Response{}, err
}
return convertProviderResponse(resp), nil
}
// Apply middleware in reverse order (first middleware wraps outermost)
chain := base
for i := len(m.middleware) - 1; i >= 0; i-- {
chain = m.middleware[i](chain)
}
return chain
}
func buildProviderRequest(model string, messages []Message, cfg *requestConfig) provider.Request {
req := provider.Request{
Model: model,
Messages: convertMessages(messages),
}
if cfg.temperature != nil {
req.Temperature = cfg.temperature
}
if cfg.maxTokens != nil {
req.MaxTokens = cfg.maxTokens
}
if cfg.topP != nil {
req.TopP = cfg.topP
}
if len(cfg.stop) > 0 {
req.Stop = cfg.stop
}
if cfg.tools != nil {
for _, tool := range cfg.tools.AllTools() {
req.Tools = append(req.Tools, provider.ToolDef{
Name: tool.Name,
Description: tool.Description,
Schema: tool.Schema,
})
}
}
return req
}
func convertMessages(msgs []Message) []provider.Message {
out := make([]provider.Message, len(msgs))
for i, m := range msgs {
pm := provider.Message{
Role: string(m.Role),
Content: m.Content.Text,
ToolCallID: m.ToolCallID,
}
for _, img := range m.Content.Images {
pm.Images = append(pm.Images, provider.Image{
URL: img.URL,
Base64: img.Base64,
ContentType: img.ContentType,
})
}
for _, tc := range m.ToolCalls {
pm.ToolCalls = append(pm.ToolCalls, provider.ToolCall{
ID: tc.ID,
Name: tc.Name,
Arguments: tc.Arguments,
})
}
out[i] = pm
}
return out
}
func convertProviderResponse(resp provider.Response) Response {
r := Response{
Text: resp.Text,
}
for _, tc := range resp.ToolCalls {
r.ToolCalls = append(r.ToolCalls, ToolCall{
ID: tc.ID,
Name: tc.Name,
Arguments: tc.Arguments,
})
}
if resp.Usage != nil {
r.Usage = &Usage{
InputTokens: resp.Usage.InputTokens,
OutputTokens: resp.Usage.OutputTokens,
TotalTokens: resp.Usage.TotalTokens,
}
}
// Build the assistant message for conversation history
r.message = Message{
Role: RoleAssistant,
Content: Content{Text: resp.Text},
ToolCalls: r.ToolCalls,
}
return r
}
// --- Provider constructors ---
// These are defined here and delegate to provider-specific packages.
// They are set up via init() in the provider packages, or defined directly.
// ClientOption configures a client.
type ClientOption func(*clientConfig)
type clientConfig struct {
baseURL string
}
// WithBaseURL overrides the API base URL.
func WithBaseURL(url string) ClientOption {
return func(c *clientConfig) { c.baseURL = url }
}

264
v2/mcp.go Normal file
View File

@@ -0,0 +1,264 @@
package llm
import (
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"sync"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// MCPTransport specifies how to connect to an MCP server.
type MCPTransport string
const (
MCPStdio MCPTransport = "stdio"
MCPSSE MCPTransport = "sse"
MCPHTTP MCPTransport = "http"
)
// MCPServer represents a connection to an MCP server.
type MCPServer struct {
name string
transport MCPTransport
// stdio fields
command string
args []string
env []string
// network fields
url string
// internal
client *mcp.Client
session *mcp.ClientSession
tools map[string]*mcp.Tool
mu sync.RWMutex
}
// MCPOption configures an MCP server.
type MCPOption func(*MCPServer)
// WithMCPEnv adds environment variables for the subprocess.
func WithMCPEnv(env ...string) MCPOption {
return func(s *MCPServer) { s.env = env }
}
// WithMCPName sets a friendly name for logging.
func WithMCPName(name string) MCPOption {
return func(s *MCPServer) { s.name = name }
}
// MCPStdioServer creates and connects to an MCP server via stdio transport.
//
// Example:
//
// server, err := llm.MCPStdioServer(ctx, "npx", "-y", "@modelcontextprotocol/server-filesystem", "/tmp")
func MCPStdioServer(ctx context.Context, command string, args ...string) (*MCPServer, error) {
s := &MCPServer{
name: command,
transport: MCPStdio,
command: command,
args: args,
}
if err := s.connect(ctx); err != nil {
return nil, err
}
return s, nil
}
// MCPHTTPServer creates and connects to an MCP server via streamable HTTP transport.
//
// Example:
//
// server, err := llm.MCPHTTPServer(ctx, "https://mcp.example.com")
func MCPHTTPServer(ctx context.Context, url string, opts ...MCPOption) (*MCPServer, error) {
s := &MCPServer{
name: url,
transport: MCPHTTP,
url: url,
}
for _, opt := range opts {
opt(s)
}
if err := s.connect(ctx); err != nil {
return nil, err
}
return s, nil
}
// MCPSSEServer creates and connects to an MCP server via SSE transport.
func MCPSSEServer(ctx context.Context, url string, opts ...MCPOption) (*MCPServer, error) {
s := &MCPServer{
name: url,
transport: MCPSSE,
url: url,
}
for _, opt := range opts {
opt(s)
}
if err := s.connect(ctx); err != nil {
return nil, err
}
return s, nil
}
func (s *MCPServer) connect(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.session != nil {
return nil
}
s.client = mcp.NewClient(&mcp.Implementation{
Name: "go-llm-v2",
Version: "2.0.0",
}, nil)
var transport mcp.Transport
switch s.transport {
case MCPSSE:
transport = &mcp.SSEClientTransport{
Endpoint: s.url,
}
case MCPHTTP:
transport = &mcp.StreamableClientTransport{
Endpoint: s.url,
}
default: // stdio
cmd := exec.Command(s.command, s.args...)
cmd.Env = append(os.Environ(), s.env...)
transport = &mcp.CommandTransport{
Command: cmd,
}
}
session, err := s.client.Connect(ctx, transport, nil)
if err != nil {
return fmt.Errorf("failed to connect to MCP server %s: %w", s.name, err)
}
s.session = session
// Load tools
s.tools = make(map[string]*mcp.Tool)
for tool, err := range session.Tools(ctx, nil) {
if err != nil {
s.session.Close()
s.session = nil
return fmt.Errorf("failed to list tools from %s: %w", s.name, err)
}
s.tools[tool.Name] = tool
}
return nil
}
// Close closes the connection to the MCP server.
func (s *MCPServer) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.session == nil {
return nil
}
err := s.session.Close()
s.session = nil
s.tools = nil
return err
}
// IsConnected returns true if the server is connected.
func (s *MCPServer) IsConnected() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.session != nil
}
// ListTools returns Tool definitions for all tools this server provides.
func (s *MCPServer) ListTools() []Tool {
s.mu.RLock()
defer s.mu.RUnlock()
var tools []Tool
for _, t := range s.tools {
tools = append(tools, s.toTool(t))
}
return tools
}
// CallTool invokes a tool on the server.
func (s *MCPServer) CallTool(ctx context.Context, name string, arguments map[string]any) (string, error) {
s.mu.RLock()
session := s.session
s.mu.RUnlock()
if session == nil {
return "", fmt.Errorf("%w: %s", ErrNotConnected, s.name)
}
result, err := session.CallTool(ctx, &mcp.CallToolParams{
Name: name,
Arguments: arguments,
})
if err != nil {
return "", err
}
if len(result.Content) == 0 {
return "", nil
}
return contentToString(result.Content), nil
}
func (s *MCPServer) toTool(t *mcp.Tool) Tool {
var inputSchema map[string]any
if t.InputSchema != nil {
data, err := json.Marshal(t.InputSchema)
if err == nil {
_ = json.Unmarshal(data, &inputSchema)
}
}
if inputSchema == nil {
inputSchema = map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
return Tool{
Name: t.Name,
Description: t.Description,
Schema: inputSchema,
isMCP: true,
mcpServer: s,
}
}
func contentToString(content []mcp.Content) string {
var parts []string
for _, c := range content {
switch tc := c.(type) {
case *mcp.TextContent:
parts = append(parts, tc.Text)
default:
if data, err := json.Marshal(c); err == nil {
parts = append(parts, string(data))
}
}
}
if len(parts) == 1 {
return parts[0]
}
data, _ := json.Marshal(parts)
return string(data)
}

73
v2/message.go Normal file
View File

@@ -0,0 +1,73 @@
package llm
// Role represents who authored a message.
type Role string
const (
RoleSystem Role = "system"
RoleUser Role = "user"
RoleAssistant Role = "assistant"
RoleTool Role = "tool"
)
// Image represents an image attachment.
type Image struct {
// Provide exactly one of URL or Base64.
URL string // HTTP(S) URL
Base64 string // Raw base64-encoded data
ContentType string // MIME type (e.g., "image/png"), required for Base64
}
// Content represents message content with optional text and images.
type Content struct {
Text string
Images []Image
}
// ToolCall represents a tool invocation requested by the assistant.
type ToolCall struct {
ID string
Name string
Arguments string // raw JSON
}
// Message represents a single message in a conversation.
type Message struct {
Role Role
Content Content
// ToolCallID is set when Role == RoleTool, identifying which tool call this responds to.
ToolCallID string
// ToolCalls is set when the assistant requests tool invocations.
ToolCalls []ToolCall
}
// UserMessage creates a user message with text content.
func UserMessage(text string) Message {
return Message{Role: RoleUser, Content: Content{Text: text}}
}
// UserMessageWithImages creates a user message with text and images.
func UserMessageWithImages(text string, images ...Image) Message {
return Message{Role: RoleUser, Content: Content{Text: text, Images: images}}
}
// SystemMessage creates a system prompt message.
func SystemMessage(text string) Message {
return Message{Role: RoleSystem, Content: Content{Text: text}}
}
// AssistantMessage creates an assistant message with text content.
func AssistantMessage(text string) Message {
return Message{Role: RoleAssistant, Content: Content{Text: text}}
}
// ToolResultMessage creates a tool result message.
func ToolResultMessage(toolCallID string, result string) Message {
return Message{
Role: RoleTool,
Content: Content{Text: result},
ToolCallID: toolCallID,
}
}

117
v2/middleware.go Normal file
View File

@@ -0,0 +1,117 @@
package llm
import (
"context"
"fmt"
"log/slog"
"sync"
"time"
)
// CompletionFunc is the signature for the completion call chain.
type CompletionFunc func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error)
// Middleware wraps a completion call. It receives the next handler in the chain
// and returns a new handler that can inspect/modify the request and response.
type Middleware func(next CompletionFunc) CompletionFunc
// WithLogging returns middleware that logs requests and responses via slog.
func WithLogging(logger *slog.Logger) Middleware {
return func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
logger.Info("llm request",
"model", model,
"message_count", len(messages),
)
start := time.Now()
resp, err := next(ctx, model, messages, cfg)
elapsed := time.Since(start)
if err != nil {
logger.Error("llm error", "model", model, "elapsed", elapsed, "error", err)
} else {
logger.Info("llm response",
"model", model,
"elapsed", elapsed,
"text_len", len(resp.Text),
"tool_calls", len(resp.ToolCalls),
)
}
return resp, err
}
}
}
// WithRetry returns middleware that retries failed requests with configurable backoff.
func WithRetry(maxRetries int, backoff func(attempt int) time.Duration) Middleware {
return func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
var lastErr error
for attempt := 0; attempt <= maxRetries; attempt++ {
if attempt > 0 {
select {
case <-ctx.Done():
return Response{}, ctx.Err()
case <-time.After(backoff(attempt)):
}
}
resp, err := next(ctx, model, messages, cfg)
if err == nil {
return resp, nil
}
lastErr = err
}
return Response{}, fmt.Errorf("after %d retries: %w", maxRetries, lastErr)
}
}
}
// WithTimeout returns middleware that enforces a per-request timeout.
func WithTimeout(d time.Duration) Middleware {
return func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
ctx, cancel := context.WithTimeout(ctx, d)
defer cancel()
return next(ctx, model, messages, cfg)
}
}
}
// UsageTracker accumulates token usage statistics across calls.
type UsageTracker struct {
mu sync.Mutex
TotalInput int64
TotalOutput int64
TotalRequests int64
}
// Add records usage from a single request.
func (ut *UsageTracker) Add(u *Usage) {
if u == nil {
return
}
ut.mu.Lock()
defer ut.mu.Unlock()
ut.TotalInput += int64(u.InputTokens)
ut.TotalOutput += int64(u.OutputTokens)
ut.TotalRequests++
}
// Summary returns the accumulated totals.
func (ut *UsageTracker) Summary() (input, output, requests int64) {
ut.mu.Lock()
defer ut.mu.Unlock()
return ut.TotalInput, ut.TotalOutput, ut.TotalRequests
}
// WithUsageTracking returns middleware that accumulates token usage across calls.
func WithUsageTracking(tracker *UsageTracker) Middleware {
return func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
resp, err := next(ctx, model, messages, cfg)
if err == nil {
tracker.Add(resp.Usage)
}
return resp, err
}
}
}

323
v2/openai/openai.go Normal file
View File

@@ -0,0 +1,323 @@
// Package openai implements the go-llm v2 provider interface for OpenAI.
package openai
import (
"context"
"fmt"
"strings"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/param"
"github.com/openai/openai-go/shared"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
// Provider implements the provider.Provider interface for OpenAI.
type Provider struct {
apiKey string
baseURL string
}
// New creates a new OpenAI provider.
func New(apiKey string, baseURL string) *Provider {
return &Provider{apiKey: apiKey, baseURL: baseURL}
}
// Complete performs a non-streaming completion.
func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
var opts []option.RequestOption
opts = append(opts, option.WithAPIKey(p.apiKey))
if p.baseURL != "" {
opts = append(opts, option.WithBaseURL(p.baseURL))
}
cl := openai.NewClient(opts...)
oaiReq := p.buildRequest(req)
resp, err := cl.Chat.Completions.New(ctx, oaiReq)
if err != nil {
return provider.Response{}, fmt.Errorf("openai completion error: %w", err)
}
return p.convertResponse(resp), nil
}
// Stream performs a streaming completion.
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
var opts []option.RequestOption
opts = append(opts, option.WithAPIKey(p.apiKey))
if p.baseURL != "" {
opts = append(opts, option.WithBaseURL(p.baseURL))
}
cl := openai.NewClient(opts...)
oaiReq := p.buildRequest(req)
stream := cl.Chat.Completions.NewStreaming(ctx, oaiReq)
var fullText strings.Builder
var toolCalls []provider.ToolCall
toolCallArgs := map[int]*strings.Builder{}
for stream.Next() {
chunk := stream.Current()
for _, choice := range chunk.Choices {
// Text delta
if choice.Delta.Content != "" {
fullText.WriteString(choice.Delta.Content)
events <- provider.StreamEvent{
Type: provider.StreamEventText,
Text: choice.Delta.Content,
}
}
// Tool call deltas
for _, tc := range choice.Delta.ToolCalls {
idx := int(tc.Index)
if tc.ID != "" {
// New tool call starting
for len(toolCalls) <= idx {
toolCalls = append(toolCalls, provider.ToolCall{})
}
toolCalls[idx].ID = tc.ID
toolCalls[idx].Name = tc.Function.Name
toolCallArgs[idx] = &strings.Builder{}
events <- provider.StreamEvent{
Type: provider.StreamEventToolStart,
ToolCall: &provider.ToolCall{
ID: tc.ID,
Name: tc.Function.Name,
},
ToolIndex: idx,
}
}
if tc.Function.Arguments != "" {
if b, ok := toolCallArgs[idx]; ok {
b.WriteString(tc.Function.Arguments)
}
events <- provider.StreamEvent{
Type: provider.StreamEventToolDelta,
ToolIndex: idx,
ToolCall: &provider.ToolCall{
Arguments: tc.Function.Arguments,
},
}
}
}
}
}
if err := stream.Err(); err != nil {
return fmt.Errorf("openai stream error: %w", err)
}
// Finalize tool calls
for idx := range toolCalls {
if b, ok := toolCallArgs[idx]; ok {
toolCalls[idx].Arguments = b.String()
}
events <- provider.StreamEvent{
Type: provider.StreamEventToolEnd,
ToolIndex: idx,
ToolCall: &toolCalls[idx],
}
}
// Send done event
events <- provider.StreamEvent{
Type: provider.StreamEventDone,
Response: &provider.Response{
Text: fullText.String(),
ToolCalls: toolCalls,
},
}
return nil
}
func (p *Provider) buildRequest(req provider.Request) openai.ChatCompletionNewParams {
oaiReq := openai.ChatCompletionNewParams{
Model: req.Model,
}
for _, msg := range req.Messages {
oaiReq.Messages = append(oaiReq.Messages, convertMessage(msg, req.Model))
}
for _, tool := range req.Tools {
oaiReq.Tools = append(oaiReq.Tools, openai.ChatCompletionToolParam{
Type: "function",
Function: shared.FunctionDefinitionParam{
Name: tool.Name,
Description: openai.String(tool.Description),
Parameters: openai.FunctionParameters(tool.Schema),
},
})
}
if req.Temperature != nil {
// o* and gpt-5* models don't support custom temperatures
if !strings.HasPrefix(req.Model, "o") && !strings.HasPrefix(req.Model, "gpt-5") {
oaiReq.Temperature = openai.Float(*req.Temperature)
}
}
if req.MaxTokens != nil {
oaiReq.MaxCompletionTokens = openai.Int(int64(*req.MaxTokens))
}
if req.TopP != nil {
oaiReq.TopP = openai.Float(*req.TopP)
}
if len(req.Stop) > 0 {
oaiReq.Stop = openai.ChatCompletionNewParamsStopUnion{OfString: openai.String(req.Stop[0])}
}
return oaiReq
}
func convertMessage(msg provider.Message, model string) openai.ChatCompletionMessageParamUnion {
var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam
var textContent param.Opt[string]
for _, img := range msg.Images {
var url string
if img.Base64 != "" {
url = "data:" + img.ContentType + ";base64," + img.Base64
} else if img.URL != "" {
url = img.URL
}
if url != "" {
arrayOfContentParts = append(arrayOfContentParts,
openai.ChatCompletionContentPartUnionParam{
OfImageURL: &openai.ChatCompletionContentPartImageParam{
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
URL: url,
},
},
},
)
}
}
if msg.Content != "" {
if len(arrayOfContentParts) > 0 {
arrayOfContentParts = append(arrayOfContentParts,
openai.ChatCompletionContentPartUnionParam{
OfText: &openai.ChatCompletionContentPartTextParam{
Text: msg.Content,
},
},
)
} else {
textContent = openai.String(msg.Content)
}
}
// Determine if this model uses developer messages instead of system
useDeveloper := false
parts := strings.Split(model, "-")
if len(parts) > 1 && len(parts[0]) > 0 && parts[0][0] == 'o' {
useDeveloper = true
}
switch msg.Role {
case "system":
if useDeveloper {
return openai.ChatCompletionMessageParamUnion{
OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{
Content: openai.ChatCompletionDeveloperMessageParamContentUnion{
OfString: textContent,
},
},
}
}
return openai.ChatCompletionMessageParamUnion{
OfSystem: &openai.ChatCompletionSystemMessageParam{
Content: openai.ChatCompletionSystemMessageParamContentUnion{
OfString: textContent,
},
},
}
case "user":
return openai.ChatCompletionMessageParamUnion{
OfUser: &openai.ChatCompletionUserMessageParam{
Content: openai.ChatCompletionUserMessageParamContentUnion{
OfString: textContent,
OfArrayOfContentParts: arrayOfContentParts,
},
},
}
case "assistant":
as := &openai.ChatCompletionAssistantMessageParam{}
if msg.Content != "" {
as.Content.OfString = openai.String(msg.Content)
}
for _, tc := range msg.ToolCalls {
as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{
ID: tc.ID,
Function: openai.ChatCompletionMessageToolCallFunctionParam{
Name: tc.Name,
Arguments: tc.Arguments,
},
})
}
return openai.ChatCompletionMessageParamUnion{OfAssistant: as}
case "tool":
return openai.ChatCompletionMessageParamUnion{
OfTool: &openai.ChatCompletionToolMessageParam{
ToolCallID: msg.ToolCallID,
Content: openai.ChatCompletionToolMessageParamContentUnion{
OfString: openai.String(msg.Content),
},
},
}
}
// Fallback to user message
return openai.ChatCompletionMessageParamUnion{
OfUser: &openai.ChatCompletionUserMessageParam{
Content: openai.ChatCompletionUserMessageParamContentUnion{
OfString: textContent,
},
},
}
}
func (p *Provider) convertResponse(resp *openai.ChatCompletion) provider.Response {
var res provider.Response
if resp == nil || len(resp.Choices) == 0 {
return res
}
choice := resp.Choices[0]
res.Text = choice.Message.Content
for _, tc := range choice.Message.ToolCalls {
res.ToolCalls = append(res.ToolCalls, provider.ToolCall{
ID: tc.ID,
Name: tc.Function.Name,
Arguments: strings.TrimSpace(tc.Function.Arguments),
})
}
if resp.Usage.TotalTokens > 0 {
res.Usage = &provider.Usage{
InputTokens: int(resp.Usage.PromptTokens),
OutputTokens: int(resp.Usage.CompletionTokens),
TotalTokens: int(resp.Usage.TotalTokens),
}
}
return res
}

92
v2/provider/provider.go Normal file
View File

@@ -0,0 +1,92 @@
// Package provider defines the interface that LLM backend implementations must satisfy.
package provider
import "context"
// Message is the provider-level message representation.
type Message struct {
Role string
Content string
Images []Image
ToolCalls []ToolCall
ToolCallID string
}
// Image represents an image attachment at the provider level.
type Image struct {
URL string
Base64 string
ContentType string
}
// ToolCall represents a tool invocation requested by the model.
type ToolCall struct {
ID string
Name string
Arguments string // raw JSON
}
// ToolDef defines a tool available to the model.
type ToolDef struct {
Name string
Description string
Schema map[string]any // JSON Schema
}
// Request is a completion request at the provider level.
type Request struct {
Model string
Messages []Message
Tools []ToolDef
Temperature *float64
MaxTokens *int
TopP *float64
Stop []string
}
// Response is a completion response at the provider level.
type Response struct {
Text string
ToolCalls []ToolCall
Usage *Usage
}
// Usage captures token consumption.
type Usage struct {
InputTokens int
OutputTokens int
TotalTokens int
}
// StreamEventType identifies the kind of stream event.
type StreamEventType int
const (
StreamEventText StreamEventType = iota // Text content delta
StreamEventToolStart // Tool call begins
StreamEventToolDelta // Tool call argument delta
StreamEventToolEnd // Tool call complete
StreamEventDone // Stream complete
StreamEventError // Error occurred
)
// StreamEvent represents a single event in a streaming response.
type StreamEvent struct {
Type StreamEventType
Text string
ToolCall *ToolCall
ToolIndex int
Error error
Response *Response
}
// Provider is the interface that LLM backends implement.
type Provider interface {
// Complete performs a non-streaming completion.
Complete(ctx context.Context, req Request) (Response, error)
// Stream performs a streaming completion, sending events to the channel.
// The provider MUST close the channel when done.
// The provider MUST send exactly one StreamEventDone as the last non-error event.
Stream(ctx context.Context, req Request, events chan<- StreamEvent) error
}

37
v2/request.go Normal file
View File

@@ -0,0 +1,37 @@
package llm
// RequestOption configures a single completion request.
type RequestOption func(*requestConfig)
type requestConfig struct {
tools *ToolBox
temperature *float64
maxTokens *int
topP *float64
stop []string
}
// WithTools attaches a toolbox to the request.
func WithTools(tb *ToolBox) RequestOption {
return func(c *requestConfig) { c.tools = tb }
}
// WithTemperature sets the sampling temperature.
func WithTemperature(t float64) RequestOption {
return func(c *requestConfig) { c.temperature = &t }
}
// WithMaxTokens sets the maximum number of tokens to generate.
func WithMaxTokens(n int) RequestOption {
return func(c *requestConfig) { c.maxTokens = &n }
}
// WithTopP sets the nucleus sampling parameter.
func WithTopP(p float64) RequestOption {
return func(c *requestConfig) { c.topP = &p }
}
// WithStop sets stop sequences.
func WithStop(sequences ...string) RequestOption {
return func(c *requestConfig) { c.stop = sequences }
}

34
v2/response.go Normal file
View File

@@ -0,0 +1,34 @@
package llm
// Response represents the result of a completion request.
type Response struct {
// Text is the assistant's text content. Empty if only tool calls.
Text string
// ToolCalls contains any tool invocations the assistant requested.
ToolCalls []ToolCall
// Usage contains token usage information (if available from provider).
Usage *Usage
// message is the full assistant message for this response.
message Message
}
// Message returns the full assistant Message for this response,
// suitable for appending to the conversation history.
func (r Response) Message() Message {
return r.message
}
// HasToolCalls returns true if the response contains tool call requests.
func (r Response) HasToolCalls() bool {
return len(r.ToolCalls) > 0
}
// Usage captures token consumption.
type Usage struct {
InputTokens int
OutputTokens int
TotalTokens int
}

163
v2/stream.go Normal file
View File

@@ -0,0 +1,163 @@
package llm
import (
"context"
"fmt"
"io"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
// StreamEventType identifies the kind of stream event.
type StreamEventType = provider.StreamEventType
const (
StreamEventText = provider.StreamEventText
StreamEventToolStart = provider.StreamEventToolStart
StreamEventToolDelta = provider.StreamEventToolDelta
StreamEventToolEnd = provider.StreamEventToolEnd
StreamEventDone = provider.StreamEventDone
StreamEventError = provider.StreamEventError
)
// StreamEvent represents a single event in a streaming response.
type StreamEvent struct {
Type StreamEventType
// Text is set for StreamEventText — the text delta.
Text string
// ToolCall is set for StreamEventToolStart/ToolDelta/ToolEnd.
ToolCall *ToolCall
// ToolIndex identifies which tool call is being updated.
ToolIndex int
// Error is set for StreamEventError.
Error error
// Response is set for StreamEventDone — the complete, aggregated response.
Response *Response
}
// StreamReader reads streaming events from an LLM response.
// Must be closed when done.
type StreamReader struct {
events <-chan StreamEvent
cancel context.CancelFunc
done bool
}
func newStreamReader(ctx context.Context, p provider.Provider, req provider.Request) (*StreamReader, error) {
ctx, cancel := context.WithCancel(ctx)
providerEvents := make(chan provider.StreamEvent, 32)
publicEvents := make(chan StreamEvent, 32)
go func() {
defer close(publicEvents)
for pev := range providerEvents {
ev := convertStreamEvent(pev)
select {
case publicEvents <- ev:
case <-ctx.Done():
return
}
}
}()
go func() {
defer close(providerEvents)
if err := p.Stream(ctx, req, providerEvents); err != nil {
select {
case providerEvents <- provider.StreamEvent{Type: provider.StreamEventError, Error: err}:
default:
}
}
}()
return &StreamReader{
events: publicEvents,
cancel: cancel,
}, nil
}
func convertStreamEvent(pev provider.StreamEvent) StreamEvent {
ev := StreamEvent{
Type: pev.Type,
Text: pev.Text,
ToolIndex: pev.ToolIndex,
}
if pev.Error != nil {
ev.Error = pev.Error
}
if pev.ToolCall != nil {
tc := ToolCall{
ID: pev.ToolCall.ID,
Name: pev.ToolCall.Name,
Arguments: pev.ToolCall.Arguments,
}
ev.ToolCall = &tc
}
if pev.Response != nil {
resp := convertProviderResponse(*pev.Response)
ev.Response = &resp
}
return ev
}
// Next returns the next event from the stream.
// Returns io.EOF when the stream is complete.
func (sr *StreamReader) Next() (StreamEvent, error) {
if sr.done {
return StreamEvent{}, io.EOF
}
ev, ok := <-sr.events
if !ok {
sr.done = true
return StreamEvent{}, io.EOF
}
if ev.Type == StreamEventError {
return ev, ev.Error
}
if ev.Type == StreamEventDone {
sr.done = true
}
return ev, nil
}
// Close closes the stream reader and releases resources.
func (sr *StreamReader) Close() error {
sr.cancel()
return nil
}
// Collect reads all events and returns the final aggregated Response.
func (sr *StreamReader) Collect() (Response, error) {
var lastResp *Response
for {
ev, err := sr.Next()
if err == io.EOF {
break
}
if err != nil {
return Response{}, err
}
if ev.Type == StreamEventDone && ev.Response != nil {
lastResp = ev.Response
}
}
if lastResp == nil {
return Response{}, fmt.Errorf("stream completed without final response")
}
return *lastResp, nil
}
// Text is a convenience that collects the stream and returns just the text.
func (sr *StreamReader) Text() (string, error) {
resp, err := sr.Collect()
if err != nil {
return "", err
}
return resp.Text, nil
}

193
v2/tool.go Normal file
View File

@@ -0,0 +1,193 @@
package llm
import (
"context"
"encoding/json"
"fmt"
"reflect"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/internal/schema"
)
// Tool defines a tool that the LLM can invoke.
type Tool struct {
// Name is the tool's unique identifier.
Name string
// Description tells the LLM what this tool does.
Description string
// Schema is the JSON Schema for the tool's parameters.
Schema map[string]any
// fn holds the implementation function (set via Define or DefineSimple).
fn reflect.Value
pTyp reflect.Type // nil for parameterless tools
// isMCP indicates this tool is provided by an MCP server.
isMCP bool
mcpServer *MCPServer
}
// Define creates a tool from a typed handler function.
// T must be a struct. Struct fields become the tool's parameters.
//
// Struct tags:
// - `json:"name"` — parameter name
// - `description:"..."` — parameter description
// - `enum:"a,b,c"` — enum constraint
//
// Pointer fields are optional; non-pointer fields are required.
//
// Example:
//
// type WeatherParams struct {
// City string `json:"city" description:"The city to query"`
// Unit string `json:"unit" description:"Temperature unit" enum:"celsius,fahrenheit"`
// }
//
// llm.Define[WeatherParams]("get_weather", "Get weather for a city",
// func(ctx context.Context, p WeatherParams) (string, error) {
// return fmt.Sprintf("72F in %s", p.City), nil
// },
// )
func Define[T any](name, description string, fn func(context.Context, T) (string, error)) Tool {
var zero T
return Tool{
Name: name,
Description: description,
Schema: schema.FromStruct(zero),
fn: reflect.ValueOf(fn),
pTyp: reflect.TypeOf(zero),
}
}
// DefineSimple creates a parameterless tool.
//
// Example:
//
// llm.DefineSimple("get_time", "Get the current time",
// func(ctx context.Context) (string, error) {
// return time.Now().Format(time.RFC3339), nil
// },
// )
func DefineSimple(name, description string, fn func(context.Context) (string, error)) Tool {
return Tool{
Name: name,
Description: description,
Schema: map[string]any{"type": "object", "properties": map[string]any{}},
fn: reflect.ValueOf(fn),
}
}
// Execute runs the tool with the given JSON arguments string.
func (t Tool) Execute(ctx context.Context, argsJSON string) (string, error) {
if t.isMCP {
var args map[string]any
if argsJSON != "" && argsJSON != "{}" {
if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
return "", fmt.Errorf("invalid MCP tool arguments: %w", err)
}
}
return t.mcpServer.CallTool(ctx, t.Name, args)
}
// Parameterless tool
if t.pTyp == nil {
out := t.fn.Call([]reflect.Value{reflect.ValueOf(ctx)})
if !out[1].IsNil() {
return "", out[1].Interface().(error)
}
return out[0].String(), nil
}
// Typed tool: unmarshal JSON into the struct, call the function
p := reflect.New(t.pTyp)
if argsJSON != "" && argsJSON != "{}" {
if err := json.Unmarshal([]byte(argsJSON), p.Interface()); err != nil {
return "", fmt.Errorf("invalid tool arguments: %w", err)
}
}
out := t.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()})
if !out[1].IsNil() {
return "", out[1].Interface().(error)
}
return out[0].String(), nil
}
// ToolBox is a collection of tools available for use by an LLM.
type ToolBox struct {
tools map[string]Tool
mcpServers []*MCPServer
}
// NewToolBox creates a new ToolBox from the given tools.
func NewToolBox(tools ...Tool) *ToolBox {
tb := &ToolBox{tools: make(map[string]Tool)}
for _, t := range tools {
tb.tools[t.Name] = t
}
return tb
}
// Add adds tools to the toolbox and returns it for chaining.
func (tb *ToolBox) Add(tools ...Tool) *ToolBox {
if tb.tools == nil {
tb.tools = make(map[string]Tool)
}
for _, t := range tools {
tb.tools[t.Name] = t
}
return tb
}
// AddMCP adds an MCP server's tools to the toolbox. The server must be connected.
func (tb *ToolBox) AddMCP(server *MCPServer) *ToolBox {
if tb.tools == nil {
tb.tools = make(map[string]Tool)
}
tb.mcpServers = append(tb.mcpServers, server)
for _, tool := range server.ListTools() {
tb.tools[tool.Name] = tool
}
return tb
}
// AllTools returns all tools (local + MCP) as a slice.
func (tb *ToolBox) AllTools() []Tool {
if tb == nil {
return nil
}
tools := make([]Tool, 0, len(tb.tools))
for _, t := range tb.tools {
tools = append(tools, t)
}
return tools
}
// Execute executes a tool call by name.
func (tb *ToolBox) Execute(ctx context.Context, call ToolCall) (string, error) {
if tb == nil {
return "", ErrNoToolsConfigured
}
tool, ok := tb.tools[call.Name]
if !ok {
return "", fmt.Errorf("%w: %s", ErrToolNotFound, call.Name)
}
return tool.Execute(ctx, call.Arguments)
}
// ExecuteAll executes all tool calls and returns tool result messages.
func (tb *ToolBox) ExecuteAll(ctx context.Context, calls []ToolCall) ([]Message, error) {
var results []Message
for _, call := range calls {
result, err := tb.Execute(ctx, call)
text := result
if err != nil {
text = "Error: " + err.Error()
}
results = append(results, ToolResultMessage(call.ID, text))
}
return results, nil
}

139
v2/tool_test.go Normal file
View File

@@ -0,0 +1,139 @@
package llm
import (
"context"
"encoding/json"
"testing"
)
type calcParams struct {
A float64 `json:"a" description:"First number"`
B float64 `json:"b" description:"Second number"`
Op string `json:"op" description:"Operation" enum:"add,sub,mul,div"`
}
func TestDefine(t *testing.T) {
tool := Define[calcParams]("calc", "Calculator",
func(ctx context.Context, p calcParams) (string, error) {
var result float64
switch p.Op {
case "add":
result = p.A + p.B
case "sub":
result = p.A - p.B
case "mul":
result = p.A * p.B
case "div":
result = p.A / p.B
}
b, err := json.Marshal(result)
return string(b), err
},
)
if tool.Name != "calc" {
t.Errorf("expected name 'calc', got %q", tool.Name)
}
if tool.Description != "Calculator" {
t.Errorf("expected description 'Calculator', got %q", tool.Description)
}
if tool.Schema["type"] != "object" {
t.Errorf("expected schema type=object, got %v", tool.Schema["type"])
}
// Test execution
result, err := tool.Execute(context.Background(), `{"a": 10, "b": 3, "op": "add"}`)
if err != nil {
t.Fatalf("execute failed: %v", err)
}
if result != "13" {
t.Errorf("expected '13', got %q", result)
}
}
func TestDefineSimple(t *testing.T) {
tool := DefineSimple("hello", "Say hello",
func(ctx context.Context) (string, error) {
return "Hello, world!", nil
},
)
result, err := tool.Execute(context.Background(), "")
if err != nil {
t.Fatalf("execute failed: %v", err)
}
if result != "Hello, world!" {
t.Errorf("expected 'Hello, world!', got %q", result)
}
}
func TestToolBox(t *testing.T) {
tool1 := DefineSimple("tool1", "Tool 1", func(ctx context.Context) (string, error) {
return "result1", nil
})
tool2 := DefineSimple("tool2", "Tool 2", func(ctx context.Context) (string, error) {
return "result2", nil
})
tb := NewToolBox(tool1, tool2)
tools := tb.AllTools()
if len(tools) != 2 {
t.Errorf("expected 2 tools, got %d", len(tools))
}
result, err := tb.Execute(context.Background(), ToolCall{ID: "1", Name: "tool1"})
if err != nil {
t.Fatalf("execute failed: %v", err)
}
if result != "result1" {
t.Errorf("expected 'result1', got %q", result)
}
// Test not found
_, err = tb.Execute(context.Background(), ToolCall{ID: "x", Name: "nonexistent"})
if err == nil {
t.Error("expected error for nonexistent tool")
}
}
func TestToolBoxExecuteAll(t *testing.T) {
tb := NewToolBox(
DefineSimple("t1", "T1", func(ctx context.Context) (string, error) {
return "r1", nil
}),
DefineSimple("t2", "T2", func(ctx context.Context) (string, error) {
return "r2", nil
}),
)
calls := []ToolCall{
{ID: "c1", Name: "t1"},
{ID: "c2", Name: "t2"},
}
msgs, err := tb.ExecuteAll(context.Background(), calls)
if err != nil {
t.Fatalf("execute all failed: %v", err)
}
if len(msgs) != 2 {
t.Fatalf("expected 2 messages, got %d", len(msgs))
}
if msgs[0].Role != RoleTool {
t.Errorf("expected role=tool, got %v", msgs[0].Role)
}
if msgs[0].ToolCallID != "c1" {
t.Errorf("expected toolCallID=c1, got %v", msgs[0].ToolCallID)
}
if msgs[0].Content.Text != "r1" {
t.Errorf("expected content=r1, got %v", msgs[0].Content.Text)
}
}
// jsonMarshal helper for calcParams test
func (p calcParams) jsonMarshal(result float64) (string, error) {
b, err := json.Marshal(result)
return string(b), err
}

59
v2/tools/browser.go Normal file
View File

@@ -0,0 +1,59 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
)
// BrowserParams defines parameters for the browser tool.
type BrowserParams struct {
URL string `json:"url" description:"The URL to fetch and extract text from"`
}
// Browser creates a simple web content fetcher tool.
// It fetches a URL and returns the text content.
//
// For a full headless browser, consider using an MCP server like Playwright MCP.
//
// Example:
//
// tools := llm.NewToolBox(tools.Browser())
func Browser() llm.Tool {
return llm.Define[BrowserParams]("browser", "Fetch a web page and return its text content",
func(ctx context.Context, p BrowserParams) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, p.URL, nil)
if err != nil {
return "", fmt.Errorf("creating request: %w", err)
}
req.Header.Set("User-Agent", "go-llm/2.0 (Web Fetcher)")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("fetching URL: %w", err)
}
defer resp.Body.Close()
// Limit to 1MB
limited := io.LimitReader(resp.Body, 1<<20)
body, err := io.ReadAll(limited)
if err != nil {
return "", fmt.Errorf("reading body: %w", err)
}
result := map[string]any{
"url": p.URL,
"status": resp.StatusCode,
"content_type": resp.Header.Get("Content-Type"),
"body": string(body),
}
out, _ := json.MarshalIndent(result, "", " ")
return string(out), nil
},
)
}

101
v2/tools/exec.go Normal file
View File

@@ -0,0 +1,101 @@
package tools
import (
"context"
"fmt"
"os/exec"
"runtime"
"strings"
"time"
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
)
// ExecParams defines parameters for the exec tool.
type ExecParams struct {
Command string `json:"command" description:"The shell command to execute"`
}
// ExecOption configures the exec tool.
type ExecOption func(*execConfig)
type execConfig struct {
allowedCommands []string
workDir string
timeout time.Duration
}
// WithAllowedCommands restricts which commands can be executed.
// If empty, all commands are allowed.
func WithAllowedCommands(cmds []string) ExecOption {
return func(c *execConfig) { c.allowedCommands = cmds }
}
// WithWorkDir sets the working directory for command execution.
func WithWorkDir(dir string) ExecOption {
return func(c *execConfig) { c.workDir = dir }
}
// WithExecTimeout sets the maximum execution time.
func WithExecTimeout(d time.Duration) ExecOption {
return func(c *execConfig) { c.timeout = d }
}
// Exec creates a shell command execution tool.
//
// Example:
//
// tools := llm.NewToolBox(
// tools.Exec(tools.WithAllowedCommands([]string{"ls", "cat", "grep"})),
// )
func Exec(opts ...ExecOption) llm.Tool {
cfg := &execConfig{
timeout: 30 * time.Second,
}
for _, opt := range opts {
opt(cfg)
}
return llm.Define[ExecParams]("exec", "Execute a shell command and return its output",
func(ctx context.Context, p ExecParams) (string, error) {
// Check allowed commands
if len(cfg.allowedCommands) > 0 {
parts := strings.Fields(p.Command)
if len(parts) == 0 {
return "", fmt.Errorf("empty command")
}
allowed := false
for _, cmd := range cfg.allowedCommands {
if parts[0] == cmd {
allowed = true
break
}
}
if !allowed {
return "", fmt.Errorf("command %q is not in the allowed list", parts[0])
}
}
ctx, cancel := context.WithTimeout(ctx, cfg.timeout)
defer cancel()
var cmd *exec.Cmd
if runtime.GOOS == "windows" {
cmd = exec.CommandContext(ctx, "cmd", "/C", p.Command)
} else {
cmd = exec.CommandContext(ctx, "sh", "-c", p.Command)
}
if cfg.workDir != "" {
cmd.Dir = cfg.workDir
}
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Sprintf("Error: %s\nOutput: %s", err.Error(), string(output)), nil
}
return string(output), nil
},
)
}

75
v2/tools/http.go Normal file
View File

@@ -0,0 +1,75 @@
package tools
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
)
// HTTPParams defines parameters for the HTTP request tool.
type HTTPParams struct {
Method string `json:"method" description:"HTTP method" enum:"GET,POST,PUT,DELETE,PATCH,HEAD"`
URL string `json:"url" description:"Request URL"`
Headers map[string]string `json:"headers,omitempty" description:"Request headers"`
Body *string `json:"body,omitempty" description:"Request body"`
}
// HTTP creates an HTTP request tool.
//
// Example:
//
// tools := llm.NewToolBox(tools.HTTP())
func HTTP() llm.Tool {
return llm.Define[HTTPParams]("http_request", "Make an HTTP request and return the response",
func(ctx context.Context, p HTTPParams) (string, error) {
var bodyReader io.Reader
if p.Body != nil {
bodyReader = bytes.NewBufferString(*p.Body)
}
req, err := http.NewRequestWithContext(ctx, p.Method, p.URL, bodyReader)
if err != nil {
return "", fmt.Errorf("creating request: %w", err)
}
for k, v := range p.Headers {
req.Header.Set(k, v)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
// Limit to 1MB
limited := io.LimitReader(resp.Body, 1<<20)
body, err := io.ReadAll(limited)
if err != nil {
return "", fmt.Errorf("reading response: %w", err)
}
headers := map[string]string{}
for k, v := range resp.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
result := map[string]any{
"status": resp.StatusCode,
"status_text": resp.Status,
"headers": headers,
"body": string(body),
}
out, _ := json.MarshalIndent(result, "", " ")
return string(out), nil
},
)
}

81
v2/tools/readfile.go Normal file
View File

@@ -0,0 +1,81 @@
package tools
import (
"bufio"
"context"
"fmt"
"os"
"strings"
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
)
// ReadFileParams defines parameters for the read file tool.
type ReadFileParams struct {
Path string `json:"path" description:"File path to read"`
Start *int `json:"start,omitempty" description:"Starting line number (1-based, inclusive)"`
End *int `json:"end,omitempty" description:"Ending line number (1-based, inclusive)"`
}
// ReadFile creates a file reading tool.
//
// Example:
//
// tools := llm.NewToolBox(tools.ReadFile())
func ReadFile() llm.Tool {
return llm.Define[ReadFileParams]("read_file", "Read the contents of a file",
func(ctx context.Context, p ReadFileParams) (string, error) {
f, err := os.Open(p.Path)
if err != nil {
return "", fmt.Errorf("opening file: %w", err)
}
defer f.Close()
// If no line range specified, read the whole file (limited to 1MB)
if p.Start == nil && p.End == nil {
info, err := f.Stat()
if err != nil {
return "", fmt.Errorf("stat file: %w", err)
}
if info.Size() > 1<<20 {
return "", fmt.Errorf("file too large (%d bytes), use start/end to read a range", info.Size())
}
data, err := os.ReadFile(p.Path)
if err != nil {
return "", fmt.Errorf("reading file: %w", err)
}
return string(data), nil
}
// Read specific line range
start := 1
end := -1
if p.Start != nil {
start = *p.Start
}
if p.End != nil {
end = *p.End
}
var lines []string
scanner := bufio.NewScanner(f)
lineNum := 0
for scanner.Scan() {
lineNum++
if lineNum < start {
continue
}
if end > 0 && lineNum > end {
break
}
lines = append(lines, fmt.Sprintf("%d: %s", lineNum, scanner.Text()))
}
if err := scanner.Err(); err != nil {
return "", fmt.Errorf("scanning file: %w", err)
}
return strings.Join(lines, "\n"), nil
},
)
}

101
v2/tools/websearch.go Normal file
View File

@@ -0,0 +1,101 @@
// Package tools provides ready-to-use tool implementations for common agent patterns.
package tools
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
)
// WebSearchParams defines parameters for the web search tool.
type WebSearchParams struct {
Query string `json:"query" description:"The search query"`
Count *int `json:"count,omitempty" description:"Number of results to return (default 5, max 20)"`
}
// WebSearch creates a web search tool using the Brave Search API.
//
// Get a free API key at https://brave.com/search/api/
//
// Example:
//
// tools := llm.NewToolBox(tools.WebSearch("your-brave-api-key"))
func WebSearch(apiKey string) llm.Tool {
return llm.Define[WebSearchParams]("web_search", "Search the web for information using Brave Search",
func(ctx context.Context, p WebSearchParams) (string, error) {
count := 5
if p.Count != nil && *p.Count > 0 {
count = *p.Count
if count > 20 {
count = 20
}
}
u := fmt.Sprintf("https://api.search.brave.com/res/v1/web/search?q=%s&count=%d",
url.QueryEscape(p.Query), count)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return "", fmt.Errorf("creating request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("X-Subscription-Token", apiKey)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("search request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("reading response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("search API returned %d: %s", resp.StatusCode, string(body))
}
// Parse and simplify the response
var raw map[string]any
if err := json.Unmarshal(body, &raw); err != nil {
return string(body), nil
}
type result struct {
Title string `json:"title"`
URL string `json:"url"`
Description string `json:"description"`
}
var results []result
if web, ok := raw["web"].(map[string]any); ok {
if items, ok := web["results"].([]any); ok {
for _, item := range items {
if m, ok := item.(map[string]any); ok {
r := result{}
if t, ok := m["title"].(string); ok {
r.Title = t
}
if u, ok := m["url"].(string); ok {
r.URL = u
}
if d, ok := m["description"].(string); ok {
r.Description = d
}
results = append(results, r)
}
}
}
}
out, _ := json.MarshalIndent(results, "", " ")
return string(out), nil
},
)
}

31
v2/tools/writefile.go Normal file
View File

@@ -0,0 +1,31 @@
package tools
import (
"context"
"fmt"
"os"
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
)
// WriteFileParams defines parameters for the write file tool.
type WriteFileParams struct {
Path string `json:"path" description:"File path to write"`
Content string `json:"content" description:"Content to write to the file"`
}
// WriteFile creates a file writing tool.
//
// Example:
//
// tools := llm.NewToolBox(tools.WriteFile())
func WriteFile() llm.Tool {
return llm.Define[WriteFileParams]("write_file", "Write content to a file (creates or overwrites)",
func(ctx context.Context, p WriteFileParams) (string, error) {
if err := os.WriteFile(p.Path, []byte(p.Content), 0644); err != nil {
return "", fmt.Errorf("writing file: %w", err)
}
return fmt.Sprintf("Successfully wrote %d bytes to %s", len(p.Content), p.Path), nil
},
)
}