From a4cb4baab517a6a8a3b79f51103d39b78c09bdae Mon Sep 17 00:00:00 2001 From: Steve Dudenhoeffer Date: Sat, 7 Feb 2026 20:00:08 -0500 Subject: [PATCH] Add go-llm v2: redesigned API for simpler LLM abstraction v2 is a new Go module (v2/) with a dramatically simpler API: - Unified Message type (no more Input marker interface) - Define[T] for ergonomic tool creation with standard context.Context - Chat session with automatic tool-call loop (agent loop) - Streaming via pull-based StreamReader - MCP one-call connect (MCPStdioServer, MCPHTTPServer, MCPSSEServer) - Middleware support (logging, retry, timeout, usage tracking) - Decoupled JSON Schema (map[string]any, no provider coupling) - Sample tools: WebSearch, Browser, Exec, ReadFile, WriteFile, HTTP - Providers: OpenAI, Anthropic, Google (all with streaming) Co-Authored-By: Claude Opus 4.6 --- v2/CLAUDE.md | 31 +++ v2/anthropic/anthropic.go | 273 +++++++++++++++++++++++++ v2/chat.go | 153 ++++++++++++++ v2/constructors.go | 48 +++++ v2/errors.go | 17 ++ v2/go.mod | 39 ++++ v2/go.sum | 159 +++++++++++++++ v2/google/google.go | 322 +++++++++++++++++++++++++++++ v2/internal/imageutil/compress.go | 105 ++++++++++ v2/internal/schema/schema.go | 188 +++++++++++++++++ v2/internal/schema/schema_test.go | 181 +++++++++++++++++ v2/llm.go | 199 ++++++++++++++++++ v2/mcp.go | 264 ++++++++++++++++++++++++ v2/message.go | 73 +++++++ v2/middleware.go | 117 +++++++++++ v2/openai/openai.go | 323 ++++++++++++++++++++++++++++++ v2/provider/provider.go | 92 +++++++++ v2/request.go | 37 ++++ v2/response.go | 34 ++++ v2/stream.go | 163 +++++++++++++++ v2/tool.go | 193 ++++++++++++++++++ v2/tool_test.go | 139 +++++++++++++ v2/tools/browser.go | 59 ++++++ v2/tools/exec.go | 101 ++++++++++ v2/tools/http.go | 75 +++++++ v2/tools/readfile.go | 81 ++++++++ v2/tools/websearch.go | 101 ++++++++++ v2/tools/writefile.go | 31 +++ 28 files changed, 3598 insertions(+) create mode 100644 v2/CLAUDE.md create mode 100644 v2/anthropic/anthropic.go create mode 100644 v2/chat.go create mode 100644 v2/constructors.go create mode 100644 v2/errors.go create mode 100644 v2/go.mod create mode 100644 v2/go.sum create mode 100644 v2/google/google.go create mode 100644 v2/internal/imageutil/compress.go create mode 100644 v2/internal/schema/schema.go create mode 100644 v2/internal/schema/schema_test.go create mode 100644 v2/llm.go create mode 100644 v2/mcp.go create mode 100644 v2/message.go create mode 100644 v2/middleware.go create mode 100644 v2/openai/openai.go create mode 100644 v2/provider/provider.go create mode 100644 v2/request.go create mode 100644 v2/response.go create mode 100644 v2/stream.go create mode 100644 v2/tool.go create mode 100644 v2/tool_test.go create mode 100644 v2/tools/browser.go create mode 100644 v2/tools/exec.go create mode 100644 v2/tools/http.go create mode 100644 v2/tools/readfile.go create mode 100644 v2/tools/websearch.go create mode 100644 v2/tools/writefile.go diff --git a/v2/CLAUDE.md b/v2/CLAUDE.md new file mode 100644 index 0000000..ff2d203 --- /dev/null +++ b/v2/CLAUDE.md @@ -0,0 +1,31 @@ +# CLAUDE.md for go-llm v2 + +## Build and Test Commands +- Build project: `cd v2 && go build ./...` +- Run all tests: `cd v2 && go test ./...` +- Run specific test: `cd v2 && go test -v -run ./...` +- Tidy dependencies: `cd v2 && go mod tidy` +- Vet: `cd v2 && go vet ./...` + +## Code Style Guidelines +- **Indentation**: Standard Go tabs +- **Naming**: `camelCase` for unexported, `PascalCase` for exported +- **Error Handling**: Always check and handle errors immediately. Wrap with `fmt.Errorf("%w: ...", err)` +- **Imports**: Standard library first, then third-party, then internal packages + +## Package Structure +- Root package `llm` — public API (Client, Model, Chat, ToolBox, Message types) +- `provider/` — Provider interface that backends implement +- `openai/`, `anthropic/`, `google/` — Provider implementations +- `tools/` — Ready-to-use sample tools (WebSearch, Browser, Exec, ReadFile, WriteFile, HTTP) +- `internal/schema/` — JSON Schema generation from Go structs +- `internal/imageutil/` — Image compression utilities + +## Key Design Decisions +1. Unified `Message` type instead of marker interfaces +2. `map[string]any` JSON Schema (no provider coupling) +3. Tool functions return `(string, error)`, use standard `context.Context` +4. `Chat.Send()` auto-loops tool calls; `Chat.SendRaw()` for manual control +5. MCP one-call connect: `MCPStdioServer(ctx, cmd, args...)` +6. Streaming via pull-based `StreamReader.Next()` +7. Middleware for logging, retry, timeout, usage tracking diff --git a/v2/anthropic/anthropic.go b/v2/anthropic/anthropic.go new file mode 100644 index 0000000..beaf64b --- /dev/null +++ b/v2/anthropic/anthropic.go @@ -0,0 +1,273 @@ +// Package anthropic implements the go-llm v2 provider interface for Anthropic. +package anthropic + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/internal/imageutil" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" + + anth "github.com/liushuangls/go-anthropic/v2" +) + +// Provider implements the provider.Provider interface for Anthropic. +type Provider struct { + apiKey string +} + +// New creates a new Anthropic provider. +func New(apiKey string) *Provider { + return &Provider{apiKey: apiKey} +} + +// Complete performs a non-streaming completion. +func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) { + cl := anth.NewClient(p.apiKey) + + anthReq := p.buildRequest(req) + + resp, err := cl.CreateMessages(ctx, anthReq) + if err != nil { + return provider.Response{}, fmt.Errorf("anthropic completion error: %w", err) + } + + return p.convertResponse(resp), nil +} + +// Stream performs a streaming completion. +func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error { + cl := anth.NewClient(p.apiKey) + + anthReq := p.buildRequest(req) + + resp, err := cl.CreateMessagesStream(ctx, anth.MessagesStreamRequest{ + MessagesRequest: anthReq, + OnContentBlockDelta: func(data anth.MessagesEventContentBlockDeltaData) { + if data.Delta.Type == "text_delta" && data.Delta.Text != nil { + events <- provider.StreamEvent{ + Type: provider.StreamEventText, + Text: *data.Delta.Text, + } + } + }, + }) + + if err != nil { + return fmt.Errorf("anthropic stream error: %w", err) + } + + result := p.convertResponse(resp) + events <- provider.StreamEvent{ + Type: provider.StreamEventDone, + Response: &result, + } + + return nil +} + +func (p *Provider) buildRequest(req provider.Request) anth.MessagesRequest { + anthReq := anth.MessagesRequest{ + Model: anth.Model(req.Model), + MaxTokens: 4096, + } + + if req.MaxTokens != nil { + anthReq.MaxTokens = *req.MaxTokens + } + + var msgs []anth.Message + + for _, msg := range req.Messages { + if msg.Role == "system" { + if len(anthReq.System) > 0 { + anthReq.System += "\n" + } + anthReq.System += msg.Content + continue + } + + if msg.Role == "tool" { + // Tool results in Anthropic format - use the helper + toolUseID := msg.ToolCallID + content := msg.Content + isError := false + msgs = append(msgs, anth.Message{ + Role: anth.RoleUser, + Content: []anth.MessageContent{ + { + Type: anth.MessagesContentTypeToolResult, + MessageContentToolResult: &anth.MessageContentToolResult{ + ToolUseID: &toolUseID, + Content: []anth.MessageContent{ + { + Type: anth.MessagesContentTypeText, + Text: &content, + }, + }, + IsError: &isError, + }, + }, + }, + }) + continue + } + + role := anth.RoleUser + if msg.Role == "assistant" { + role = anth.RoleAssistant + } + + m := anth.Message{ + Role: role, + Content: []anth.MessageContent{}, + } + + if msg.Content != "" { + m.Content = append(m.Content, anth.MessageContent{ + Type: anth.MessagesContentTypeText, + Text: &msg.Content, + }) + } + + // Handle tool calls in assistant messages + for _, tc := range msg.ToolCalls { + var input json.RawMessage + if tc.Arguments != "" { + input = json.RawMessage(tc.Arguments) + } else { + input = json.RawMessage("{}") + } + m.Content = append(m.Content, anth.MessageContent{ + Type: anth.MessagesContentTypeToolUse, + MessageContentToolUse: &anth.MessageContentToolUse{ + ID: tc.ID, + Name: tc.Name, + Input: input, + }, + }) + } + + // Handle images + for _, img := range msg.Images { + if role == anth.RoleAssistant { + role = anth.RoleUser + m.Role = anth.RoleUser + } + + if img.Base64 != "" { + b64 := img.Base64 + contentType := img.ContentType + + // Compress if > 5MiB + raw, err := base64.StdEncoding.DecodeString(b64) + if err == nil && len(raw) >= 5242880 { + compressed, mime, cerr := imageutil.CompressImage(b64, 5*1024*1024) + if cerr == nil { + b64 = compressed + contentType = mime + } + } + + m.Content = append(m.Content, anth.NewImageMessageContent( + anth.NewMessageContentSource( + anth.MessagesContentSourceTypeBase64, + contentType, + b64, + ))) + } else if img.URL != "" { + // Download and convert to base64 (Anthropic doesn't support URLs directly) + resp, err := http.Get(img.URL) + if err != nil { + continue + } + data, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + continue + } + + contentType := resp.Header.Get("Content-Type") + b64 := base64.StdEncoding.EncodeToString(data) + + m.Content = append(m.Content, anth.NewImageMessageContent( + anth.NewMessageContentSource( + anth.MessagesContentSourceTypeBase64, + contentType, + b64, + ))) + } + } + + // Merge consecutive same-role messages (Anthropic requires alternating) + if len(msgs) > 0 && msgs[len(msgs)-1].Role == role { + msgs[len(msgs)-1].Content = append(msgs[len(msgs)-1].Content, m.Content...) + } else { + msgs = append(msgs, m) + } + } + + for _, tool := range req.Tools { + anthReq.Tools = append(anthReq.Tools, anth.ToolDefinition{ + Name: tool.Name, + Description: tool.Description, + InputSchema: tool.Schema, + }) + } + + anthReq.Messages = msgs + + if req.Temperature != nil { + f := float32(*req.Temperature) + anthReq.Temperature = &f + } + + if req.TopP != nil { + f := float32(*req.TopP) + anthReq.TopP = &f + } + + if len(req.Stop) > 0 { + anthReq.StopSequences = req.Stop + } + + return anthReq +} + +func (p *Provider) convertResponse(resp anth.MessagesResponse) provider.Response { + var res provider.Response + var textParts []string + + for _, block := range resp.Content { + switch block.Type { + case anth.MessagesContentTypeText: + if block.Text != nil { + textParts = append(textParts, *block.Text) + } + case anth.MessagesContentTypeToolUse: + if block.MessageContentToolUse != nil { + args, _ := json.Marshal(block.MessageContentToolUse.Input) + res.ToolCalls = append(res.ToolCalls, provider.ToolCall{ + ID: block.MessageContentToolUse.ID, + Name: block.MessageContentToolUse.Name, + Arguments: string(args), + }) + } + } + } + + res.Text = strings.Join(textParts, "") + + res.Usage = &provider.Usage{ + InputTokens: resp.Usage.InputTokens, + OutputTokens: resp.Usage.OutputTokens, + TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens, + } + + return res +} diff --git a/v2/chat.go b/v2/chat.go new file mode 100644 index 0000000..47d2f64 --- /dev/null +++ b/v2/chat.go @@ -0,0 +1,153 @@ +package llm + +import ( + "context" + "fmt" +) + +// Chat manages a multi-turn conversation with automatic history tracking +// and optional automatic tool-call execution. +type Chat struct { + model *Model + messages []Message + tools *ToolBox + opts []RequestOption +} + +// NewChat creates a new conversation with the given model. +func NewChat(model *Model, opts ...RequestOption) *Chat { + return &Chat{ + model: model, + opts: opts, + } +} + +// SetSystem sets or replaces the system message. +func (c *Chat) SetSystem(text string) { + filtered := make([]Message, 0, len(c.messages)+1) + for _, m := range c.messages { + if m.Role != RoleSystem { + filtered = append(filtered, m) + } + } + c.messages = append([]Message{SystemMessage(text)}, filtered...) +} + +// SetTools configures the tools available for this chat. +func (c *Chat) SetTools(tb *ToolBox) { + c.tools = tb +} + +// Send sends a user message and returns the assistant's text response. +// If the model calls tools, they are executed automatically and the loop +// continues until the model produces a text response (the "agent loop"). +func (c *Chat) Send(ctx context.Context, text string) (string, error) { + return c.SendMessage(ctx, UserMessage(text)) +} + +// SendWithImages sends a user message with images attached. +func (c *Chat) SendWithImages(ctx context.Context, text string, images ...Image) (string, error) { + return c.SendMessage(ctx, UserMessageWithImages(text, images...)) +} + +// SendMessage sends an arbitrary message and returns the final text response. +// Handles the full tool-call loop automatically. +func (c *Chat) SendMessage(ctx context.Context, msg Message) (string, error) { + c.messages = append(c.messages, msg) + + opts := c.buildOpts() + + for { + resp, err := c.model.Complete(ctx, c.messages, opts...) + if err != nil { + return "", fmt.Errorf("completion failed: %w", err) + } + + c.messages = append(c.messages, resp.Message()) + + if !resp.HasToolCalls() { + return resp.Text, nil + } + + if c.tools == nil { + return "", ErrNoToolsConfigured + } + + toolResults, err := c.tools.ExecuteAll(ctx, resp.ToolCalls) + if err != nil { + return "", fmt.Errorf("tool execution failed: %w", err) + } + + c.messages = append(c.messages, toolResults...) + } +} + +// SendRaw sends a message and returns the raw Response without automatic tool execution. +// Useful when you want to handle tool calls manually. +func (c *Chat) SendRaw(ctx context.Context, msg Message) (Response, error) { + c.messages = append(c.messages, msg) + + opts := c.buildOpts() + + resp, err := c.model.Complete(ctx, c.messages, opts...) + if err != nil { + return Response{}, err + } + + c.messages = append(c.messages, resp.Message()) + return resp, nil +} + +// SendStream sends a user message and returns a StreamReader for streaming responses. +func (c *Chat) SendStream(ctx context.Context, text string) (*StreamReader, error) { + c.messages = append(c.messages, UserMessage(text)) + + opts := c.buildOpts() + + cfg := &requestConfig{} + for _, opt := range opts { + opt(cfg) + } + + req := buildProviderRequest(c.model.model, c.messages, cfg) + return newStreamReader(ctx, c.model.provider, req) +} + +// AddToolResults manually adds tool results to the conversation. +// Use with SendRaw when handling tool calls manually. +func (c *Chat) AddToolResults(results ...Message) { + c.messages = append(c.messages, results...) +} + +// Messages returns the current conversation history (read-only copy). +func (c *Chat) Messages() []Message { + cp := make([]Message, len(c.messages)) + copy(cp, c.messages) + return cp +} + +// Reset clears the conversation history. +func (c *Chat) Reset() { + c.messages = nil +} + +// Fork creates a copy of this chat with identical history, for branching conversations. +func (c *Chat) Fork() *Chat { + c2 := &Chat{ + model: c.model, + messages: make([]Message, len(c.messages)), + tools: c.tools, + opts: c.opts, + } + copy(c2.messages, c.messages) + return c2 +} + +func (c *Chat) buildOpts() []RequestOption { + opts := make([]RequestOption, len(c.opts)) + copy(opts, c.opts) + if c.tools != nil { + opts = append(opts, WithTools(c.tools)) + } + return opts +} diff --git a/v2/constructors.go b/v2/constructors.go new file mode 100644 index 0000000..b5cb3b0 --- /dev/null +++ b/v2/constructors.go @@ -0,0 +1,48 @@ +package llm + +import ( + anthProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/anthropic" + googleProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/google" + openaiProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openai" +) + +// OpenAI creates an OpenAI client. +// +// Example: +// +// model := llm.OpenAI("sk-...").Model("gpt-4o") +func OpenAI(apiKey string, opts ...ClientOption) *Client { + cfg := &clientConfig{} + for _, opt := range opts { + opt(cfg) + } + return newClient(openaiProvider.New(apiKey, cfg.baseURL)) +} + +// Anthropic creates an Anthropic client. +// +// Example: +// +// model := llm.Anthropic("sk-ant-...").Model("claude-sonnet-4-20250514") +func Anthropic(apiKey string, opts ...ClientOption) *Client { + cfg := &clientConfig{} + for _, opt := range opts { + opt(cfg) + } + _ = cfg // Anthropic doesn't support custom base URL in the SDK + return newClient(anthProvider.New(apiKey)) +} + +// Google creates a Google (Gemini) client. +// +// Example: +// +// model := llm.Google("...").Model("gemini-2.0-flash") +func Google(apiKey string, opts ...ClientOption) *Client { + cfg := &clientConfig{} + for _, opt := range opts { + opt(cfg) + } + _ = cfg // Google doesn't support custom base URL in the SDK + return newClient(googleProvider.New(apiKey)) +} diff --git a/v2/errors.go b/v2/errors.go new file mode 100644 index 0000000..5082608 --- /dev/null +++ b/v2/errors.go @@ -0,0 +1,17 @@ +package llm + +import "errors" + +var ( + // ErrNoToolsConfigured is returned when the model requests tool calls but no tools are available. + ErrNoToolsConfigured = errors.New("model requested tool calls but no tools configured") + + // ErrToolNotFound is returned when a requested tool is not in the toolbox. + ErrToolNotFound = errors.New("tool not found") + + // ErrNotConnected is returned when trying to use an MCP server that isn't connected. + ErrNotConnected = errors.New("MCP server not connected") + + // ErrStreamClosed is returned when trying to read from a closed stream. + ErrStreamClosed = errors.New("stream closed") +) diff --git a/v2/go.mod b/v2/go.mod new file mode 100644 index 0000000..5b9c582 --- /dev/null +++ b/v2/go.mod @@ -0,0 +1,39 @@ +module gitea.stevedudenhoeffer.com/steve/go-llm/v2 + +go 1.24.0 + +toolchain go1.24.2 + +require ( + github.com/liushuangls/go-anthropic/v2 v2.17.0 + github.com/modelcontextprotocol/go-sdk v1.2.0 + github.com/openai/openai-go v1.12.0 + golang.org/x/image v0.35.0 + google.golang.org/genai v1.45.0 +) + +require ( + cloud.google.com/go v0.116.0 // indirect + cloud.google.com/go/auth v0.9.3 // indirect + cloud.google.com/go/compute/metadata v0.5.0 // indirect + github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect + github.com/google/go-cmp v0.7.0 // indirect + github.com/google/jsonschema-go v0.3.0 // indirect + github.com/google/s2a-go v0.1.8 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect + github.com/gorilla/websocket v1.5.3 // indirect + github.com/tidwall/gjson v1.14.4 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + go.opencensus.io v0.24.0 // indirect + golang.org/x/crypto v0.36.0 // indirect + golang.org/x/net v0.38.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/sys v0.31.0 // indirect + golang.org/x/text v0.33.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect + google.golang.org/grpc v1.66.2 // indirect + google.golang.org/protobuf v1.34.2 // indirect +) diff --git a/v2/go.sum b/v2/go.sum new file mode 100644 index 0000000..2dc7df9 --- /dev/null +++ b/v2/go.sum @@ -0,0 +1,159 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE= +cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U= +cloud.google.com/go/auth v0.9.3 h1:VOEUIAADkkLtyfr3BLa3R8Ed/j6w1jTBmARx+wb5w5U= +cloud.google.com/go/auth v0.9.3/go.mod h1:7z6VY+7h3KUdRov5F1i8NDP5ZzWKYmEPO842BgCsmTk= +cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY= +cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= +github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM= +github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA= +github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw= +github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/liushuangls/go-anthropic/v2 v2.17.0 h1:iBA6h7aghi1q86owEQ95XE2R2MF/0dQ7bCxtwTxOg4c= +github.com/liushuangls/go-anthropic/v2 v2.17.0/go.mod h1:a550cJXPoTG2FL3DvfKG2zzD5O2vjgvo4tHtoGPzFLU= +github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s= +github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= +github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0= +github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= +go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/image v0.35.0 h1:LKjiHdgMtO8z7Fh18nGY6KDcoEtVfsgLDPeLyguqb7I= +golang.org/x/image v0.35.0/go.mod h1:MwPLTVgvxSASsxdLzKrl8BRFuyqMyGhLwmC+TO1Sybk= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= +golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genai v1.45.0 h1:s80ZpS42XW0zu/ogiOtenCio17nJ7reEFJjoCftukpA= +google.golang.org/genai v1.45.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= +google.golang.org/grpc v1.66.2 h1:3QdXkuq3Bkh7w+ywLdLvM56cmGvQHUMZpiCzt6Rqaoo= +google.golang.org/grpc v1.66.2/go.mod h1:s3/l6xSSCURdVfAnL+TqCNMyTDAGN6+lZeVxnZR128Y= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/v2/google/google.go b/v2/google/google.go new file mode 100644 index 0000000..180f721 --- /dev/null +++ b/v2/google/google.go @@ -0,0 +1,322 @@ +// Package google implements the go-llm v2 provider interface for Google (Gemini). +package google + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" + + "google.golang.org/genai" +) + +// Provider implements the provider.Provider interface for Google Gemini. +type Provider struct { + apiKey string +} + +// New creates a new Google provider. +func New(apiKey string) *Provider { + return &Provider{apiKey: apiKey} +} + +// Complete performs a non-streaming completion. +func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) { + cl, err := genai.NewClient(ctx, &genai.ClientConfig{ + APIKey: p.apiKey, + Backend: genai.BackendGeminiAPI, + }) + if err != nil { + return provider.Response{}, fmt.Errorf("google client error: %w", err) + } + + contents, cfg := p.buildRequest(req) + + resp, err := cl.Models.GenerateContent(ctx, req.Model, contents, cfg) + if err != nil { + return provider.Response{}, fmt.Errorf("google completion error: %w", err) + } + + return p.convertResponse(resp) +} + +// Stream performs a streaming completion. +func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error { + cl, err := genai.NewClient(ctx, &genai.ClientConfig{ + APIKey: p.apiKey, + Backend: genai.BackendGeminiAPI, + }) + if err != nil { + return fmt.Errorf("google client error: %w", err) + } + + contents, cfg := p.buildRequest(req) + + var fullText strings.Builder + var toolCalls []provider.ToolCall + + for resp, err := range cl.Models.GenerateContentStream(ctx, req.Model, contents, cfg) { + if err != nil { + return fmt.Errorf("google stream error: %w", err) + } + + for _, c := range resp.Candidates { + if c.Content == nil { + continue + } + for _, part := range c.Content.Parts { + if part.Text != "" { + fullText.WriteString(part.Text) + events <- provider.StreamEvent{ + Type: provider.StreamEventText, + Text: part.Text, + } + } + if part.FunctionCall != nil { + args, _ := json.Marshal(part.FunctionCall.Args) + tc := provider.ToolCall{ + ID: part.FunctionCall.Name, + Name: part.FunctionCall.Name, + Arguments: string(args), + } + toolCalls = append(toolCalls, tc) + events <- provider.StreamEvent{ + Type: provider.StreamEventToolStart, + ToolCall: &tc, + ToolIndex: len(toolCalls) - 1, + } + events <- provider.StreamEvent{ + Type: provider.StreamEventToolEnd, + ToolCall: &tc, + ToolIndex: len(toolCalls) - 1, + } + } + } + } + } + + events <- provider.StreamEvent{ + Type: provider.StreamEventDone, + Response: &provider.Response{ + Text: fullText.String(), + ToolCalls: toolCalls, + }, + } + + return nil +} + +func (p *Provider) buildRequest(req provider.Request) ([]*genai.Content, *genai.GenerateContentConfig) { + var contents []*genai.Content + cfg := &genai.GenerateContentConfig{} + + for _, tool := range req.Tools { + cfg.Tools = append(cfg.Tools, &genai.Tool{ + FunctionDeclarations: []*genai.FunctionDeclaration{ + { + Name: tool.Name, + Description: tool.Description, + Parameters: schemaToGenai(tool.Schema), + }, + }, + }) + } + + if req.Temperature != nil { + f := float32(*req.Temperature) + cfg.Temperature = &f + } + + if req.MaxTokens != nil { + cfg.MaxOutputTokens = int32(*req.MaxTokens) + } + + if req.TopP != nil { + f := float32(*req.TopP) + cfg.TopP = &f + } + + if len(req.Stop) > 0 { + cfg.StopSequences = req.Stop + } + + for _, msg := range req.Messages { + var role genai.Role + switch msg.Role { + case "system": + cfg.SystemInstruction = genai.NewContentFromText(msg.Content, genai.RoleUser) + continue + case "assistant": + role = genai.RoleModel + case "tool": + // Tool results go as function responses (Genai uses RoleUser for function responses) + contents = append(contents, &genai.Content{ + Role: genai.RoleUser, + Parts: []*genai.Part{ + { + FunctionResponse: &genai.FunctionResponse{ + Name: msg.ToolCallID, + Response: map[string]any{ + "result": msg.Content, + }, + }, + }, + }, + }) + continue + default: + role = genai.RoleUser + } + + var parts []*genai.Part + + if msg.Content != "" { + parts = append(parts, genai.NewPartFromText(msg.Content)) + } + + // Handle tool calls in assistant messages + for _, tc := range msg.ToolCalls { + var args map[string]any + if tc.Arguments != "" { + _ = json.Unmarshal([]byte(tc.Arguments), &args) + } + parts = append(parts, &genai.Part{ + FunctionCall: &genai.FunctionCall{ + Name: tc.Name, + Args: args, + }, + }) + } + + for _, img := range msg.Images { + if img.URL != "" { + // Gemini doesn't support URLs directly; download + resp, err := http.Get(img.URL) + if err != nil { + continue + } + data, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + continue + } + + mimeType := http.DetectContentType(data) + parts = append(parts, genai.NewPartFromBytes(data, mimeType)) + } else if img.Base64 != "" { + data, err := base64.StdEncoding.DecodeString(img.Base64) + if err != nil { + continue + } + parts = append(parts, genai.NewPartFromBytes(data, img.ContentType)) + } + } + + contents = append(contents, genai.NewContentFromParts(parts, role)) + } + + return contents, cfg +} + +func (p *Provider) convertResponse(resp *genai.GenerateContentResponse) (provider.Response, error) { + var res provider.Response + + for _, c := range resp.Candidates { + if c.Content == nil { + continue + } + for _, part := range c.Content.Parts { + if part.Text != "" { + res.Text += part.Text + } + if part.FunctionCall != nil { + args, _ := json.Marshal(part.FunctionCall.Args) + res.ToolCalls = append(res.ToolCalls, provider.ToolCall{ + ID: part.FunctionCall.Name, + Name: part.FunctionCall.Name, + Arguments: string(args), + }) + } + } + } + + if resp.UsageMetadata != nil { + res.Usage = &provider.Usage{ + InputTokens: int(resp.UsageMetadata.PromptTokenCount), + OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount), + TotalTokens: int(resp.UsageMetadata.TotalTokenCount), + } + } + + return res, nil +} + +// schemaToGenai converts a JSON Schema map to a genai.Schema. +func schemaToGenai(s map[string]any) *genai.Schema { + if s == nil { + return nil + } + + schema := &genai.Schema{} + + if t, ok := s["type"].(string); ok { + switch t { + case "object": + schema.Type = genai.TypeObject + case "array": + schema.Type = genai.TypeArray + case "string": + schema.Type = genai.TypeString + case "integer": + schema.Type = genai.TypeInteger + case "number": + schema.Type = genai.TypeNumber + case "boolean": + schema.Type = genai.TypeBoolean + } + } + + if desc, ok := s["description"].(string); ok { + schema.Description = desc + } + + if props, ok := s["properties"].(map[string]any); ok { + schema.Properties = make(map[string]*genai.Schema) + for k, v := range props { + if vm, ok := v.(map[string]any); ok { + schema.Properties[k] = schemaToGenai(vm) + } + } + } + + if req, ok := s["required"].([]string); ok { + schema.Required = req + } else if req, ok := s["required"].([]any); ok { + for _, r := range req { + if rs, ok := r.(string); ok { + schema.Required = append(schema.Required, rs) + } + } + } + + if items, ok := s["items"].(map[string]any); ok { + schema.Items = schemaToGenai(items) + } + + if enums, ok := s["enum"].([]string); ok { + schema.Enum = enums + } else if enums, ok := s["enum"].([]any); ok { + for _, e := range enums { + if es, ok := e.(string); ok { + schema.Enum = append(schema.Enum, es) + } + } + } + + return schema +} diff --git a/v2/internal/imageutil/compress.go b/v2/internal/imageutil/compress.go new file mode 100644 index 0000000..908b16c --- /dev/null +++ b/v2/internal/imageutil/compress.go @@ -0,0 +1,105 @@ +// Package imageutil provides image compression utilities. +package imageutil + +import ( + "bytes" + "encoding/base64" + "fmt" + "image" + "image/gif" + "image/jpeg" + _ "image/png" // register PNG decoder + "net/http" + + "golang.org/x/image/draw" +) + +// CompressImage takes a base-64-encoded image (JPEG, PNG or GIF) and returns +// a base-64-encoded version that is at most maxLength bytes, along with the MIME type. +func CompressImage(b64 string, maxLength int) (string, string, error) { + raw, err := base64.StdEncoding.DecodeString(b64) + if err != nil { + return "", "", fmt.Errorf("base64 decode: %w", err) + } + + mime := http.DetectContentType(raw) + if len(raw) <= maxLength { + return b64, mime, nil + } + + switch mime { + case "image/gif": + return compressGIF(raw, maxLength) + default: + return compressRaster(raw, maxLength) + } +} + +func compressRaster(src []byte, maxLength int) (string, string, error) { + img, _, err := image.Decode(bytes.NewReader(src)) + if err != nil { + return "", "", fmt.Errorf("decode raster: %w", err) + } + + quality := 95 + for { + var buf bytes.Buffer + if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: quality}); err != nil { + return "", "", fmt.Errorf("jpeg encode: %w", err) + } + if buf.Len() <= maxLength { + return base64.StdEncoding.EncodeToString(buf.Bytes()), "image/jpeg", nil + } + + if quality > 20 { + quality -= 5 + continue + } + + b := img.Bounds() + if b.Dx() < 100 || b.Dy() < 100 { + return "", "", fmt.Errorf("cannot compress below %.02fMiB without destroying image", float64(maxLength)/1048576.0) + } + dst := image.NewRGBA(image.Rect(0, 0, int(float64(b.Dx())*0.8), int(float64(b.Dy())*0.8))) + draw.ApproxBiLinear.Scale(dst, dst.Bounds(), img, b, draw.Over, nil) + img = dst + quality = 95 + } +} + +func compressGIF(src []byte, maxLength int) (string, string, error) { + g, err := gif.DecodeAll(bytes.NewReader(src)) + if err != nil { + return "", "", fmt.Errorf("gif decode: %w", err) + } + + for { + var buf bytes.Buffer + if err := gif.EncodeAll(&buf, g); err != nil { + return "", "", fmt.Errorf("gif encode: %w", err) + } + if buf.Len() <= maxLength { + return base64.StdEncoding.EncodeToString(buf.Bytes()), "image/gif", nil + } + + w, h := g.Config.Width, g.Config.Height + if w < 100 || h < 100 { + return "", "", fmt.Errorf("cannot compress animated GIF below %.02fMiB", float64(maxLength)/1048576.0) + } + + nw, nh := int(float64(w)*0.8), int(float64(h)*0.8) + for i, frm := range g.Image { + rgba := image.NewRGBA(frm.Bounds()) + draw.Draw(rgba, rgba.Bounds(), frm, frm.Bounds().Min, draw.Src) + + dst := image.NewRGBA(image.Rect(0, 0, nw, nh)) + draw.ApproxBiLinear.Scale(dst, dst.Bounds(), rgba, rgba.Bounds(), draw.Over, nil) + + paletted := image.NewPaletted(dst.Bounds(), nil) + draw.FloydSteinberg.Draw(paletted, paletted.Bounds(), dst, dst.Bounds().Min) + + g.Image[i] = paletted + } + g.Config.Width, g.Config.Height = nw, nh + } +} diff --git a/v2/internal/schema/schema.go b/v2/internal/schema/schema.go new file mode 100644 index 0000000..63af420 --- /dev/null +++ b/v2/internal/schema/schema.go @@ -0,0 +1,188 @@ +// Package schema provides JSON Schema generation from Go structs. +// It produces standard JSON Schema as map[string]any, with no provider-specific types. +package schema + +import ( + "reflect" + "strings" +) + +// FromStruct generates a JSON Schema object from a Go struct. +// Struct tags used: +// - `json:"name"` — sets the property name (standard Go JSON convention) +// - `description:"..."` — sets the property description +// - `enum:"a,b,c"` — restricts string values to the given set +// +// Pointer fields are treated as optional; non-pointer fields are required. +// Anonymous (embedded) struct fields are flattened into the parent. +func FromStruct(v any) map[string]any { + t := reflect.TypeOf(v) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + panic("schema.FromStruct expects a struct or pointer to struct") + } + return objectSchema(t) +} + +func objectSchema(t reflect.Type) map[string]any { + properties := map[string]any{} + var required []string + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + + // Skip unexported fields + if !field.IsExported() { + continue + } + + // Flatten anonymous (embedded) structs + if field.Anonymous { + ft := field.Type + if ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + if ft.Kind() == reflect.Struct { + embedded := objectSchema(ft) + if props, ok := embedded["properties"].(map[string]any); ok { + for k, v := range props { + properties[k] = v + } + } + if req, ok := embedded["required"].([]string); ok { + required = append(required, req...) + } + } + continue + } + + name := fieldName(field) + isRequired := true + ft := field.Type + + if ft.Kind() == reflect.Ptr { + ft = ft.Elem() + isRequired = false + } + + prop := fieldSchema(field, ft) + properties[name] = prop + + if isRequired { + required = append(required, name) + } + } + + result := map[string]any{ + "type": "object", + "properties": properties, + } + if len(required) > 0 { + result["required"] = required + } + return result +} + +func fieldSchema(field reflect.StructField, ft reflect.Type) map[string]any { + prop := map[string]any{} + + // Check for enum tag first + if enumTag, ok := field.Tag.Lookup("enum"); ok { + vals := parseEnum(enumTag) + prop["type"] = "string" + prop["enum"] = vals + if desc, ok := field.Tag.Lookup("description"); ok { + prop["description"] = desc + } + return prop + } + + switch ft.Kind() { + case reflect.String: + prop["type"] = "string" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + prop["type"] = "integer" + case reflect.Float32, reflect.Float64: + prop["type"] = "number" + case reflect.Bool: + prop["type"] = "boolean" + case reflect.Struct: + return objectSchema(ft) + case reflect.Slice: + prop["type"] = "array" + elemType := ft.Elem() + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + prop["items"] = typeSchema(elemType) + case reflect.Map: + prop["type"] = "object" + if ft.Key().Kind() == reflect.String { + valType := ft.Elem() + if valType.Kind() == reflect.Ptr { + valType = valType.Elem() + } + prop["additionalProperties"] = typeSchema(valType) + } + default: + prop["type"] = "string" // fallback + } + + if desc, ok := field.Tag.Lookup("description"); ok { + prop["description"] = desc + } + + return prop +} + +func typeSchema(t reflect.Type) map[string]any { + switch t.Kind() { + case reflect.String: + return map[string]any{"type": "string"} + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return map[string]any{"type": "integer"} + case reflect.Float32, reflect.Float64: + return map[string]any{"type": "number"} + case reflect.Bool: + return map[string]any{"type": "boolean"} + case reflect.Struct: + return objectSchema(t) + case reflect.Slice: + elemType := t.Elem() + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + return map[string]any{ + "type": "array", + "items": typeSchema(elemType), + } + default: + return map[string]any{"type": "string"} + } +} + +func fieldName(f reflect.StructField) string { + if tag, ok := f.Tag.Lookup("json"); ok { + parts := strings.SplitN(tag, ",", 2) + if parts[0] != "" && parts[0] != "-" { + return parts[0] + } + } + return f.Name +} + +func parseEnum(tag string) []string { + parts := strings.Split(tag, ",") + var vals []string + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + vals = append(vals, p) + } + } + return vals +} diff --git a/v2/internal/schema/schema_test.go b/v2/internal/schema/schema_test.go new file mode 100644 index 0000000..14794f5 --- /dev/null +++ b/v2/internal/schema/schema_test.go @@ -0,0 +1,181 @@ +package schema + +import ( + "encoding/json" + "testing" +) + +type SimpleParams struct { + Name string `json:"name" description:"The name"` + Age int `json:"age" description:"The age"` +} + +type OptionalParams struct { + Required string `json:"required" description:"A required field"` + Optional *string `json:"optional,omitempty" description:"An optional field"` +} + +type EnumParams struct { + Color string `json:"color" description:"The color" enum:"red,green,blue"` +} + +type NestedParams struct { + Inner SimpleParams `json:"inner" description:"Nested object"` +} + +type ArrayParams struct { + Items []string `json:"items" description:"A list of items"` +} + +type EmbeddedBase struct { + ID string `json:"id" description:"The ID"` +} + +type EmbeddedParams struct { + EmbeddedBase + Name string `json:"name" description:"The name"` +} + +func TestFromStruct_Simple(t *testing.T) { + s := FromStruct(SimpleParams{}) + + if s["type"] != "object" { + t.Errorf("expected type=object, got %v", s["type"]) + } + + props, ok := s["properties"].(map[string]any) + if !ok { + t.Fatal("expected properties to be map[string]any") + } + + if len(props) != 2 { + t.Errorf("expected 2 properties, got %d", len(props)) + } + + nameSchema, ok := props["name"].(map[string]any) + if !ok { + t.Fatal("expected name property to be map[string]any") + } + if nameSchema["type"] != "string" { + t.Errorf("expected name type=string, got %v", nameSchema["type"]) + } + if nameSchema["description"] != "The name" { + t.Errorf("expected name description='The name', got %v", nameSchema["description"]) + } + + ageSchema, ok := props["age"].(map[string]any) + if !ok { + t.Fatal("expected age property to be map[string]any") + } + if ageSchema["type"] != "integer" { + t.Errorf("expected age type=integer, got %v", ageSchema["type"]) + } + + required, ok := s["required"].([]string) + if !ok { + t.Fatal("expected required to be []string") + } + if len(required) != 2 { + t.Errorf("expected 2 required fields, got %d", len(required)) + } +} + +func TestFromStruct_Optional(t *testing.T) { + s := FromStruct(OptionalParams{}) + + required, ok := s["required"].([]string) + if !ok { + t.Fatal("expected required to be []string") + } + + // Only "required" field should be required, not "optional" + if len(required) != 1 { + t.Errorf("expected 1 required field, got %d: %v", len(required), required) + } + if required[0] != "required" { + t.Errorf("expected required field 'required', got %v", required[0]) + } +} + +func TestFromStruct_Enum(t *testing.T) { + s := FromStruct(EnumParams{}) + + props := s["properties"].(map[string]any) + colorSchema := props["color"].(map[string]any) + + if colorSchema["type"] != "string" { + t.Errorf("expected enum type=string, got %v", colorSchema["type"]) + } + + enums, ok := colorSchema["enum"].([]string) + if !ok { + t.Fatal("expected enum to be []string") + } + if len(enums) != 3 { + t.Errorf("expected 3 enum values, got %d", len(enums)) + } +} + +func TestFromStruct_Nested(t *testing.T) { + s := FromStruct(NestedParams{}) + + props := s["properties"].(map[string]any) + innerSchema := props["inner"].(map[string]any) + + if innerSchema["type"] != "object" { + t.Errorf("expected nested type=object, got %v", innerSchema["type"]) + } + + innerProps := innerSchema["properties"].(map[string]any) + if len(innerProps) != 2 { + t.Errorf("expected 2 inner properties, got %d", len(innerProps)) + } +} + +func TestFromStruct_Array(t *testing.T) { + s := FromStruct(ArrayParams{}) + + props := s["properties"].(map[string]any) + itemsSchema := props["items"].(map[string]any) + + if itemsSchema["type"] != "array" { + t.Errorf("expected array type=array, got %v", itemsSchema["type"]) + } + + items := itemsSchema["items"].(map[string]any) + if items["type"] != "string" { + t.Errorf("expected items type=string, got %v", items["type"]) + } +} + +func TestFromStruct_Embedded(t *testing.T) { + s := FromStruct(EmbeddedParams{}) + + props := s["properties"].(map[string]any) + + // Should have both ID from embedded and Name + if len(props) != 2 { + t.Errorf("expected 2 properties (flattened), got %d", len(props)) + } + + if _, ok := props["id"]; !ok { + t.Error("expected 'id' property from embedded struct") + } + if _, ok := props["name"]; !ok { + t.Error("expected 'name' property") + } +} + +func TestFromStruct_ValidJSON(t *testing.T) { + s := FromStruct(SimpleParams{}) + + data, err := json.Marshal(s) + if err != nil { + t.Fatalf("schema should be valid JSON: %v", err) + } + + var parsed map[string]any + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatalf("schema should round-trip through JSON: %v", err) + } +} diff --git a/v2/llm.go b/v2/llm.go new file mode 100644 index 0000000..460b0a9 --- /dev/null +++ b/v2/llm.go @@ -0,0 +1,199 @@ +package llm + +import ( + "context" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +// Client represents an LLM provider. Create with OpenAI(), Anthropic(), Google(). +type Client struct { + p provider.Provider + middleware []Middleware +} + +// newClient creates a Client backed by the given provider. +func newClient(p provider.Provider) *Client { + return &Client{p: p} +} + +// Model returns a Model for the specified model version. +func (c *Client) Model(modelVersion string) *Model { + return &Model{ + provider: c.p, + model: modelVersion, + middleware: c.middleware, + } +} + +// WithMiddleware returns a new Client with additional middleware applied to all models. +func (c *Client) WithMiddleware(mw ...Middleware) *Client { + c2 := &Client{ + p: c.p, + middleware: append(append([]Middleware{}, c.middleware...), mw...), + } + return c2 +} + +// Model represents a specific model from a provider, ready for completions. +type Model struct { + provider provider.Provider + model string + middleware []Middleware +} + +// Complete sends a non-streaming completion request. +func (m *Model) Complete(ctx context.Context, messages []Message, opts ...RequestOption) (Response, error) { + cfg := &requestConfig{} + for _, opt := range opts { + opt(cfg) + } + + chain := m.buildChain() + return chain(ctx, m.model, messages, cfg) +} + +// Stream sends a streaming completion request, returning a StreamReader. +func (m *Model) Stream(ctx context.Context, messages []Message, opts ...RequestOption) (*StreamReader, error) { + cfg := &requestConfig{} + for _, opt := range opts { + opt(cfg) + } + + req := buildProviderRequest(m.model, messages, cfg) + return newStreamReader(ctx, m.provider, req) +} + +// WithMiddleware returns a new Model with additional middleware applied. +func (m *Model) WithMiddleware(mw ...Middleware) *Model { + return &Model{ + provider: m.provider, + model: m.model, + middleware: append(append([]Middleware{}, m.middleware...), mw...), + } +} + +func (m *Model) buildChain() CompletionFunc { + // Base handler that calls the provider + base := func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) { + req := buildProviderRequest(model, messages, cfg) + resp, err := m.provider.Complete(ctx, req) + if err != nil { + return Response{}, err + } + return convertProviderResponse(resp), nil + } + + // Apply middleware in reverse order (first middleware wraps outermost) + chain := base + for i := len(m.middleware) - 1; i >= 0; i-- { + chain = m.middleware[i](chain) + } + return chain +} + +func buildProviderRequest(model string, messages []Message, cfg *requestConfig) provider.Request { + req := provider.Request{ + Model: model, + Messages: convertMessages(messages), + } + + if cfg.temperature != nil { + req.Temperature = cfg.temperature + } + if cfg.maxTokens != nil { + req.MaxTokens = cfg.maxTokens + } + if cfg.topP != nil { + req.TopP = cfg.topP + } + if len(cfg.stop) > 0 { + req.Stop = cfg.stop + } + + if cfg.tools != nil { + for _, tool := range cfg.tools.AllTools() { + req.Tools = append(req.Tools, provider.ToolDef{ + Name: tool.Name, + Description: tool.Description, + Schema: tool.Schema, + }) + } + } + + return req +} + +func convertMessages(msgs []Message) []provider.Message { + out := make([]provider.Message, len(msgs)) + for i, m := range msgs { + pm := provider.Message{ + Role: string(m.Role), + Content: m.Content.Text, + ToolCallID: m.ToolCallID, + } + for _, img := range m.Content.Images { + pm.Images = append(pm.Images, provider.Image{ + URL: img.URL, + Base64: img.Base64, + ContentType: img.ContentType, + }) + } + for _, tc := range m.ToolCalls { + pm.ToolCalls = append(pm.ToolCalls, provider.ToolCall{ + ID: tc.ID, + Name: tc.Name, + Arguments: tc.Arguments, + }) + } + out[i] = pm + } + return out +} + +func convertProviderResponse(resp provider.Response) Response { + r := Response{ + Text: resp.Text, + } + + for _, tc := range resp.ToolCalls { + r.ToolCalls = append(r.ToolCalls, ToolCall{ + ID: tc.ID, + Name: tc.Name, + Arguments: tc.Arguments, + }) + } + + if resp.Usage != nil { + r.Usage = &Usage{ + InputTokens: resp.Usage.InputTokens, + OutputTokens: resp.Usage.OutputTokens, + TotalTokens: resp.Usage.TotalTokens, + } + } + + // Build the assistant message for conversation history + r.message = Message{ + Role: RoleAssistant, + Content: Content{Text: resp.Text}, + ToolCalls: r.ToolCalls, + } + + return r +} + +// --- Provider constructors --- +// These are defined here and delegate to provider-specific packages. +// They are set up via init() in the provider packages, or defined directly. + +// ClientOption configures a client. +type ClientOption func(*clientConfig) + +type clientConfig struct { + baseURL string +} + +// WithBaseURL overrides the API base URL. +func WithBaseURL(url string) ClientOption { + return func(c *clientConfig) { c.baseURL = url } +} diff --git a/v2/mcp.go b/v2/mcp.go new file mode 100644 index 0000000..b6f2716 --- /dev/null +++ b/v2/mcp.go @@ -0,0 +1,264 @@ +package llm + +import ( + "context" + "encoding/json" + "fmt" + "os" + "os/exec" + "sync" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// MCPTransport specifies how to connect to an MCP server. +type MCPTransport string + +const ( + MCPStdio MCPTransport = "stdio" + MCPSSE MCPTransport = "sse" + MCPHTTP MCPTransport = "http" +) + +// MCPServer represents a connection to an MCP server. +type MCPServer struct { + name string + transport MCPTransport + + // stdio fields + command string + args []string + env []string + + // network fields + url string + + // internal + client *mcp.Client + session *mcp.ClientSession + tools map[string]*mcp.Tool + mu sync.RWMutex +} + +// MCPOption configures an MCP server. +type MCPOption func(*MCPServer) + +// WithMCPEnv adds environment variables for the subprocess. +func WithMCPEnv(env ...string) MCPOption { + return func(s *MCPServer) { s.env = env } +} + +// WithMCPName sets a friendly name for logging. +func WithMCPName(name string) MCPOption { + return func(s *MCPServer) { s.name = name } +} + +// MCPStdioServer creates and connects to an MCP server via stdio transport. +// +// Example: +// +// server, err := llm.MCPStdioServer(ctx, "npx", "-y", "@modelcontextprotocol/server-filesystem", "/tmp") +func MCPStdioServer(ctx context.Context, command string, args ...string) (*MCPServer, error) { + s := &MCPServer{ + name: command, + transport: MCPStdio, + command: command, + args: args, + } + if err := s.connect(ctx); err != nil { + return nil, err + } + return s, nil +} + +// MCPHTTPServer creates and connects to an MCP server via streamable HTTP transport. +// +// Example: +// +// server, err := llm.MCPHTTPServer(ctx, "https://mcp.example.com") +func MCPHTTPServer(ctx context.Context, url string, opts ...MCPOption) (*MCPServer, error) { + s := &MCPServer{ + name: url, + transport: MCPHTTP, + url: url, + } + for _, opt := range opts { + opt(s) + } + if err := s.connect(ctx); err != nil { + return nil, err + } + return s, nil +} + +// MCPSSEServer creates and connects to an MCP server via SSE transport. +func MCPSSEServer(ctx context.Context, url string, opts ...MCPOption) (*MCPServer, error) { + s := &MCPServer{ + name: url, + transport: MCPSSE, + url: url, + } + for _, opt := range opts { + opt(s) + } + if err := s.connect(ctx); err != nil { + return nil, err + } + return s, nil +} + +func (s *MCPServer) connect(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.session != nil { + return nil + } + + s.client = mcp.NewClient(&mcp.Implementation{ + Name: "go-llm-v2", + Version: "2.0.0", + }, nil) + + var transport mcp.Transport + + switch s.transport { + case MCPSSE: + transport = &mcp.SSEClientTransport{ + Endpoint: s.url, + } + case MCPHTTP: + transport = &mcp.StreamableClientTransport{ + Endpoint: s.url, + } + default: // stdio + cmd := exec.Command(s.command, s.args...) + cmd.Env = append(os.Environ(), s.env...) + transport = &mcp.CommandTransport{ + Command: cmd, + } + } + + session, err := s.client.Connect(ctx, transport, nil) + if err != nil { + return fmt.Errorf("failed to connect to MCP server %s: %w", s.name, err) + } + + s.session = session + + // Load tools + s.tools = make(map[string]*mcp.Tool) + for tool, err := range session.Tools(ctx, nil) { + if err != nil { + s.session.Close() + s.session = nil + return fmt.Errorf("failed to list tools from %s: %w", s.name, err) + } + s.tools[tool.Name] = tool + } + + return nil +} + +// Close closes the connection to the MCP server. +func (s *MCPServer) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.session == nil { + return nil + } + + err := s.session.Close() + s.session = nil + s.tools = nil + return err +} + +// IsConnected returns true if the server is connected. +func (s *MCPServer) IsConnected() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.session != nil +} + +// ListTools returns Tool definitions for all tools this server provides. +func (s *MCPServer) ListTools() []Tool { + s.mu.RLock() + defer s.mu.RUnlock() + + var tools []Tool + for _, t := range s.tools { + tools = append(tools, s.toTool(t)) + } + return tools +} + +// CallTool invokes a tool on the server. +func (s *MCPServer) CallTool(ctx context.Context, name string, arguments map[string]any) (string, error) { + s.mu.RLock() + session := s.session + s.mu.RUnlock() + + if session == nil { + return "", fmt.Errorf("%w: %s", ErrNotConnected, s.name) + } + + result, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: name, + Arguments: arguments, + }) + if err != nil { + return "", err + } + + if len(result.Content) == 0 { + return "", nil + } + + return contentToString(result.Content), nil +} + +func (s *MCPServer) toTool(t *mcp.Tool) Tool { + var inputSchema map[string]any + if t.InputSchema != nil { + data, err := json.Marshal(t.InputSchema) + if err == nil { + _ = json.Unmarshal(data, &inputSchema) + } + } + + if inputSchema == nil { + inputSchema = map[string]any{ + "type": "object", + "properties": map[string]any{}, + } + } + + return Tool{ + Name: t.Name, + Description: t.Description, + Schema: inputSchema, + isMCP: true, + mcpServer: s, + } +} + +func contentToString(content []mcp.Content) string { + var parts []string + for _, c := range content { + switch tc := c.(type) { + case *mcp.TextContent: + parts = append(parts, tc.Text) + default: + if data, err := json.Marshal(c); err == nil { + parts = append(parts, string(data)) + } + } + } + if len(parts) == 1 { + return parts[0] + } + data, _ := json.Marshal(parts) + return string(data) +} diff --git a/v2/message.go b/v2/message.go new file mode 100644 index 0000000..43185b5 --- /dev/null +++ b/v2/message.go @@ -0,0 +1,73 @@ +package llm + +// Role represents who authored a message. +type Role string + +const ( + RoleSystem Role = "system" + RoleUser Role = "user" + RoleAssistant Role = "assistant" + RoleTool Role = "tool" +) + +// Image represents an image attachment. +type Image struct { + // Provide exactly one of URL or Base64. + URL string // HTTP(S) URL + Base64 string // Raw base64-encoded data + ContentType string // MIME type (e.g., "image/png"), required for Base64 +} + +// Content represents message content with optional text and images. +type Content struct { + Text string + Images []Image +} + +// ToolCall represents a tool invocation requested by the assistant. +type ToolCall struct { + ID string + Name string + Arguments string // raw JSON +} + +// Message represents a single message in a conversation. +type Message struct { + Role Role + Content Content + + // ToolCallID is set when Role == RoleTool, identifying which tool call this responds to. + ToolCallID string + + // ToolCalls is set when the assistant requests tool invocations. + ToolCalls []ToolCall +} + +// UserMessage creates a user message with text content. +func UserMessage(text string) Message { + return Message{Role: RoleUser, Content: Content{Text: text}} +} + +// UserMessageWithImages creates a user message with text and images. +func UserMessageWithImages(text string, images ...Image) Message { + return Message{Role: RoleUser, Content: Content{Text: text, Images: images}} +} + +// SystemMessage creates a system prompt message. +func SystemMessage(text string) Message { + return Message{Role: RoleSystem, Content: Content{Text: text}} +} + +// AssistantMessage creates an assistant message with text content. +func AssistantMessage(text string) Message { + return Message{Role: RoleAssistant, Content: Content{Text: text}} +} + +// ToolResultMessage creates a tool result message. +func ToolResultMessage(toolCallID string, result string) Message { + return Message{ + Role: RoleTool, + Content: Content{Text: result}, + ToolCallID: toolCallID, + } +} diff --git a/v2/middleware.go b/v2/middleware.go new file mode 100644 index 0000000..73e1620 --- /dev/null +++ b/v2/middleware.go @@ -0,0 +1,117 @@ +package llm + +import ( + "context" + "fmt" + "log/slog" + "sync" + "time" +) + +// CompletionFunc is the signature for the completion call chain. +type CompletionFunc func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) + +// Middleware wraps a completion call. It receives the next handler in the chain +// and returns a new handler that can inspect/modify the request and response. +type Middleware func(next CompletionFunc) CompletionFunc + +// WithLogging returns middleware that logs requests and responses via slog. +func WithLogging(logger *slog.Logger) Middleware { + return func(next CompletionFunc) CompletionFunc { + return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) { + logger.Info("llm request", + "model", model, + "message_count", len(messages), + ) + start := time.Now() + resp, err := next(ctx, model, messages, cfg) + elapsed := time.Since(start) + if err != nil { + logger.Error("llm error", "model", model, "elapsed", elapsed, "error", err) + } else { + logger.Info("llm response", + "model", model, + "elapsed", elapsed, + "text_len", len(resp.Text), + "tool_calls", len(resp.ToolCalls), + ) + } + return resp, err + } + } +} + +// WithRetry returns middleware that retries failed requests with configurable backoff. +func WithRetry(maxRetries int, backoff func(attempt int) time.Duration) Middleware { + return func(next CompletionFunc) CompletionFunc { + return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) { + var lastErr error + for attempt := 0; attempt <= maxRetries; attempt++ { + if attempt > 0 { + select { + case <-ctx.Done(): + return Response{}, ctx.Err() + case <-time.After(backoff(attempt)): + } + } + resp, err := next(ctx, model, messages, cfg) + if err == nil { + return resp, nil + } + lastErr = err + } + return Response{}, fmt.Errorf("after %d retries: %w", maxRetries, lastErr) + } + } +} + +// WithTimeout returns middleware that enforces a per-request timeout. +func WithTimeout(d time.Duration) Middleware { + return func(next CompletionFunc) CompletionFunc { + return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) { + ctx, cancel := context.WithTimeout(ctx, d) + defer cancel() + return next(ctx, model, messages, cfg) + } + } +} + +// UsageTracker accumulates token usage statistics across calls. +type UsageTracker struct { + mu sync.Mutex + TotalInput int64 + TotalOutput int64 + TotalRequests int64 +} + +// Add records usage from a single request. +func (ut *UsageTracker) Add(u *Usage) { + if u == nil { + return + } + ut.mu.Lock() + defer ut.mu.Unlock() + ut.TotalInput += int64(u.InputTokens) + ut.TotalOutput += int64(u.OutputTokens) + ut.TotalRequests++ +} + +// Summary returns the accumulated totals. +func (ut *UsageTracker) Summary() (input, output, requests int64) { + ut.mu.Lock() + defer ut.mu.Unlock() + return ut.TotalInput, ut.TotalOutput, ut.TotalRequests +} + +// WithUsageTracking returns middleware that accumulates token usage across calls. +func WithUsageTracking(tracker *UsageTracker) Middleware { + return func(next CompletionFunc) CompletionFunc { + return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) { + resp, err := next(ctx, model, messages, cfg) + if err == nil { + tracker.Add(resp.Usage) + } + return resp, err + } + } +} diff --git a/v2/openai/openai.go b/v2/openai/openai.go new file mode 100644 index 0000000..ab20adf --- /dev/null +++ b/v2/openai/openai.go @@ -0,0 +1,323 @@ +// Package openai implements the go-llm v2 provider interface for OpenAI. +package openai + +import ( + "context" + "fmt" + "strings" + + "github.com/openai/openai-go" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/shared" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +// Provider implements the provider.Provider interface for OpenAI. +type Provider struct { + apiKey string + baseURL string +} + +// New creates a new OpenAI provider. +func New(apiKey string, baseURL string) *Provider { + return &Provider{apiKey: apiKey, baseURL: baseURL} +} + +// Complete performs a non-streaming completion. +func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) { + var opts []option.RequestOption + opts = append(opts, option.WithAPIKey(p.apiKey)) + if p.baseURL != "" { + opts = append(opts, option.WithBaseURL(p.baseURL)) + } + + cl := openai.NewClient(opts...) + oaiReq := p.buildRequest(req) + + resp, err := cl.Chat.Completions.New(ctx, oaiReq) + if err != nil { + return provider.Response{}, fmt.Errorf("openai completion error: %w", err) + } + + return p.convertResponse(resp), nil +} + +// Stream performs a streaming completion. +func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error { + var opts []option.RequestOption + opts = append(opts, option.WithAPIKey(p.apiKey)) + if p.baseURL != "" { + opts = append(opts, option.WithBaseURL(p.baseURL)) + } + + cl := openai.NewClient(opts...) + oaiReq := p.buildRequest(req) + + stream := cl.Chat.Completions.NewStreaming(ctx, oaiReq) + + var fullText strings.Builder + var toolCalls []provider.ToolCall + toolCallArgs := map[int]*strings.Builder{} + + for stream.Next() { + chunk := stream.Current() + for _, choice := range chunk.Choices { + // Text delta + if choice.Delta.Content != "" { + fullText.WriteString(choice.Delta.Content) + events <- provider.StreamEvent{ + Type: provider.StreamEventText, + Text: choice.Delta.Content, + } + } + + // Tool call deltas + for _, tc := range choice.Delta.ToolCalls { + idx := int(tc.Index) + + if tc.ID != "" { + // New tool call starting + for len(toolCalls) <= idx { + toolCalls = append(toolCalls, provider.ToolCall{}) + } + toolCalls[idx].ID = tc.ID + toolCalls[idx].Name = tc.Function.Name + toolCallArgs[idx] = &strings.Builder{} + + events <- provider.StreamEvent{ + Type: provider.StreamEventToolStart, + ToolCall: &provider.ToolCall{ + ID: tc.ID, + Name: tc.Function.Name, + }, + ToolIndex: idx, + } + } + + if tc.Function.Arguments != "" { + if b, ok := toolCallArgs[idx]; ok { + b.WriteString(tc.Function.Arguments) + } + events <- provider.StreamEvent{ + Type: provider.StreamEventToolDelta, + ToolIndex: idx, + ToolCall: &provider.ToolCall{ + Arguments: tc.Function.Arguments, + }, + } + } + } + } + } + + if err := stream.Err(); err != nil { + return fmt.Errorf("openai stream error: %w", err) + } + + // Finalize tool calls + for idx := range toolCalls { + if b, ok := toolCallArgs[idx]; ok { + toolCalls[idx].Arguments = b.String() + } + events <- provider.StreamEvent{ + Type: provider.StreamEventToolEnd, + ToolIndex: idx, + ToolCall: &toolCalls[idx], + } + } + + // Send done event + events <- provider.StreamEvent{ + Type: provider.StreamEventDone, + Response: &provider.Response{ + Text: fullText.String(), + ToolCalls: toolCalls, + }, + } + + return nil +} + +func (p *Provider) buildRequest(req provider.Request) openai.ChatCompletionNewParams { + oaiReq := openai.ChatCompletionNewParams{ + Model: req.Model, + } + + for _, msg := range req.Messages { + oaiReq.Messages = append(oaiReq.Messages, convertMessage(msg, req.Model)) + } + + for _, tool := range req.Tools { + oaiReq.Tools = append(oaiReq.Tools, openai.ChatCompletionToolParam{ + Type: "function", + Function: shared.FunctionDefinitionParam{ + Name: tool.Name, + Description: openai.String(tool.Description), + Parameters: openai.FunctionParameters(tool.Schema), + }, + }) + } + + if req.Temperature != nil { + // o* and gpt-5* models don't support custom temperatures + if !strings.HasPrefix(req.Model, "o") && !strings.HasPrefix(req.Model, "gpt-5") { + oaiReq.Temperature = openai.Float(*req.Temperature) + } + } + + if req.MaxTokens != nil { + oaiReq.MaxCompletionTokens = openai.Int(int64(*req.MaxTokens)) + } + + if req.TopP != nil { + oaiReq.TopP = openai.Float(*req.TopP) + } + + if len(req.Stop) > 0 { + oaiReq.Stop = openai.ChatCompletionNewParamsStopUnion{OfString: openai.String(req.Stop[0])} + } + + return oaiReq +} + +func convertMessage(msg provider.Message, model string) openai.ChatCompletionMessageParamUnion { + var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam + var textContent param.Opt[string] + + for _, img := range msg.Images { + var url string + if img.Base64 != "" { + url = "data:" + img.ContentType + ";base64," + img.Base64 + } else if img.URL != "" { + url = img.URL + } + if url != "" { + arrayOfContentParts = append(arrayOfContentParts, + openai.ChatCompletionContentPartUnionParam{ + OfImageURL: &openai.ChatCompletionContentPartImageParam{ + ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ + URL: url, + }, + }, + }, + ) + } + } + + if msg.Content != "" { + if len(arrayOfContentParts) > 0 { + arrayOfContentParts = append(arrayOfContentParts, + openai.ChatCompletionContentPartUnionParam{ + OfText: &openai.ChatCompletionContentPartTextParam{ + Text: msg.Content, + }, + }, + ) + } else { + textContent = openai.String(msg.Content) + } + } + + // Determine if this model uses developer messages instead of system + useDeveloper := false + parts := strings.Split(model, "-") + if len(parts) > 1 && len(parts[0]) > 0 && parts[0][0] == 'o' { + useDeveloper = true + } + + switch msg.Role { + case "system": + if useDeveloper { + return openai.ChatCompletionMessageParamUnion{ + OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{ + Content: openai.ChatCompletionDeveloperMessageParamContentUnion{ + OfString: textContent, + }, + }, + } + } + return openai.ChatCompletionMessageParamUnion{ + OfSystem: &openai.ChatCompletionSystemMessageParam{ + Content: openai.ChatCompletionSystemMessageParamContentUnion{ + OfString: textContent, + }, + }, + } + + case "user": + return openai.ChatCompletionMessageParamUnion{ + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.ChatCompletionUserMessageParamContentUnion{ + OfString: textContent, + OfArrayOfContentParts: arrayOfContentParts, + }, + }, + } + + case "assistant": + as := &openai.ChatCompletionAssistantMessageParam{} + if msg.Content != "" { + as.Content.OfString = openai.String(msg.Content) + } + for _, tc := range msg.ToolCalls { + as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{ + ID: tc.ID, + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: tc.Name, + Arguments: tc.Arguments, + }, + }) + } + return openai.ChatCompletionMessageParamUnion{OfAssistant: as} + + case "tool": + return openai.ChatCompletionMessageParamUnion{ + OfTool: &openai.ChatCompletionToolMessageParam{ + ToolCallID: msg.ToolCallID, + Content: openai.ChatCompletionToolMessageParamContentUnion{ + OfString: openai.String(msg.Content), + }, + }, + } + } + + // Fallback to user message + return openai.ChatCompletionMessageParamUnion{ + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.ChatCompletionUserMessageParamContentUnion{ + OfString: textContent, + }, + }, + } +} + +func (p *Provider) convertResponse(resp *openai.ChatCompletion) provider.Response { + var res provider.Response + + if resp == nil || len(resp.Choices) == 0 { + return res + } + + choice := resp.Choices[0] + res.Text = choice.Message.Content + + for _, tc := range choice.Message.ToolCalls { + res.ToolCalls = append(res.ToolCalls, provider.ToolCall{ + ID: tc.ID, + Name: tc.Function.Name, + Arguments: strings.TrimSpace(tc.Function.Arguments), + }) + } + + if resp.Usage.TotalTokens > 0 { + res.Usage = &provider.Usage{ + InputTokens: int(resp.Usage.PromptTokens), + OutputTokens: int(resp.Usage.CompletionTokens), + TotalTokens: int(resp.Usage.TotalTokens), + } + } + + return res +} diff --git a/v2/provider/provider.go b/v2/provider/provider.go new file mode 100644 index 0000000..02e79d7 --- /dev/null +++ b/v2/provider/provider.go @@ -0,0 +1,92 @@ +// Package provider defines the interface that LLM backend implementations must satisfy. +package provider + +import "context" + +// Message is the provider-level message representation. +type Message struct { + Role string + Content string + Images []Image + ToolCalls []ToolCall + ToolCallID string +} + +// Image represents an image attachment at the provider level. +type Image struct { + URL string + Base64 string + ContentType string +} + +// ToolCall represents a tool invocation requested by the model. +type ToolCall struct { + ID string + Name string + Arguments string // raw JSON +} + +// ToolDef defines a tool available to the model. +type ToolDef struct { + Name string + Description string + Schema map[string]any // JSON Schema +} + +// Request is a completion request at the provider level. +type Request struct { + Model string + Messages []Message + Tools []ToolDef + Temperature *float64 + MaxTokens *int + TopP *float64 + Stop []string +} + +// Response is a completion response at the provider level. +type Response struct { + Text string + ToolCalls []ToolCall + Usage *Usage +} + +// Usage captures token consumption. +type Usage struct { + InputTokens int + OutputTokens int + TotalTokens int +} + +// StreamEventType identifies the kind of stream event. +type StreamEventType int + +const ( + StreamEventText StreamEventType = iota // Text content delta + StreamEventToolStart // Tool call begins + StreamEventToolDelta // Tool call argument delta + StreamEventToolEnd // Tool call complete + StreamEventDone // Stream complete + StreamEventError // Error occurred +) + +// StreamEvent represents a single event in a streaming response. +type StreamEvent struct { + Type StreamEventType + Text string + ToolCall *ToolCall + ToolIndex int + Error error + Response *Response +} + +// Provider is the interface that LLM backends implement. +type Provider interface { + // Complete performs a non-streaming completion. + Complete(ctx context.Context, req Request) (Response, error) + + // Stream performs a streaming completion, sending events to the channel. + // The provider MUST close the channel when done. + // The provider MUST send exactly one StreamEventDone as the last non-error event. + Stream(ctx context.Context, req Request, events chan<- StreamEvent) error +} diff --git a/v2/request.go b/v2/request.go new file mode 100644 index 0000000..198538f --- /dev/null +++ b/v2/request.go @@ -0,0 +1,37 @@ +package llm + +// RequestOption configures a single completion request. +type RequestOption func(*requestConfig) + +type requestConfig struct { + tools *ToolBox + temperature *float64 + maxTokens *int + topP *float64 + stop []string +} + +// WithTools attaches a toolbox to the request. +func WithTools(tb *ToolBox) RequestOption { + return func(c *requestConfig) { c.tools = tb } +} + +// WithTemperature sets the sampling temperature. +func WithTemperature(t float64) RequestOption { + return func(c *requestConfig) { c.temperature = &t } +} + +// WithMaxTokens sets the maximum number of tokens to generate. +func WithMaxTokens(n int) RequestOption { + return func(c *requestConfig) { c.maxTokens = &n } +} + +// WithTopP sets the nucleus sampling parameter. +func WithTopP(p float64) RequestOption { + return func(c *requestConfig) { c.topP = &p } +} + +// WithStop sets stop sequences. +func WithStop(sequences ...string) RequestOption { + return func(c *requestConfig) { c.stop = sequences } +} diff --git a/v2/response.go b/v2/response.go new file mode 100644 index 0000000..a19c397 --- /dev/null +++ b/v2/response.go @@ -0,0 +1,34 @@ +package llm + +// Response represents the result of a completion request. +type Response struct { + // Text is the assistant's text content. Empty if only tool calls. + Text string + + // ToolCalls contains any tool invocations the assistant requested. + ToolCalls []ToolCall + + // Usage contains token usage information (if available from provider). + Usage *Usage + + // message is the full assistant message for this response. + message Message +} + +// Message returns the full assistant Message for this response, +// suitable for appending to the conversation history. +func (r Response) Message() Message { + return r.message +} + +// HasToolCalls returns true if the response contains tool call requests. +func (r Response) HasToolCalls() bool { + return len(r.ToolCalls) > 0 +} + +// Usage captures token consumption. +type Usage struct { + InputTokens int + OutputTokens int + TotalTokens int +} diff --git a/v2/stream.go b/v2/stream.go new file mode 100644 index 0000000..8caf8a0 --- /dev/null +++ b/v2/stream.go @@ -0,0 +1,163 @@ +package llm + +import ( + "context" + "fmt" + "io" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +// StreamEventType identifies the kind of stream event. +type StreamEventType = provider.StreamEventType + +const ( + StreamEventText = provider.StreamEventText + StreamEventToolStart = provider.StreamEventToolStart + StreamEventToolDelta = provider.StreamEventToolDelta + StreamEventToolEnd = provider.StreamEventToolEnd + StreamEventDone = provider.StreamEventDone + StreamEventError = provider.StreamEventError +) + +// StreamEvent represents a single event in a streaming response. +type StreamEvent struct { + Type StreamEventType + + // Text is set for StreamEventText — the text delta. + Text string + + // ToolCall is set for StreamEventToolStart/ToolDelta/ToolEnd. + ToolCall *ToolCall + + // ToolIndex identifies which tool call is being updated. + ToolIndex int + + // Error is set for StreamEventError. + Error error + + // Response is set for StreamEventDone — the complete, aggregated response. + Response *Response +} + +// StreamReader reads streaming events from an LLM response. +// Must be closed when done. +type StreamReader struct { + events <-chan StreamEvent + cancel context.CancelFunc + done bool +} + +func newStreamReader(ctx context.Context, p provider.Provider, req provider.Request) (*StreamReader, error) { + ctx, cancel := context.WithCancel(ctx) + providerEvents := make(chan provider.StreamEvent, 32) + + publicEvents := make(chan StreamEvent, 32) + + go func() { + defer close(publicEvents) + for pev := range providerEvents { + ev := convertStreamEvent(pev) + select { + case publicEvents <- ev: + case <-ctx.Done(): + return + } + } + }() + + go func() { + defer close(providerEvents) + if err := p.Stream(ctx, req, providerEvents); err != nil { + select { + case providerEvents <- provider.StreamEvent{Type: provider.StreamEventError, Error: err}: + default: + } + } + }() + + return &StreamReader{ + events: publicEvents, + cancel: cancel, + }, nil +} + +func convertStreamEvent(pev provider.StreamEvent) StreamEvent { + ev := StreamEvent{ + Type: pev.Type, + Text: pev.Text, + ToolIndex: pev.ToolIndex, + } + if pev.Error != nil { + ev.Error = pev.Error + } + if pev.ToolCall != nil { + tc := ToolCall{ + ID: pev.ToolCall.ID, + Name: pev.ToolCall.Name, + Arguments: pev.ToolCall.Arguments, + } + ev.ToolCall = &tc + } + if pev.Response != nil { + resp := convertProviderResponse(*pev.Response) + ev.Response = &resp + } + return ev +} + +// Next returns the next event from the stream. +// Returns io.EOF when the stream is complete. +func (sr *StreamReader) Next() (StreamEvent, error) { + if sr.done { + return StreamEvent{}, io.EOF + } + ev, ok := <-sr.events + if !ok { + sr.done = true + return StreamEvent{}, io.EOF + } + if ev.Type == StreamEventError { + return ev, ev.Error + } + if ev.Type == StreamEventDone { + sr.done = true + } + return ev, nil +} + +// Close closes the stream reader and releases resources. +func (sr *StreamReader) Close() error { + sr.cancel() + return nil +} + +// Collect reads all events and returns the final aggregated Response. +func (sr *StreamReader) Collect() (Response, error) { + var lastResp *Response + for { + ev, err := sr.Next() + if err == io.EOF { + break + } + if err != nil { + return Response{}, err + } + if ev.Type == StreamEventDone && ev.Response != nil { + lastResp = ev.Response + } + } + if lastResp == nil { + return Response{}, fmt.Errorf("stream completed without final response") + } + return *lastResp, nil +} + +// Text is a convenience that collects the stream and returns just the text. +func (sr *StreamReader) Text() (string, error) { + resp, err := sr.Collect() + if err != nil { + return "", err + } + return resp.Text, nil +} diff --git a/v2/tool.go b/v2/tool.go new file mode 100644 index 0000000..c90f1c7 --- /dev/null +++ b/v2/tool.go @@ -0,0 +1,193 @@ +package llm + +import ( + "context" + "encoding/json" + "fmt" + "reflect" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/internal/schema" +) + +// Tool defines a tool that the LLM can invoke. +type Tool struct { + // Name is the tool's unique identifier. + Name string + + // Description tells the LLM what this tool does. + Description string + + // Schema is the JSON Schema for the tool's parameters. + Schema map[string]any + + // fn holds the implementation function (set via Define or DefineSimple). + fn reflect.Value + pTyp reflect.Type // nil for parameterless tools + + // isMCP indicates this tool is provided by an MCP server. + isMCP bool + mcpServer *MCPServer +} + +// Define creates a tool from a typed handler function. +// T must be a struct. Struct fields become the tool's parameters. +// +// Struct tags: +// - `json:"name"` — parameter name +// - `description:"..."` — parameter description +// - `enum:"a,b,c"` — enum constraint +// +// Pointer fields are optional; non-pointer fields are required. +// +// Example: +// +// type WeatherParams struct { +// City string `json:"city" description:"The city to query"` +// Unit string `json:"unit" description:"Temperature unit" enum:"celsius,fahrenheit"` +// } +// +// llm.Define[WeatherParams]("get_weather", "Get weather for a city", +// func(ctx context.Context, p WeatherParams) (string, error) { +// return fmt.Sprintf("72F in %s", p.City), nil +// }, +// ) +func Define[T any](name, description string, fn func(context.Context, T) (string, error)) Tool { + var zero T + return Tool{ + Name: name, + Description: description, + Schema: schema.FromStruct(zero), + fn: reflect.ValueOf(fn), + pTyp: reflect.TypeOf(zero), + } +} + +// DefineSimple creates a parameterless tool. +// +// Example: +// +// llm.DefineSimple("get_time", "Get the current time", +// func(ctx context.Context) (string, error) { +// return time.Now().Format(time.RFC3339), nil +// }, +// ) +func DefineSimple(name, description string, fn func(context.Context) (string, error)) Tool { + return Tool{ + Name: name, + Description: description, + Schema: map[string]any{"type": "object", "properties": map[string]any{}}, + fn: reflect.ValueOf(fn), + } +} + +// Execute runs the tool with the given JSON arguments string. +func (t Tool) Execute(ctx context.Context, argsJSON string) (string, error) { + if t.isMCP { + var args map[string]any + if argsJSON != "" && argsJSON != "{}" { + if err := json.Unmarshal([]byte(argsJSON), &args); err != nil { + return "", fmt.Errorf("invalid MCP tool arguments: %w", err) + } + } + return t.mcpServer.CallTool(ctx, t.Name, args) + } + + // Parameterless tool + if t.pTyp == nil { + out := t.fn.Call([]reflect.Value{reflect.ValueOf(ctx)}) + if !out[1].IsNil() { + return "", out[1].Interface().(error) + } + return out[0].String(), nil + } + + // Typed tool: unmarshal JSON into the struct, call the function + p := reflect.New(t.pTyp) + if argsJSON != "" && argsJSON != "{}" { + if err := json.Unmarshal([]byte(argsJSON), p.Interface()); err != nil { + return "", fmt.Errorf("invalid tool arguments: %w", err) + } + } + out := t.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()}) + if !out[1].IsNil() { + return "", out[1].Interface().(error) + } + return out[0].String(), nil +} + +// ToolBox is a collection of tools available for use by an LLM. +type ToolBox struct { + tools map[string]Tool + mcpServers []*MCPServer +} + +// NewToolBox creates a new ToolBox from the given tools. +func NewToolBox(tools ...Tool) *ToolBox { + tb := &ToolBox{tools: make(map[string]Tool)} + for _, t := range tools { + tb.tools[t.Name] = t + } + return tb +} + +// Add adds tools to the toolbox and returns it for chaining. +func (tb *ToolBox) Add(tools ...Tool) *ToolBox { + if tb.tools == nil { + tb.tools = make(map[string]Tool) + } + for _, t := range tools { + tb.tools[t.Name] = t + } + return tb +} + +// AddMCP adds an MCP server's tools to the toolbox. The server must be connected. +func (tb *ToolBox) AddMCP(server *MCPServer) *ToolBox { + if tb.tools == nil { + tb.tools = make(map[string]Tool) + } + tb.mcpServers = append(tb.mcpServers, server) + + for _, tool := range server.ListTools() { + tb.tools[tool.Name] = tool + } + return tb +} + +// AllTools returns all tools (local + MCP) as a slice. +func (tb *ToolBox) AllTools() []Tool { + if tb == nil { + return nil + } + tools := make([]Tool, 0, len(tb.tools)) + for _, t := range tb.tools { + tools = append(tools, t) + } + return tools +} + +// Execute executes a tool call by name. +func (tb *ToolBox) Execute(ctx context.Context, call ToolCall) (string, error) { + if tb == nil { + return "", ErrNoToolsConfigured + } + tool, ok := tb.tools[call.Name] + if !ok { + return "", fmt.Errorf("%w: %s", ErrToolNotFound, call.Name) + } + return tool.Execute(ctx, call.Arguments) +} + +// ExecuteAll executes all tool calls and returns tool result messages. +func (tb *ToolBox) ExecuteAll(ctx context.Context, calls []ToolCall) ([]Message, error) { + var results []Message + for _, call := range calls { + result, err := tb.Execute(ctx, call) + text := result + if err != nil { + text = "Error: " + err.Error() + } + results = append(results, ToolResultMessage(call.ID, text)) + } + return results, nil +} diff --git a/v2/tool_test.go b/v2/tool_test.go new file mode 100644 index 0000000..b69fdb1 --- /dev/null +++ b/v2/tool_test.go @@ -0,0 +1,139 @@ +package llm + +import ( + "context" + "encoding/json" + "testing" +) + +type calcParams struct { + A float64 `json:"a" description:"First number"` + B float64 `json:"b" description:"Second number"` + Op string `json:"op" description:"Operation" enum:"add,sub,mul,div"` +} + +func TestDefine(t *testing.T) { + tool := Define[calcParams]("calc", "Calculator", + func(ctx context.Context, p calcParams) (string, error) { + var result float64 + switch p.Op { + case "add": + result = p.A + p.B + case "sub": + result = p.A - p.B + case "mul": + result = p.A * p.B + case "div": + result = p.A / p.B + } + b, err := json.Marshal(result) + return string(b), err + }, + ) + + if tool.Name != "calc" { + t.Errorf("expected name 'calc', got %q", tool.Name) + } + if tool.Description != "Calculator" { + t.Errorf("expected description 'Calculator', got %q", tool.Description) + } + if tool.Schema["type"] != "object" { + t.Errorf("expected schema type=object, got %v", tool.Schema["type"]) + } + + // Test execution + result, err := tool.Execute(context.Background(), `{"a": 10, "b": 3, "op": "add"}`) + if err != nil { + t.Fatalf("execute failed: %v", err) + } + if result != "13" { + t.Errorf("expected '13', got %q", result) + } +} + +func TestDefineSimple(t *testing.T) { + tool := DefineSimple("hello", "Say hello", + func(ctx context.Context) (string, error) { + return "Hello, world!", nil + }, + ) + + result, err := tool.Execute(context.Background(), "") + if err != nil { + t.Fatalf("execute failed: %v", err) + } + if result != "Hello, world!" { + t.Errorf("expected 'Hello, world!', got %q", result) + } +} + +func TestToolBox(t *testing.T) { + tool1 := DefineSimple("tool1", "Tool 1", func(ctx context.Context) (string, error) { + return "result1", nil + }) + tool2 := DefineSimple("tool2", "Tool 2", func(ctx context.Context) (string, error) { + return "result2", nil + }) + + tb := NewToolBox(tool1, tool2) + + tools := tb.AllTools() + if len(tools) != 2 { + t.Errorf("expected 2 tools, got %d", len(tools)) + } + + result, err := tb.Execute(context.Background(), ToolCall{ID: "1", Name: "tool1"}) + if err != nil { + t.Fatalf("execute failed: %v", err) + } + if result != "result1" { + t.Errorf("expected 'result1', got %q", result) + } + + // Test not found + _, err = tb.Execute(context.Background(), ToolCall{ID: "x", Name: "nonexistent"}) + if err == nil { + t.Error("expected error for nonexistent tool") + } +} + +func TestToolBoxExecuteAll(t *testing.T) { + tb := NewToolBox( + DefineSimple("t1", "T1", func(ctx context.Context) (string, error) { + return "r1", nil + }), + DefineSimple("t2", "T2", func(ctx context.Context) (string, error) { + return "r2", nil + }), + ) + + calls := []ToolCall{ + {ID: "c1", Name: "t1"}, + {ID: "c2", Name: "t2"}, + } + + msgs, err := tb.ExecuteAll(context.Background(), calls) + if err != nil { + t.Fatalf("execute all failed: %v", err) + } + + if len(msgs) != 2 { + t.Fatalf("expected 2 messages, got %d", len(msgs)) + } + + if msgs[0].Role != RoleTool { + t.Errorf("expected role=tool, got %v", msgs[0].Role) + } + if msgs[0].ToolCallID != "c1" { + t.Errorf("expected toolCallID=c1, got %v", msgs[0].ToolCallID) + } + if msgs[0].Content.Text != "r1" { + t.Errorf("expected content=r1, got %v", msgs[0].Content.Text) + } +} + +// jsonMarshal helper for calcParams test +func (p calcParams) jsonMarshal(result float64) (string, error) { + b, err := json.Marshal(result) + return string(b), err +} diff --git a/v2/tools/browser.go b/v2/tools/browser.go new file mode 100644 index 0000000..503fe61 --- /dev/null +++ b/v2/tools/browser.go @@ -0,0 +1,59 @@ +package tools + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2" +) + +// BrowserParams defines parameters for the browser tool. +type BrowserParams struct { + URL string `json:"url" description:"The URL to fetch and extract text from"` +} + +// Browser creates a simple web content fetcher tool. +// It fetches a URL and returns the text content. +// +// For a full headless browser, consider using an MCP server like Playwright MCP. +// +// Example: +// +// tools := llm.NewToolBox(tools.Browser()) +func Browser() llm.Tool { + return llm.Define[BrowserParams]("browser", "Fetch a web page and return its text content", + func(ctx context.Context, p BrowserParams) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, p.URL, nil) + if err != nil { + return "", fmt.Errorf("creating request: %w", err) + } + req.Header.Set("User-Agent", "go-llm/2.0 (Web Fetcher)") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("fetching URL: %w", err) + } + defer resp.Body.Close() + + // Limit to 1MB + limited := io.LimitReader(resp.Body, 1<<20) + body, err := io.ReadAll(limited) + if err != nil { + return "", fmt.Errorf("reading body: %w", err) + } + + result := map[string]any{ + "url": p.URL, + "status": resp.StatusCode, + "content_type": resp.Header.Get("Content-Type"), + "body": string(body), + } + + out, _ := json.MarshalIndent(result, "", " ") + return string(out), nil + }, + ) +} diff --git a/v2/tools/exec.go b/v2/tools/exec.go new file mode 100644 index 0000000..187082f --- /dev/null +++ b/v2/tools/exec.go @@ -0,0 +1,101 @@ +package tools + +import ( + "context" + "fmt" + "os/exec" + "runtime" + "strings" + "time" + + llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2" +) + +// ExecParams defines parameters for the exec tool. +type ExecParams struct { + Command string `json:"command" description:"The shell command to execute"` +} + +// ExecOption configures the exec tool. +type ExecOption func(*execConfig) + +type execConfig struct { + allowedCommands []string + workDir string + timeout time.Duration +} + +// WithAllowedCommands restricts which commands can be executed. +// If empty, all commands are allowed. +func WithAllowedCommands(cmds []string) ExecOption { + return func(c *execConfig) { c.allowedCommands = cmds } +} + +// WithWorkDir sets the working directory for command execution. +func WithWorkDir(dir string) ExecOption { + return func(c *execConfig) { c.workDir = dir } +} + +// WithExecTimeout sets the maximum execution time. +func WithExecTimeout(d time.Duration) ExecOption { + return func(c *execConfig) { c.timeout = d } +} + +// Exec creates a shell command execution tool. +// +// Example: +// +// tools := llm.NewToolBox( +// tools.Exec(tools.WithAllowedCommands([]string{"ls", "cat", "grep"})), +// ) +func Exec(opts ...ExecOption) llm.Tool { + cfg := &execConfig{ + timeout: 30 * time.Second, + } + for _, opt := range opts { + opt(cfg) + } + + return llm.Define[ExecParams]("exec", "Execute a shell command and return its output", + func(ctx context.Context, p ExecParams) (string, error) { + // Check allowed commands + if len(cfg.allowedCommands) > 0 { + parts := strings.Fields(p.Command) + if len(parts) == 0 { + return "", fmt.Errorf("empty command") + } + allowed := false + for _, cmd := range cfg.allowedCommands { + if parts[0] == cmd { + allowed = true + break + } + } + if !allowed { + return "", fmt.Errorf("command %q is not in the allowed list", parts[0]) + } + } + + ctx, cancel := context.WithTimeout(ctx, cfg.timeout) + defer cancel() + + var cmd *exec.Cmd + if runtime.GOOS == "windows" { + cmd = exec.CommandContext(ctx, "cmd", "/C", p.Command) + } else { + cmd = exec.CommandContext(ctx, "sh", "-c", p.Command) + } + + if cfg.workDir != "" { + cmd.Dir = cfg.workDir + } + + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Sprintf("Error: %s\nOutput: %s", err.Error(), string(output)), nil + } + + return string(output), nil + }, + ) +} diff --git a/v2/tools/http.go b/v2/tools/http.go new file mode 100644 index 0000000..7c80496 --- /dev/null +++ b/v2/tools/http.go @@ -0,0 +1,75 @@ +package tools + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2" +) + +// HTTPParams defines parameters for the HTTP request tool. +type HTTPParams struct { + Method string `json:"method" description:"HTTP method" enum:"GET,POST,PUT,DELETE,PATCH,HEAD"` + URL string `json:"url" description:"Request URL"` + Headers map[string]string `json:"headers,omitempty" description:"Request headers"` + Body *string `json:"body,omitempty" description:"Request body"` +} + +// HTTP creates an HTTP request tool. +// +// Example: +// +// tools := llm.NewToolBox(tools.HTTP()) +func HTTP() llm.Tool { + return llm.Define[HTTPParams]("http_request", "Make an HTTP request and return the response", + func(ctx context.Context, p HTTPParams) (string, error) { + var bodyReader io.Reader + if p.Body != nil { + bodyReader = bytes.NewBufferString(*p.Body) + } + + req, err := http.NewRequestWithContext(ctx, p.Method, p.URL, bodyReader) + if err != nil { + return "", fmt.Errorf("creating request: %w", err) + } + + for k, v := range p.Headers { + req.Header.Set(k, v) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + // Limit to 1MB + limited := io.LimitReader(resp.Body, 1<<20) + body, err := io.ReadAll(limited) + if err != nil { + return "", fmt.Errorf("reading response: %w", err) + } + + headers := map[string]string{} + for k, v := range resp.Header { + if len(v) > 0 { + headers[k] = v[0] + } + } + + result := map[string]any{ + "status": resp.StatusCode, + "status_text": resp.Status, + "headers": headers, + "body": string(body), + } + + out, _ := json.MarshalIndent(result, "", " ") + return string(out), nil + }, + ) +} diff --git a/v2/tools/readfile.go b/v2/tools/readfile.go new file mode 100644 index 0000000..61a5dfc --- /dev/null +++ b/v2/tools/readfile.go @@ -0,0 +1,81 @@ +package tools + +import ( + "bufio" + "context" + "fmt" + "os" + "strings" + + llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2" +) + +// ReadFileParams defines parameters for the read file tool. +type ReadFileParams struct { + Path string `json:"path" description:"File path to read"` + Start *int `json:"start,omitempty" description:"Starting line number (1-based, inclusive)"` + End *int `json:"end,omitempty" description:"Ending line number (1-based, inclusive)"` +} + +// ReadFile creates a file reading tool. +// +// Example: +// +// tools := llm.NewToolBox(tools.ReadFile()) +func ReadFile() llm.Tool { + return llm.Define[ReadFileParams]("read_file", "Read the contents of a file", + func(ctx context.Context, p ReadFileParams) (string, error) { + f, err := os.Open(p.Path) + if err != nil { + return "", fmt.Errorf("opening file: %w", err) + } + defer f.Close() + + // If no line range specified, read the whole file (limited to 1MB) + if p.Start == nil && p.End == nil { + info, err := f.Stat() + if err != nil { + return "", fmt.Errorf("stat file: %w", err) + } + if info.Size() > 1<<20 { + return "", fmt.Errorf("file too large (%d bytes), use start/end to read a range", info.Size()) + } + data, err := os.ReadFile(p.Path) + if err != nil { + return "", fmt.Errorf("reading file: %w", err) + } + return string(data), nil + } + + // Read specific line range + start := 1 + end := -1 + if p.Start != nil { + start = *p.Start + } + if p.End != nil { + end = *p.End + } + + var lines []string + scanner := bufio.NewScanner(f) + lineNum := 0 + for scanner.Scan() { + lineNum++ + if lineNum < start { + continue + } + if end > 0 && lineNum > end { + break + } + lines = append(lines, fmt.Sprintf("%d: %s", lineNum, scanner.Text())) + } + + if err := scanner.Err(); err != nil { + return "", fmt.Errorf("scanning file: %w", err) + } + + return strings.Join(lines, "\n"), nil + }, + ) +} diff --git a/v2/tools/websearch.go b/v2/tools/websearch.go new file mode 100644 index 0000000..84a6226 --- /dev/null +++ b/v2/tools/websearch.go @@ -0,0 +1,101 @@ +// Package tools provides ready-to-use tool implementations for common agent patterns. +package tools + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + + llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2" +) + +// WebSearchParams defines parameters for the web search tool. +type WebSearchParams struct { + Query string `json:"query" description:"The search query"` + Count *int `json:"count,omitempty" description:"Number of results to return (default 5, max 20)"` +} + +// WebSearch creates a web search tool using the Brave Search API. +// +// Get a free API key at https://brave.com/search/api/ +// +// Example: +// +// tools := llm.NewToolBox(tools.WebSearch("your-brave-api-key")) +func WebSearch(apiKey string) llm.Tool { + return llm.Define[WebSearchParams]("web_search", "Search the web for information using Brave Search", + func(ctx context.Context, p WebSearchParams) (string, error) { + count := 5 + if p.Count != nil && *p.Count > 0 { + count = *p.Count + if count > 20 { + count = 20 + } + } + + u := fmt.Sprintf("https://api.search.brave.com/res/v1/web/search?q=%s&count=%d", + url.QueryEscape(p.Query), count) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + if err != nil { + return "", fmt.Errorf("creating request: %w", err) + } + req.Header.Set("Accept", "application/json") + req.Header.Set("X-Subscription-Token", apiKey) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("search request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("reading response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("search API returned %d: %s", resp.StatusCode, string(body)) + } + + // Parse and simplify the response + var raw map[string]any + if err := json.Unmarshal(body, &raw); err != nil { + return string(body), nil + } + + type result struct { + Title string `json:"title"` + URL string `json:"url"` + Description string `json:"description"` + } + + var results []result + if web, ok := raw["web"].(map[string]any); ok { + if items, ok := web["results"].([]any); ok { + for _, item := range items { + if m, ok := item.(map[string]any); ok { + r := result{} + if t, ok := m["title"].(string); ok { + r.Title = t + } + if u, ok := m["url"].(string); ok { + r.URL = u + } + if d, ok := m["description"].(string); ok { + r.Description = d + } + results = append(results, r) + } + } + } + } + + out, _ := json.MarshalIndent(results, "", " ") + return string(out), nil + }, + ) +} diff --git a/v2/tools/writefile.go b/v2/tools/writefile.go new file mode 100644 index 0000000..68e0d4e --- /dev/null +++ b/v2/tools/writefile.go @@ -0,0 +1,31 @@ +package tools + +import ( + "context" + "fmt" + "os" + + llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2" +) + +// WriteFileParams defines parameters for the write file tool. +type WriteFileParams struct { + Path string `json:"path" description:"File path to write"` + Content string `json:"content" description:"Content to write to the file"` +} + +// WriteFile creates a file writing tool. +// +// Example: +// +// tools := llm.NewToolBox(tools.WriteFile()) +func WriteFile() llm.Tool { + return llm.Define[WriteFileParams]("write_file", "Write content to a file (creates or overwrites)", + func(ctx context.Context, p WriteFileParams) (string, error) { + if err := os.WriteFile(p.Path, []byte(p.Content), 0644); err != nil { + return "", fmt.Errorf("writing file: %w", err) + } + return fmt.Sprintf("Successfully wrote %d bytes to %s", len(p.Content), p.Path), nil + }, + ) +}