Compare commits

...

17 Commits

Author SHA1 Message Date
7e1705c385 feat: add audio input support to v2 providers
All checks were successful
CI / Lint (push) Successful in 9m37s
CI / Root Module (push) Successful in 10m53s
CI / V2 Module (push) Successful in 11m9s
Add Audio struct alongside Image for sending audio attachments to
multimodal LLMs. OpenAI uses input_audio content parts (wav/mp3),
Google Gemini uses genai.NewPartFromBytes, and Anthropic skips
audio gracefully since it's not supported.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 21:00:56 -05:00
fc2218b5fe Add comprehensive test suite for sandbox package (78 tests)
All checks were successful
CI / Lint (push) Successful in 9m35s
CI / V2 Module (push) Successful in 10m39s
CI / Root Module (push) Successful in 11m2s
Expanded from 22 basic tests to 78 tests covering error injection,
task polling, IP discovery, context cancellation, HTTP error codes,
concurrent access, SSH lifecycle, and request verification.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 01:10:59 -05:00
23c9068022 Add sandbox package for isolated Linux containers via Proxmox LXC
All checks were successful
CI / V2 Module (push) Successful in 11m46s
CI / Root Module (push) Successful in 11m50s
CI / Lint (push) Successful in 9m28s
Provides a complete lifecycle manager for ephemeral sandbox environments:
- ProxmoxClient: thin REST wrapper for container CRUD, IP discovery, internet toggle
- SSHExecutor: persistent SSH/SFTP for command execution and file transfer
- Manager/Sandbox: high-level orchestrator tying Proxmox + SSH together
- 22 unit tests with mock Proxmox HTTP server
- Proxmox setup & hardening guide (docs/sandbox-setup.md)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 00:47:45 -05:00
87ec56a2be Add agent sub-package for composable LLM agents
All checks were successful
CI / Lint (push) Successful in 9m46s
CI / V2 Module (push) Successful in 12m5s
CI / Root Module (push) Successful in 12m6s
Introduces v2/agent with a minimal API: Agent, New(), Run(), and AsTool().
Agents wrap a model + system prompt + tools. AsTool() turns an agent into
a llm.Tool, enabling parent agents to delegate to sub-agents through the
normal tool-call loop — no channels, pools, or orchestration needed.

Also exports NewClient(provider.Provider) for custom provider integration.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 23:17:19 -05:00
be572a76f4 Add structured output support with Generate[T] and GenerateWith[T]
All checks were successful
CI / Lint (push) Successful in 9m35s
CI / V2 Module (push) Successful in 11m43s
CI / Root Module (push) Successful in 11m53s
Generic functions that use the "hidden tool" technique to force models
to return structured JSON matching a Go struct's schema, replacing the
verbose "tool as structured output" pattern.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 22:36:33 -05:00
6a7eeef619 Add comprehensive test suite for v2 module with mock provider
All checks were successful
CI / Lint (push) Successful in 9m36s
CI / V2 Module (push) Successful in 11m33s
CI / Root Module (push) Successful in 11m35s
Cover all core library logic (Client, Model, Chat, middleware, streaming,
message conversion, request building) using a configurable mock provider
that avoids real API calls. ~50 tests across 7 files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 22:00:49 -05:00
cbe340ced0 Fix corrupted checksum for charmbracelet/bubbles in go.sum
All checks were successful
CI / Lint (push) Successful in 9m34s
CI / V2 Module (push) Successful in 11m34s
CI / Root Module (push) Successful in 11m35s
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 21:39:32 -05:00
9e288954f2 Add transcription API to v2 module
Some checks failed
CI / Lint (push) Failing after 5m0s
CI / Root Module (push) Failing after 5m3s
CI / V2 Module (push) Successful in 10m48s
Migrate speech-to-text transcription types and OpenAI transcriber
implementation from v1. Types are defined in provider/ to avoid
import cycles and re-exported via type aliases from the root package.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 20:24:20 -05:00
9d6d2c61c3 Add Gitea CI workflow for build, test, vet, and lint
Some checks failed
CI / Lint (push) Failing after 29s
CI / Root Module (push) Failing after 5m19s
CI / V2 Module (push) Successful in 11m9s
Runs on all pushes and PRs:
- Build, vet, and test both root and v2 modules (with -race)
- Verify go.mod/go.sum tidiness for both modules

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 20:01:01 -05:00
a4cb4baab5 Add go-llm v2: redesigned API for simpler LLM abstraction
v2 is a new Go module (v2/) with a dramatically simpler API:
- Unified Message type (no more Input marker interface)
- Define[T] for ergonomic tool creation with standard context.Context
- Chat session with automatic tool-call loop (agent loop)
- Streaming via pull-based StreamReader
- MCP one-call connect (MCPStdioServer, MCPHTTPServer, MCPSSEServer)
- Middleware support (logging, retry, timeout, usage tracking)
- Decoupled JSON Schema (map[string]any, no provider coupling)
- Sample tools: WebSearch, Browser, Exec, ReadFile, WriteFile, HTTP
- Providers: OpenAI, Anthropic, Google (all with streaming)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 20:00:08 -05:00
85a848d96e Update openaiTranscriber to handle audio file metadata in transcription parameters 2026-01-25 02:38:45 -05:00
8801ce5945 Add OpenAI-based transcriber implementation
- Introduce `openaiTranscriber` for integrating OpenAI's Whisper audio transcription capabilities.
- Define `Transcriber` interface and associated types (`Transcription`, `TranscriptionOptions`, segments, and words).
- Implement transcription logic supporting features like languages, prompts, temperature, and timestamp granularities.
- Add `audioFileToWav` utility using `ffmpeg` for audio file conversion to WAV format.
- Ensure response parsing for structured and verbose JSON outputs.
2026-01-25 01:46:29 -05:00
9c1b4f7e9f Fix checksum typo for github.com/charmbracelet/bubbles in go.sum 2026-01-24 16:59:55 -05:00
2cf75ae07d Add MCP integration with MCPServer for tool-based interactions
- Introduce `MCPServer` to support connecting to MCP servers via stdio, SSE, or HTTP.
- Implement tool fetching, management, and invocation through MCP.
- Add `WithMCPServer` method to `ToolBox` for seamless tool integration.
- Extend schema package to handle raw JSON schemas for MCP tools.
- Update documentation with MCP usage guidelines and examples.
2026-01-24 16:25:28 -05:00
97d54c10ae Implement interactive CLI for LLM providers with chat, tools, and image support
- Add Bubble Tea-based CLI interface for LLM interactions.
- Implement `.env.example` for environment variable setup.
- Add provider, model, and tool selection screens.
- Include support for API key configuration.
- Enable chat interactions with optional image and tool support.
- Introduce core utility functions: image handling, tool execution, chat request management, and response rendering.
- Implement style customization with Lip Gloss.
2026-01-24 15:53:36 -05:00
bf7c86ab2a Refactor: modularize and streamline LLM providers and utility functions
- Migrate `compress_image.go` to `internal/imageutil` for better encapsulation.
- Reorganize LLM provider implementations into distinct packages (`google`, `openai`, and `anthropic`).
- Replace `go_llm` package name with `llm`.
- Refactor internal APIs for improved clarity, including renaming `anthropic` to `anthropicImpl` and `google` to `googleImpl`.
- Add helper methods and restructure message handling for better separation of concerns.
2026-01-24 15:40:38 -05:00
be99af3597 Update all dependencies and migrate to new Google genai SDK
- Update all Go dependencies to latest versions
- Migrate from github.com/google/generative-ai-go/genai to google.golang.org/genai
- Fix google.go to use the new SDK API (NewPartFromText, NewContentFromParts, etc.)
- Update schema package imports to use the new genai package
- Add CLAUDE.md with README maintenance guideline

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-24 15:22:34 -05:00
88 changed files with 13220 additions and 515 deletions

76
.gitea/workflows/ci.yaml Normal file
View File

@@ -0,0 +1,76 @@
name: CI
on:
push:
branches: ["*"]
pull_request:
branches: ["*"]
jobs:
root-module:
name: Root Module
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: "1.24"
- name: Download dependencies
run: go mod download
- name: Build
run: go build ./...
- name: Vet
run: go vet ./...
- name: Test
run: go test -race -count=1 ./...
v2-module:
name: V2 Module
runs-on: ubuntu-latest
defaults:
run:
working-directory: v2
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: "1.24"
- name: Download dependencies
run: go mod download
- name: Build
run: go build ./...
- name: Vet
run: go vet ./...
- name: Test
run: go test -race -count=1 ./...
lint:
name: Lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: "1.24"
- name: Check root module tidiness
run: |
go mod tidy
git diff --exit-code go.mod go.sum
- name: Check v2 module tidiness
run: |
cd v2
go mod tidy
git diff --exit-code go.mod go.sum

4
.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
.claude
.idea
*.exe
.env

88
CLAUDE.md Normal file
View File

@@ -0,0 +1,88 @@
# CLAUDE.md for go-llm
## Build and Test Commands
- Build project: `go build ./...`
- Run all tests: `go test ./...`
- Run specific test: `go test -v -run <TestName> ./...`
- Tidy dependencies: `go mod tidy`
## Code Style Guidelines
- **Indentation**: Use standard Go tabs for indentation.
- **Naming**:
- Use `camelCase` for internal/private variables and functions.
- Use `PascalCase` for exported types, functions, and struct fields.
- Interface names should be concise (e.g., `LLM`, `ChatCompletion`).
- **Error Handling**:
- Always check and handle errors immediately.
- Wrap errors with context using `fmt.Errorf("%w: ...", err)`.
- Use the project's internal `Error` struct in `error.go` when differentiating between error types is needed.
- **Project Structure**:
- `llm.go`: Contains core interfaces (`LLM`, `ChatCompletion`) and shared types (`Message`, `Role`, `Image`).
- Provider implementations are in `openai.go`, `anthropic.go`, and `google.go`.
- Schema definitions for tool calling are in the `schema/` directory.
- `mcp.go`: MCP (Model Context Protocol) client integration for connecting to MCP servers.
- **Imports**: Organize imports into groups: standard library, then third-party libraries.
- **Documentation**: Use standard Go doc comments for exported symbols.
- **README.md**: The README.md file should always be kept up to date with any significant changes to the project.
## CLI Tool
- Build CLI: `go build ./cmd/llm`
- Run CLI: `./llm` (or `llm.exe` on Windows)
- Run without building: `go run ./cmd/llm`
### CLI Features
- Interactive TUI for testing all go-llm features
- Support for OpenAI, Anthropic, and Google providers
- Image input (file path, URL, or base64)
- Tool/function calling with demo tools
- Temperature control and settings
### Key Bindings
- `Enter` - Send message
- `Ctrl+I` - Add image
- `Ctrl+T` - Toggle tools panel
- `Ctrl+P` - Change provider
- `Ctrl+M` - Change model
- `Ctrl+S` - Settings
- `Ctrl+N` - New conversation
- `Esc` - Exit/Cancel
## MCP (Model Context Protocol) Support
The library supports connecting to MCP servers to use their tools. MCP servers can be connected via:
- **stdio**: Run a command as a subprocess
- **sse**: Connect to an SSE endpoint
- **http**: Connect to a streamable HTTP endpoint
### Usage Example
```go
ctx := context.Background()
// Create and connect to an MCP server
server := &llm.MCPServer{
Name: "my-server",
Command: "my-mcp-server",
Args: []string{"--some-flag"},
}
if err := server.Connect(ctx); err != nil {
log.Fatal(err)
}
defer server.Close()
// Add the server to a toolbox
toolbox := llm.NewToolBox().WithMCPServer(server)
// Use the toolbox in requests - MCP tools are automatically available
req := llm.Request{
Messages: []llm.Message{{Role: llm.RoleUser, Text: "Use the MCP tool"}},
Toolbox: toolbox,
}
```
### MCPServer Options
- `Name`: Friendly name for logging
- `Command`: Command to run (for stdio transport)
- `Args`: Command arguments
- `Env`: Additional environment variables
- `URL`: Endpoint URL (for sse/http transport)
- `Transport`: "stdio" (default), "sse", or "http"

View File

@@ -1,4 +1,4 @@
package go_llm
package llm
import (
"context"
@@ -10,19 +10,19 @@ import (
"log/slog"
"net/http"
"gitea.stevedudenhoeffer.com/steve/go-llm/utils"
"gitea.stevedudenhoeffer.com/steve/go-llm/internal/imageutil"
anth "github.com/liushuangls/go-anthropic/v2"
)
type anthropic struct {
type anthropicImpl struct {
key string
model string
}
var _ LLM = anthropic{}
var _ LLM = anthropicImpl{}
func (a anthropic) ModelVersion(modelVersion string) (ChatCompletion, error) {
func (a anthropicImpl) ModelVersion(modelVersion string) (ChatCompletion, error) {
a.model = modelVersion
// TODO: model verification?
@@ -36,7 +36,7 @@ func deferClose(c io.Closer) {
}
}
func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
func (a anthropicImpl) requestToAnthropicRequest(req Request) anth.MessagesRequest {
res := anth.MessagesRequest{
Model: anth.Model(a.model),
MaxTokens: 1000,
@@ -90,7 +90,7 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
// Check if image size exceeds 5MiB (5242880 bytes)
if len(raw) >= 5242880 {
compressed, mime, err := utils.CompressImage(img.Base64, 5*1024*1024)
compressed, mime, err := imageutil.CompressImage(img.Base64, 5*1024*1024)
// just replace the image with the compressed one
if err != nil {
@@ -157,7 +157,7 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
}
}
for _, tool := range req.Toolbox.functions {
for _, tool := range req.Toolbox.Functions() {
res.Tools = append(res.Tools, anth.ToolDefinition{
Name: tool.Name,
Description: tool.Description,
@@ -177,7 +177,7 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
return res
}
func (a anthropic) responseToLLMResponse(in anth.MessagesResponse) Response {
func (a anthropicImpl) responseToLLMResponse(in anth.MessagesResponse) Response {
choice := ResponseChoice{}
for _, msg := range in.Content {
@@ -212,7 +212,7 @@ func (a anthropic) responseToLLMResponse(in anth.MessagesResponse) Response {
}
}
func (a anthropic) ChatComplete(ctx context.Context, req Request) (Response, error) {
func (a anthropicImpl) ChatComplete(ctx context.Context, req Request) (Response, error) {
cl := anth.NewClient(a.key)
res, err := cl.CreateMessages(ctx, a.requestToAnthropicRequest(req))

11
cmd/llm/.env.example Normal file
View File

@@ -0,0 +1,11 @@
# go-llm CLI Environment Variables
# Copy this file to .env and fill in your API keys
# OpenAI API Key (https://platform.openai.com/api-keys)
OPENAI_API_KEY=
# Anthropic API Key (https://console.anthropic.com/settings/keys)
ANTHROPIC_API_KEY=
# Google AI API Key (https://aistudio.google.com/apikey)
GOOGLE_API_KEY=

182
cmd/llm/commands.go Normal file
View File

@@ -0,0 +1,182 @@
package main
import (
"context"
"encoding/base64"
"fmt"
"net/http"
"os"
"strings"
tea "github.com/charmbracelet/bubbletea"
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
)
// Message types for async operations
// ChatResponseMsg contains the response from a chat completion
type ChatResponseMsg struct {
Response llm.Response
Err error
}
// ToolExecutionMsg contains results from tool execution
type ToolExecutionMsg struct {
Results []llm.ToolCallResponse
Err error
}
// ImageLoadedMsg contains a loaded image
type ImageLoadedMsg struct {
Image llm.Image
Err error
}
// sendChatRequest sends a chat completion request
func sendChatRequest(chat llm.ChatCompletion, req llm.Request) tea.Cmd {
return func() tea.Msg {
resp, err := chat.ChatComplete(context.Background(), req)
return ChatResponseMsg{Response: resp, Err: err}
}
}
// executeTools executes tool calls and returns results
func executeTools(toolbox llm.ToolBox, req llm.Request, resp llm.ResponseChoice) tea.Cmd {
return func() tea.Msg {
ctx := llm.NewContext(context.Background(), req, &resp, nil)
var results []llm.ToolCallResponse
for _, call := range resp.Calls {
result, err := toolbox.Execute(ctx, call)
results = append(results, llm.ToolCallResponse{
ID: call.ID,
Result: result,
Error: err,
})
}
return ToolExecutionMsg{Results: results, Err: nil}
}
}
// loadImageFromPath loads an image from a file path
func loadImageFromPath(path string) tea.Cmd {
return func() tea.Msg {
// Clean up the path
path = strings.TrimSpace(path)
path = strings.Trim(path, "\"'")
// Read the file
data, err := os.ReadFile(path)
if err != nil {
return ImageLoadedMsg{Err: fmt.Errorf("failed to read image file: %w", err)}
}
// Detect content type
contentType := http.DetectContentType(data)
if !strings.HasPrefix(contentType, "image/") {
return ImageLoadedMsg{Err: fmt.Errorf("file is not an image: %s", contentType)}
}
// Base64 encode
encoded := base64.StdEncoding.EncodeToString(data)
return ImageLoadedMsg{
Image: llm.Image{
Base64: encoded,
ContentType: contentType,
},
}
}
}
// loadImageFromURL loads an image from a URL
func loadImageFromURL(url string) tea.Cmd {
return func() tea.Msg {
url = strings.TrimSpace(url)
// For URL images, we can just use the URL directly
return ImageLoadedMsg{
Image: llm.Image{
Url: url,
},
}
}
}
// loadImageFromBase64 loads an image from base64 data
func loadImageFromBase64(data string) tea.Cmd {
return func() tea.Msg {
data = strings.TrimSpace(data)
// Check if it's a data URL
if strings.HasPrefix(data, "data:") {
// Parse data URL: data:image/png;base64,....
parts := strings.SplitN(data, ",", 2)
if len(parts) != 2 {
return ImageLoadedMsg{Err: fmt.Errorf("invalid data URL format")}
}
// Extract content type from first part
mediaType := strings.TrimPrefix(parts[0], "data:")
mediaType = strings.TrimSuffix(mediaType, ";base64")
return ImageLoadedMsg{
Image: llm.Image{
Base64: parts[1],
ContentType: mediaType,
},
}
}
// Assume it's raw base64, try to detect content type
decoded, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return ImageLoadedMsg{Err: fmt.Errorf("invalid base64 data: %w", err)}
}
contentType := http.DetectContentType(decoded)
if !strings.HasPrefix(contentType, "image/") {
return ImageLoadedMsg{Err: fmt.Errorf("data is not an image: %s", contentType)}
}
return ImageLoadedMsg{
Image: llm.Image{
Base64: data,
ContentType: contentType,
},
}
}
}
// buildRequest builds a chat request from the current state
func buildRequest(m *Model, userText string) llm.Request {
// Create the user message with any pending images
userMsg := llm.Message{
Role: llm.RoleUser,
Text: userText,
Images: m.pendingImages,
}
req := llm.Request{
Conversation: m.conversation,
Messages: []llm.Message{
{Role: llm.RoleSystem, Text: m.systemPrompt},
userMsg,
},
Temperature: m.temperature,
}
// Add toolbox if enabled
if m.toolsEnabled && len(m.toolbox.Functions()) > 0 {
req.Toolbox = m.toolbox.WithRequireTool(false)
}
return req
}
// buildFollowUpRequest builds a follow-up request after tool execution
func buildFollowUpRequest(m *Model, previousReq llm.Request, resp llm.ResponseChoice, toolResults []llm.ToolCallResponse) llm.Request {
return previousReq.NextRequest(resp, toolResults)
}

25
cmd/llm/main.go Normal file
View File

@@ -0,0 +1,25 @@
package main
import (
"fmt"
"os"
tea "github.com/charmbracelet/bubbletea"
"github.com/joho/godotenv"
)
func main() {
// Load .env file if it exists (ignore error if not found)
_ = godotenv.Load()
p := tea.NewProgram(
InitialModel(),
tea.WithAltScreen(),
tea.WithMouseCellMotion(),
)
if _, err := p.Run(); err != nil {
fmt.Printf("Error running program: %v\n", err)
os.Exit(1)
}
}

295
cmd/llm/model.go Normal file
View File

@@ -0,0 +1,295 @@
package main
import (
"os"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
)
// State represents the current view/screen of the application
type State int
const (
StateChat State = iota
StateProviderSelect
StateModelSelect
StateImageInput
StateToolsPanel
StateSettings
StateAPIKeyInput
)
// DisplayMessage represents a message for display in the UI
type DisplayMessage struct {
Role llm.Role
Content string
Images int // number of images attached
}
// ProviderInfo contains information about a provider
type ProviderInfo struct {
Name string
EnvVar string
Models []string
HasAPIKey bool
ModelIndex int
}
// Model is the main Bubble Tea model
type Model struct {
// State
state State
previousState State
// Provider
provider llm.LLM
providerName string
chat llm.ChatCompletion
modelName string
apiKeys map[string]string
providers []ProviderInfo
providerIndex int
// Conversation
conversation []llm.Input
messages []DisplayMessage
// Tools
toolbox llm.ToolBox
toolsEnabled bool
// Settings
systemPrompt string
temperature *float64
// Pending images
pendingImages []llm.Image
// UI Components
input textinput.Model
viewport viewport.Model
viewportReady bool
// Selection state (for lists)
listIndex int
listItems []string
// Dimensions
width int
height int
// Loading state
loading bool
err error
// For API key input
apiKeyInput textinput.Model
}
// InitialModel creates and returns the initial model
func InitialModel() Model {
ti := textinput.New()
ti.Placeholder = "Type your message..."
ti.Focus()
ti.CharLimit = 4096
ti.Width = 60
aki := textinput.New()
aki.Placeholder = "Enter API key..."
aki.CharLimit = 256
aki.Width = 60
aki.EchoMode = textinput.EchoPassword
// Initialize providers with environment variable checks
providers := []ProviderInfo{
{
Name: "OpenAI",
EnvVar: "OPENAI_API_KEY",
Models: []string{
"gpt-4.1",
"gpt-4.1-mini",
"gpt-4.1-nano",
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
"gpt-3.5-turbo",
"o1",
"o1-mini",
"o1-preview",
"o3-mini",
},
},
{
Name: "Anthropic",
EnvVar: "ANTHROPIC_API_KEY",
Models: []string{
"claude-sonnet-4-20250514",
"claude-opus-4-20250514",
"claude-3-7-sonnet-20250219",
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-20241022",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
},
},
{
Name: "Google",
EnvVar: "GOOGLE_API_KEY",
Models: []string{
"gemini-2.0-flash",
"gemini-2.0-flash-lite",
"gemini-1.5-pro",
"gemini-1.5-flash",
"gemini-1.5-flash-8b",
"gemini-1.0-pro",
},
},
}
// Check for API keys in environment
apiKeys := make(map[string]string)
for i := range providers {
if key := os.Getenv(providers[i].EnvVar); key != "" {
apiKeys[providers[i].Name] = key
providers[i].HasAPIKey = true
}
}
m := Model{
state: StateProviderSelect,
input: ti,
apiKeyInput: aki,
apiKeys: apiKeys,
providers: providers,
systemPrompt: "You are a helpful assistant.",
toolbox: createDemoToolbox(),
toolsEnabled: false,
messages: []DisplayMessage{},
conversation: []llm.Input{},
}
// Build list items for provider selection
m.listItems = make([]string, len(providers))
for i, p := range providers {
status := " (no key)"
if p.HasAPIKey {
status = " (ready)"
}
m.listItems[i] = p.Name + status
}
return m
}
// Init initializes the model
func (m Model) Init() tea.Cmd {
return textinput.Blink
}
// selectProvider sets up the selected provider
func (m *Model) selectProvider(index int) error {
if index < 0 || index >= len(m.providers) {
return nil
}
p := m.providers[index]
key, ok := m.apiKeys[p.Name]
if !ok || key == "" {
return nil
}
m.providerName = p.Name
m.providerIndex = index
switch p.Name {
case "OpenAI":
m.provider = llm.OpenAI(key)
case "Anthropic":
m.provider = llm.Anthropic(key)
case "Google":
m.provider = llm.Google(key)
}
// Select default model
if len(p.Models) > 0 {
return m.selectModel(p.ModelIndex)
}
return nil
}
// selectModel sets the current model
func (m *Model) selectModel(index int) error {
if m.provider == nil {
return nil
}
p := m.providers[m.providerIndex]
if index < 0 || index >= len(p.Models) {
return nil
}
modelName := p.Models[index]
chat, err := m.provider.ModelVersion(modelName)
if err != nil {
return err
}
m.chat = chat
m.modelName = modelName
m.providers[m.providerIndex].ModelIndex = index
return nil
}
// newConversation resets the conversation
func (m *Model) newConversation() {
m.conversation = []llm.Input{}
m.messages = []DisplayMessage{}
m.pendingImages = []llm.Image{}
m.err = nil
}
// addUserMessage adds a user message to the conversation
func (m *Model) addUserMessage(text string, images []llm.Image) {
msg := llm.Message{
Role: llm.RoleUser,
Text: text,
Images: images,
}
m.conversation = append(m.conversation, msg)
m.messages = append(m.messages, DisplayMessage{
Role: llm.RoleUser,
Content: text,
Images: len(images),
})
}
// addAssistantMessage adds an assistant message to the conversation
func (m *Model) addAssistantMessage(content string) {
m.messages = append(m.messages, DisplayMessage{
Role: llm.RoleAssistant,
Content: content,
})
}
// addToolCallMessage adds a tool call message to display
func (m *Model) addToolCallMessage(name string, args string) {
m.messages = append(m.messages, DisplayMessage{
Role: llm.Role("tool_call"),
Content: name + ": " + args,
})
}
// addToolResultMessage adds a tool result message to display
func (m *Model) addToolResultMessage(name string, result string) {
m.messages = append(m.messages, DisplayMessage{
Role: llm.Role("tool_result"),
Content: name + " -> " + result,
})
}

113
cmd/llm/styles.go Normal file
View File

@@ -0,0 +1,113 @@
package main
import (
"github.com/charmbracelet/lipgloss"
)
var (
// Colors
primaryColor = lipgloss.Color("205")
secondaryColor = lipgloss.Color("39")
accentColor = lipgloss.Color("212")
mutedColor = lipgloss.Color("241")
errorColor = lipgloss.Color("196")
successColor = lipgloss.Color("82")
// App styles
appStyle = lipgloss.NewStyle().Padding(1, 2)
// Header
headerStyle = lipgloss.NewStyle().
Bold(true).
Foreground(primaryColor).
BorderStyle(lipgloss.NormalBorder()).
BorderBottom(true).
BorderForeground(mutedColor).
Padding(0, 1)
// Provider badge
providerBadgeStyle = lipgloss.NewStyle().
Background(secondaryColor).
Foreground(lipgloss.Color("0")).
Padding(0, 1).
Bold(true)
// Messages
systemMsgStyle = lipgloss.NewStyle().
Foreground(mutedColor).
Italic(true).
Padding(0, 1)
userMsgStyle = lipgloss.NewStyle().
Foreground(secondaryColor).
Padding(0, 1)
assistantMsgStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("255")).
Padding(0, 1)
roleLabelStyle = lipgloss.NewStyle().
Bold(true).
Width(12)
// Tool calls
toolCallStyle = lipgloss.NewStyle().
Foreground(accentColor).
Italic(true).
Padding(0, 1)
toolResultStyle = lipgloss.NewStyle().
Foreground(successColor).
Padding(0, 1)
// Input area
inputStyle = lipgloss.NewStyle().
BorderStyle(lipgloss.RoundedBorder()).
BorderForeground(primaryColor).
Padding(0, 1)
inputHelpStyle = lipgloss.NewStyle().
Foreground(mutedColor).
Italic(true)
// Error
errorStyle = lipgloss.NewStyle().
Foreground(errorColor).
Bold(true)
// Loading
loadingStyle = lipgloss.NewStyle().
Foreground(accentColor).
Italic(true)
// List selection
selectedItemStyle = lipgloss.NewStyle().
Foreground(primaryColor).
Bold(true)
normalItemStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("255"))
// Settings panel
settingLabelStyle = lipgloss.NewStyle().
Foreground(secondaryColor).
Width(15)
settingValueStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("255"))
// Help text
helpStyle = lipgloss.NewStyle().
Foreground(mutedColor).
Padding(1, 0)
// Image indicator
imageIndicatorStyle = lipgloss.NewStyle().
Foreground(accentColor).
Bold(true)
// Viewport
viewportStyle = lipgloss.NewStyle().
BorderStyle(lipgloss.NormalBorder()).
BorderForeground(mutedColor)
)

105
cmd/llm/tools.go Normal file
View File

@@ -0,0 +1,105 @@
package main
import (
"fmt"
"math"
"strconv"
"strings"
"time"
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
)
// TimeParams is the parameter struct for the GetTime function
type TimeParams struct{}
// GetTime returns the current time
func GetTime(_ *llm.Context, _ TimeParams) (any, error) {
return time.Now().Format("Monday, January 2, 2006 3:04:05 PM MST"), nil
}
// CalcParams is the parameter struct for the Calculate function
type CalcParams struct {
A float64 `json:"a" description:"First number"`
B float64 `json:"b" description:"Second number"`
Op string `json:"op" description:"Operation: add, subtract, multiply, divide, power, sqrt, mod"`
}
// Calculate performs basic math operations
func Calculate(_ *llm.Context, params CalcParams) (any, error) {
switch strings.ToLower(params.Op) {
case "add", "+":
return params.A + params.B, nil
case "subtract", "sub", "-":
return params.A - params.B, nil
case "multiply", "mul", "*":
return params.A * params.B, nil
case "divide", "div", "/":
if params.B == 0 {
return nil, fmt.Errorf("division by zero")
}
return params.A / params.B, nil
case "power", "pow", "^":
return math.Pow(params.A, params.B), nil
case "sqrt":
if params.A < 0 {
return nil, fmt.Errorf("cannot take square root of negative number")
}
return math.Sqrt(params.A), nil
case "mod", "%":
return math.Mod(params.A, params.B), nil
default:
return nil, fmt.Errorf("unknown operation: %s", params.Op)
}
}
// WeatherParams is the parameter struct for the GetWeather function
type WeatherParams struct {
Location string `json:"location" description:"City name or location"`
}
// GetWeather returns mock weather data (for demo purposes)
func GetWeather(_ *llm.Context, params WeatherParams) (any, error) {
// This is a demo function - returns mock data
weathers := []string{"sunny", "cloudy", "rainy", "partly cloudy", "windy"}
temps := []int{65, 72, 58, 80, 45}
// Use location string to deterministically pick weather
idx := len(params.Location) % len(weathers)
return map[string]any{
"location": params.Location,
"temperature": strconv.Itoa(temps[idx]) + "F",
"condition": weathers[idx],
"humidity": "45%",
"note": "This is mock data for demonstration purposes",
}, nil
}
// RandomNumberParams is the parameter struct for the RandomNumber function
type RandomNumberParams struct {
Min int `json:"min" description:"Minimum value (inclusive)"`
Max int `json:"max" description:"Maximum value (inclusive)"`
}
// RandomNumber generates a pseudo-random number (using current time nanoseconds)
func RandomNumber(_ *llm.Context, params RandomNumberParams) (any, error) {
if params.Min > params.Max {
return nil, fmt.Errorf("min cannot be greater than max")
}
// Simple pseudo-random using time
n := time.Now().UnixNano()
rangeSize := params.Max - params.Min + 1
result := params.Min + int(n%int64(rangeSize))
return result, nil
}
// createDemoToolbox creates a toolbox with demo tools for testing
func createDemoToolbox() llm.ToolBox {
return llm.NewToolBox(
llm.NewFunction("get_time", "Get the current date and time", GetTime),
llm.NewFunction("calculate", "Perform basic math operations (add, subtract, multiply, divide, power, sqrt, mod)", Calculate),
llm.NewFunction("get_weather", "Get weather information for a location (demo data)", GetWeather),
llm.NewFunction("random_number", "Generate a random number between min and max", RandomNumber),
)
}

435
cmd/llm/update.go Normal file
View File

@@ -0,0 +1,435 @@
package main
import (
"fmt"
"strings"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
)
// pendingRequest stores the request being processed for follow-up
var pendingRequest llm.Request
var pendingResponse llm.ResponseChoice
// Update handles messages and updates the model
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
var cmds []tea.Cmd
switch msg := msg.(type) {
case tea.KeyMsg:
return m.handleKeyMsg(msg)
case tea.WindowSizeMsg:
m.width = msg.Width
m.height = msg.Height
headerHeight := 3
footerHeight := 4
verticalMargins := headerHeight + footerHeight
if !m.viewportReady {
m.viewport = viewport.New(msg.Width-4, msg.Height-verticalMargins)
m.viewport.HighPerformanceRendering = false
m.viewportReady = true
} else {
m.viewport.Width = msg.Width - 4
m.viewport.Height = msg.Height - verticalMargins
}
m.input.Width = msg.Width - 6
m.apiKeyInput.Width = msg.Width - 6
m.viewport.SetContent(m.renderMessages())
case ChatResponseMsg:
m.loading = false
if msg.Err != nil {
m.err = msg.Err
return m, nil
}
if len(msg.Response.Choices) == 0 {
m.err = fmt.Errorf("no response choices returned")
return m, nil
}
choice := msg.Response.Choices[0]
// Check for tool calls
if len(choice.Calls) > 0 && m.toolsEnabled {
// Store for follow-up
pendingResponse = choice
// Add assistant's response to conversation if there's content
if choice.Content != "" {
m.addAssistantMessage(choice.Content)
}
// Display tool calls
for _, call := range choice.Calls {
m.addToolCallMessage(call.FunctionCall.Name, call.FunctionCall.Arguments)
}
m.viewport.SetContent(m.renderMessages())
m.viewport.GotoBottom()
// Execute tools
m.loading = true
return m, executeTools(m.toolbox, pendingRequest, choice)
}
// Regular response - add to conversation and display
m.conversation = append(m.conversation, choice)
m.addAssistantMessage(choice.Content)
m.viewport.SetContent(m.renderMessages())
m.viewport.GotoBottom()
case ToolExecutionMsg:
if msg.Err != nil {
m.loading = false
m.err = msg.Err
return m, nil
}
// Display tool results
for i, result := range msg.Results {
name := pendingResponse.Calls[i].FunctionCall.Name
resultStr := fmt.Sprintf("%v", result.Result)
if result.Error != nil {
resultStr = "Error: " + result.Error.Error()
}
m.addToolResultMessage(name, resultStr)
}
// Add tool call responses to conversation
for _, result := range msg.Results {
m.conversation = append(m.conversation, result)
}
// Add the assistant's response to conversation
m.conversation = append(m.conversation, pendingResponse)
m.viewport.SetContent(m.renderMessages())
m.viewport.GotoBottom()
// Send follow-up request
followUp := buildFollowUpRequest(&m, pendingRequest, pendingResponse, msg.Results)
pendingRequest = followUp
return m, sendChatRequest(m.chat, followUp)
case ImageLoadedMsg:
if msg.Err != nil {
m.err = msg.Err
m.state = m.previousState
return m, nil
}
m.pendingImages = append(m.pendingImages, msg.Image)
m.state = m.previousState
m.err = nil
default:
// Update text input
if m.state == StateChat {
m.input, cmd = m.input.Update(msg)
cmds = append(cmds, cmd)
} else if m.state == StateAPIKeyInput {
m.apiKeyInput, cmd = m.apiKeyInput.Update(msg)
cmds = append(cmds, cmd)
}
}
return m, tea.Batch(cmds...)
}
// handleKeyMsg handles keyboard input
func (m Model) handleKeyMsg(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
// Global key handling
switch msg.String() {
case "ctrl+c":
return m, tea.Quit
case "esc":
if m.state != StateChat {
m.state = StateChat
m.input.Focus()
return m, nil
}
return m, tea.Quit
}
// State-specific key handling
switch m.state {
case StateChat:
return m.handleChatKeys(msg)
case StateProviderSelect:
return m.handleProviderSelectKeys(msg)
case StateModelSelect:
return m.handleModelSelectKeys(msg)
case StateImageInput:
return m.handleImageInputKeys(msg)
case StateToolsPanel:
return m.handleToolsPanelKeys(msg)
case StateSettings:
return m.handleSettingsKeys(msg)
case StateAPIKeyInput:
return m.handleAPIKeyInputKeys(msg)
}
return m, nil
}
// handleChatKeys handles keys in chat state
func (m Model) handleChatKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
switch msg.String() {
case "enter":
if m.loading {
return m, nil
}
text := strings.TrimSpace(m.input.Value())
if text == "" {
return m, nil
}
if m.chat == nil {
m.err = fmt.Errorf("no model selected - press Ctrl+P to select a provider")
return m, nil
}
// Build and send request
req := buildRequest(&m, text)
pendingRequest = req
// Add user message to display
m.addUserMessage(text, m.pendingImages)
// Clear input and pending images
m.input.Reset()
m.pendingImages = nil
m.err = nil
m.loading = true
m.viewport.SetContent(m.renderMessages())
m.viewport.GotoBottom()
return m, sendChatRequest(m.chat, req)
case "ctrl+i":
m.previousState = StateChat
m.state = StateImageInput
m.input.SetValue("")
m.input.Placeholder = "Enter image path or URL..."
return m, nil
case "ctrl+t":
m.state = StateToolsPanel
return m, nil
case "ctrl+p":
m.state = StateProviderSelect
m.listIndex = m.providerIndex
return m, nil
case "ctrl+m":
if m.provider == nil {
m.err = fmt.Errorf("select a provider first")
return m, nil
}
m.state = StateModelSelect
m.listItems = m.providers[m.providerIndex].Models
m.listIndex = m.providers[m.providerIndex].ModelIndex
return m, nil
case "ctrl+s":
m.state = StateSettings
return m, nil
case "ctrl+n":
m.newConversation()
m.viewport.SetContent(m.renderMessages())
return m, nil
case "up", "down", "pgup", "pgdown":
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
default:
var cmd tea.Cmd
m.input, cmd = m.input.Update(msg)
return m, cmd
}
}
// handleProviderSelectKeys handles keys in provider selection state
func (m Model) handleProviderSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
switch msg.String() {
case "up", "k":
if m.listIndex > 0 {
m.listIndex--
}
case "down", "j":
if m.listIndex < len(m.providers)-1 {
m.listIndex++
}
case "enter":
p := m.providers[m.listIndex]
if !p.HasAPIKey {
// Need to get API key
m.state = StateAPIKeyInput
m.apiKeyInput.Focus()
m.apiKeyInput.SetValue("")
return m, textinput.Blink
}
err := m.selectProvider(m.listIndex)
if err != nil {
m.err = err
return m, nil
}
m.state = StateChat
m.input.Focus()
m.newConversation()
return m, nil
}
return m, nil
}
// handleAPIKeyInputKeys handles keys in API key input state
func (m Model) handleAPIKeyInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
switch msg.String() {
case "enter":
key := strings.TrimSpace(m.apiKeyInput.Value())
if key == "" {
return m, nil
}
// Store the API key
p := m.providers[m.listIndex]
m.apiKeys[p.Name] = key
m.providers[m.listIndex].HasAPIKey = true
// Update list items
for i, prov := range m.providers {
status := " (no key)"
if prov.HasAPIKey {
status = " (ready)"
}
m.listItems[i] = prov.Name + status
}
// Select the provider
err := m.selectProvider(m.listIndex)
if err != nil {
m.err = err
return m, nil
}
m.state = StateChat
m.input.Focus()
m.newConversation()
return m, nil
default:
var cmd tea.Cmd
m.apiKeyInput, cmd = m.apiKeyInput.Update(msg)
return m, cmd
}
}
// handleModelSelectKeys handles keys in model selection state
func (m Model) handleModelSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
switch msg.String() {
case "up", "k":
if m.listIndex > 0 {
m.listIndex--
}
case "down", "j":
if m.listIndex < len(m.listItems)-1 {
m.listIndex++
}
case "enter":
err := m.selectModel(m.listIndex)
if err != nil {
m.err = err
return m, nil
}
m.state = StateChat
m.input.Focus()
}
return m, nil
}
// handleImageInputKeys handles keys in image input state
func (m Model) handleImageInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
switch msg.String() {
case "enter":
input := strings.TrimSpace(m.input.Value())
if input == "" {
m.state = m.previousState
m.input.Placeholder = "Type your message..."
return m, nil
}
m.input.Placeholder = "Type your message..."
// Determine input type and load
if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") {
return m, loadImageFromURL(input)
} else if strings.HasPrefix(input, "data:") || len(input) > 100 && !strings.Contains(input, "/") && !strings.Contains(input, "\\") {
return m, loadImageFromBase64(input)
} else {
return m, loadImageFromPath(input)
}
default:
var cmd tea.Cmd
m.input, cmd = m.input.Update(msg)
return m, cmd
}
}
// handleToolsPanelKeys handles keys in tools panel state
func (m Model) handleToolsPanelKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
switch msg.String() {
case "t":
m.toolsEnabled = !m.toolsEnabled
case "enter", "q":
m.state = StateChat
m.input.Focus()
}
return m, nil
}
// handleSettingsKeys handles keys in settings state
func (m Model) handleSettingsKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
switch msg.String() {
case "1":
// Set temperature to nil (default)
m.temperature = nil
case "2":
t := 0.0
m.temperature = &t
case "3":
t := 0.5
m.temperature = &t
case "4":
t := 0.7
m.temperature = &t
case "5":
t := 1.0
m.temperature = &t
case "enter", "q":
m.state = StateChat
m.input.Focus()
}
return m, nil
}

296
cmd/llm/view.go Normal file
View File

@@ -0,0 +1,296 @@
package main
import (
"fmt"
"strings"
"github.com/charmbracelet/lipgloss"
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
)
// View renders the current state
func (m Model) View() string {
switch m.state {
case StateProviderSelect:
return m.renderProviderSelect()
case StateModelSelect:
return m.renderModelSelect()
case StateImageInput:
return m.renderImageInput()
case StateToolsPanel:
return m.renderToolsPanel()
case StateSettings:
return m.renderSettings()
case StateAPIKeyInput:
return m.renderAPIKeyInput()
default:
return m.renderChat()
}
}
// renderChat renders the main chat view
func (m Model) renderChat() string {
var b strings.Builder
// Header
provider := m.providerName
if provider == "" {
provider = "None"
}
model := m.modelName
if model == "" {
model = "None"
}
header := headerStyle.Render(fmt.Sprintf("go-llm CLI %s",
providerBadgeStyle.Render(fmt.Sprintf("%s/%s", provider, model))))
b.WriteString(header)
b.WriteString("\n")
// Messages viewport
if m.viewportReady {
b.WriteString(m.viewport.View())
b.WriteString("\n")
}
// Image indicator
if len(m.pendingImages) > 0 {
b.WriteString(imageIndicatorStyle.Render(fmt.Sprintf(" [%d image(s) attached]", len(m.pendingImages))))
b.WriteString("\n")
}
// Error
if m.err != nil {
b.WriteString(errorStyle.Render(" Error: " + m.err.Error()))
b.WriteString("\n")
}
// Loading
if m.loading {
b.WriteString(loadingStyle.Render(" Thinking..."))
b.WriteString("\n")
}
// Input
inputBox := inputStyle.Render(m.input.View())
b.WriteString(inputBox)
b.WriteString("\n")
// Help
help := inputHelpStyle.Render("Enter: send | Ctrl+I: image | Ctrl+T: tools | Ctrl+P: provider | Ctrl+M: model | Ctrl+S: settings | Ctrl+N: new | Esc: quit")
b.WriteString(help)
return appStyle.Render(b.String())
}
// renderMessages renders all messages for the viewport
func (m Model) renderMessages() string {
var b strings.Builder
if len(m.messages) == 0 {
b.WriteString(systemMsgStyle.Render("[System] " + m.systemPrompt))
b.WriteString("\n\n")
b.WriteString(lipgloss.NewStyle().Foreground(mutedColor).Render("Start a conversation by typing a message below."))
return b.String()
}
b.WriteString(systemMsgStyle.Render("[System] " + m.systemPrompt))
b.WriteString("\n\n")
for _, msg := range m.messages {
var content string
var style lipgloss.Style
switch msg.Role {
case llm.RoleUser:
style = userMsgStyle
label := roleLabelStyle.Foreground(secondaryColor).Render("[User]")
content = label + " " + msg.Content
if msg.Images > 0 {
content += imageIndicatorStyle.Render(fmt.Sprintf(" [%d image(s)]", msg.Images))
}
case llm.RoleAssistant:
style = assistantMsgStyle
label := roleLabelStyle.Foreground(lipgloss.Color("255")).Render("[Assistant]")
content = label + " " + msg.Content
case llm.Role("tool_call"):
style = toolCallStyle
content = " -> Calling: " + msg.Content
case llm.Role("tool_result"):
style = toolResultStyle
content = " <- Result: " + msg.Content
default:
style = assistantMsgStyle
content = msg.Content
}
b.WriteString(style.Render(content))
b.WriteString("\n\n")
}
return b.String()
}
// renderProviderSelect renders the provider selection view
func (m Model) renderProviderSelect() string {
var b strings.Builder
b.WriteString(headerStyle.Render("Select Provider"))
b.WriteString("\n\n")
for i, item := range m.listItems {
cursor := " "
style := normalItemStyle
if i == m.listIndex {
cursor = "> "
style = selectedItemStyle
}
b.WriteString(style.Render(cursor + item))
b.WriteString("\n")
}
b.WriteString("\n")
b.WriteString(helpStyle.Render("Use arrow keys or j/k to navigate, Enter to select, Esc to cancel"))
return appStyle.Render(b.String())
}
// renderAPIKeyInput renders the API key input view
func (m Model) renderAPIKeyInput() string {
var b strings.Builder
provider := m.providers[m.listIndex]
b.WriteString(headerStyle.Render(fmt.Sprintf("Enter API Key for %s", provider.Name)))
b.WriteString("\n\n")
b.WriteString(fmt.Sprintf("Environment variable: %s\n\n", provider.EnvVar))
b.WriteString("Enter your API key below (it will be hidden):\n\n")
inputBox := inputStyle.Render(m.apiKeyInput.View())
b.WriteString(inputBox)
b.WriteString("\n\n")
b.WriteString(helpStyle.Render("Enter to confirm, Esc to cancel"))
return appStyle.Render(b.String())
}
// renderModelSelect renders the model selection view
func (m Model) renderModelSelect() string {
var b strings.Builder
b.WriteString(headerStyle.Render(fmt.Sprintf("Select Model (%s)", m.providerName)))
b.WriteString("\n\n")
for i, item := range m.listItems {
cursor := " "
style := normalItemStyle
if i == m.listIndex {
cursor = "> "
style = selectedItemStyle
}
if item == m.modelName {
item += " (current)"
}
b.WriteString(style.Render(cursor + item))
b.WriteString("\n")
}
b.WriteString("\n")
b.WriteString(helpStyle.Render("Use arrow keys or j/k to navigate, Enter to select, Esc to cancel"))
return appStyle.Render(b.String())
}
// renderImageInput renders the image input view
func (m Model) renderImageInput() string {
var b strings.Builder
b.WriteString(headerStyle.Render("Add Image"))
b.WriteString("\n\n")
b.WriteString("Enter an image source:\n")
b.WriteString(" - File path (e.g., /path/to/image.png)\n")
b.WriteString(" - URL (e.g., https://example.com/image.jpg)\n")
b.WriteString(" - Base64 data or data URL\n\n")
if len(m.pendingImages) > 0 {
b.WriteString(imageIndicatorStyle.Render(fmt.Sprintf("Currently attached: %d image(s)\n\n", len(m.pendingImages))))
}
inputBox := inputStyle.Render(m.input.View())
b.WriteString(inputBox)
b.WriteString("\n\n")
b.WriteString(helpStyle.Render("Enter to add image, Esc to cancel"))
return appStyle.Render(b.String())
}
// renderToolsPanel renders the tools panel
func (m Model) renderToolsPanel() string {
var b strings.Builder
b.WriteString(headerStyle.Render("Tools / Function Calling"))
b.WriteString("\n\n")
status := "DISABLED"
statusStyle := errorStyle
if m.toolsEnabled {
status = "ENABLED"
statusStyle = lipgloss.NewStyle().Foreground(successColor).Bold(true)
}
b.WriteString(settingLabelStyle.Render("Tools Status:"))
b.WriteString(statusStyle.Render(status))
b.WriteString("\n\n")
b.WriteString("Available tools:\n")
for _, fn := range m.toolbox.Functions() {
b.WriteString(fmt.Sprintf(" - %s: %s\n", selectedItemStyle.Render(fn.Name), fn.Description))
}
b.WriteString("\n")
b.WriteString(helpStyle.Render("Press 't' to toggle tools, Enter or 'q' to close"))
return appStyle.Render(b.String())
}
// renderSettings renders the settings view
func (m Model) renderSettings() string {
var b strings.Builder
b.WriteString(headerStyle.Render("Settings"))
b.WriteString("\n\n")
// Temperature
tempStr := "default"
if m.temperature != nil {
tempStr = fmt.Sprintf("%.1f", *m.temperature)
}
b.WriteString(settingLabelStyle.Render("Temperature:"))
b.WriteString(settingValueStyle.Render(tempStr))
b.WriteString("\n\n")
b.WriteString("Press a key to set temperature:\n")
b.WriteString(" 1 - Default (model decides)\n")
b.WriteString(" 2 - 0.0 (deterministic)\n")
b.WriteString(" 3 - 0.5 (balanced)\n")
b.WriteString(" 4 - 0.7 (creative)\n")
b.WriteString(" 5 - 1.0 (very creative)\n")
b.WriteString("\n")
// System prompt
b.WriteString(settingLabelStyle.Render("System Prompt:"))
b.WriteString("\n")
b.WriteString(settingValueStyle.Render(" " + m.systemPrompt))
b.WriteString("\n\n")
b.WriteString(helpStyle.Render("Enter or 'q' to close"))
return appStyle.Render(b.String())
}

View File

@@ -1,4 +1,4 @@
package go_llm
package llm
import (
"context"

575
docs/sandbox-setup.md Normal file
View File

@@ -0,0 +1,575 @@
# Sandbox Setup & Hardening Guide
Complete guide for setting up a Proxmox VE host to run isolated LXC sandbox containers for the go-llm sandbox package.
## Table of Contents
1. [Prerequisites](#1-prerequisites)
2. [Proxmox Host Preparation](#2-proxmox-host-preparation)
3. [Network Setup](#3-network-setup)
4. [LXC Template Creation](#4-lxc-template-creation)
5. [SSH Key Setup](#5-ssh-key-setup)
6. [Configuration](#6-configuration)
7. [Hardening Checklist](#7-hardening-checklist)
8. [Monitoring & Maintenance](#8-monitoring--maintenance)
9. [Troubleshooting](#9-troubleshooting)
---
## 1. Prerequisites
### Hardware/VM Requirements
| Resource | Minimum | Recommended |
|----------|---------|-------------|
| CPU | 4 cores | 8+ cores |
| RAM | 8 GB | 16+ GB |
| Storage | 100 GB SSD | 250+ GB SSD |
| Network | 1 NIC | 2 NICs (mgmt + sandbox) |
### Software
- Proxmox VE 8.x ([installation guide](https://pve.proxmox.com/wiki/Installation))
- During install, configure the management interface on `vmbr0`
---
## 2. Proxmox Host Preparation
### Create Resource Pool
Scope sandbox containers to a dedicated resource pool to limit API token access:
```bash
pvesh create /pools --poolid sandbox-pool
```
### Create API User and Token
```bash
# Create dedicated user
pveum useradd mort-sandbox@pve
# Create role with minimum required permissions
pveum roleadd SandboxAdmin -privs "VM.Allocate,VM.Clone,VM.Audit,VM.PowerMgmt,VM.Console,Datastore.AllocateSpace,Datastore.Audit"
# Grant role on the sandbox pool only
pveum aclmod /pool/sandbox-pool -user mort-sandbox@pve -role SandboxAdmin
# Grant access to the template storage
pveum aclmod /storage/local -user mort-sandbox@pve -role PVEDatastoreUser
# Create API token (privsep=0 means token inherits user's permissions)
pveum user token add mort-sandbox@pve sandbox-token --privsep=0
```
Save the output — it contains the token secret:
```
┌──────────┬──────────────────────────────────────────┐
│ key │ value │
╞══════════╪══════════════════════════════════════════╡
│ full-tokenid │ mort-sandbox@pve!sandbox-token │
│ value │ xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx │
└──────────┴──────────────────────────────────────────┘
```
Store the secret securely (environment variable, secret manager, etc.). Never commit it to source control.
---
## 3. Network Setup
### 3.1 Create Isolated Bridge
Add to `/etc/network/interfaces` on the Proxmox host:
```
auto vmbr1
iface vmbr1 inet static
address 10.99.0.1/16
bridge-ports none
bridge-stp off
bridge-fd 0
post-up echo 1 > /proc/sys/net/ipv4/ip_forward
# NAT for optional internet access (controlled per-container by nftables)
post-up nft add table nat 2>/dev/null; true
post-up nft add chain nat postrouting { type nat hook postrouting priority 100 \; } 2>/dev/null; true
post-up nft add rule nat postrouting oifname "vmbr0" ip saddr 10.99.0.0/16 masquerade 2>/dev/null; true
```
Apply the configuration:
```bash
ifreload -a
```
### 3.2 Install and Configure DHCP
```bash
apt-get install -y dnsmasq
```
Create `/etc/dnsmasq.d/sandbox.conf`:
```
interface=vmbr1
bind-interfaces
dhcp-range=10.99.1.1,10.99.254.254,255.255.0.0,1h
dhcp-option=option:router,10.99.0.1
dhcp-option=option:dns-server,1.1.1.1,8.8.8.8
```
Restart dnsmasq:
```bash
systemctl restart dnsmasq
systemctl enable dnsmasq
```
### 3.3 Configure nftables Firewall
Create `/etc/nftables.conf`:
```nft
#!/usr/sbin/nft -f
flush ruleset
table inet sandbox {
# Dynamic set of container IPs allowed internet access.
# Populated/cleared by the sandbox manager via the Proxmox API.
set internet_allowed {
type ipv4_addr
}
chain forward {
type filter hook forward priority 0; policy drop;
# Allow established/related connections
ct state established,related accept
# Allow inter-bridge traffic (host ↔ containers via vmbr1)
iifname "vmbr1" oifname "vmbr1" accept
# Allow DNS for all containers (needed for apt)
ip saddr 10.99.0.0/16 udp dport 53 accept
ip saddr 10.99.0.0/16 tcp dport 53 accept
# Allow HTTP/HTTPS only for containers in the internet_allowed set
ip saddr @internet_allowed tcp dport { 80, 443 } accept
# Rate limit: max 50 new connections per second per container
ip saddr 10.99.0.0/16 ct state new limit rate over 50/second drop
# Block everything else from containers
ip saddr 10.99.0.0/16 drop
# Allow host → containers (for SSH from the application)
ip daddr 10.99.0.0/16 accept
}
chain input {
type filter hook input priority 0; policy accept;
# Block containers from accessing Proxmox management ports
# (only SSH is allowed for the sandbox manager)
iifname "vmbr1" ip daddr 10.99.0.1 tcp dport != 22 drop
}
}
# NAT table for optional internet access
table nat {
chain postrouting {
type nat hook postrouting priority 100;
oifname "vmbr0" ip saddr 10.99.0.0/16 masquerade
}
}
```
Apply and persist:
```bash
nft -f /etc/nftables.conf
systemctl enable nftables
```
Verify:
```bash
nft list ruleset
```
### 3.4 Test Network Isolation
From a test container on `vmbr1`:
```bash
# Should work: DNS resolution
dig google.com
# Should be blocked: HTTP (not in internet_allowed set)
curl -s --connect-timeout 5 https://google.com && echo "FAIL: should be blocked" || echo "OK: blocked"
# Should be blocked: access to LAN
ping -c 1 -W 2 192.168.1.1 && echo "FAIL: LAN reachable" || echo "OK: LAN blocked"
# Should be blocked: access to Proxmox management
curl -s --connect-timeout 5 https://10.99.0.1:8006 && echo "FAIL: Proxmox reachable" || echo "OK: Proxmox blocked"
```
---
## 4. LXC Template Creation
### 4.1 Download Base Image
```bash
pveam update
pveam download local ubuntu-24.04-standard_24.04-1_amd64.tar.zst
```
### 4.2 Create Template Container
```bash
pct create 9000 local:vztmpl/ubuntu-24.04-standard_24.04-1_amd64.tar.zst \
--hostname sandbox-template \
--memory 1024 \
--swap 0 \
--cores 1 \
--rootfs local-lvm:8 \
--net0 name=eth0,bridge=vmbr1,ip=dhcp \
--unprivileged 1 \
--features nesting=0 \
--ostype ubuntu \
--ssh-public-keys /root/.ssh/mort_sandbox.pub \
--pool sandbox-pool \
--start 0
```
### 4.3 Install Base Packages
```bash
pct start 9000
pct exec 9000 -- bash -c '
apt-get update && apt-get install -y --no-install-recommends \
build-essential \
python3 python3-pip python3-venv \
nodejs npm \
git curl wget jq \
vim nano \
htop tree \
ca-certificates \
openssh-server \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
'
```
### 4.4 Create Sandbox User
```bash
pct exec 9000 -- bash -c '
# Create unprivileged sandbox user with sudo
useradd -m -s /bin/bash -G sudo sandbox
echo "sandbox ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/sandbox
# Set up SSH access
mkdir -p /home/sandbox/.ssh
cp /root/.ssh/authorized_keys /home/sandbox/.ssh/
chown -R sandbox:sandbox /home/sandbox/.ssh
chmod 700 /home/sandbox/.ssh
chmod 600 /home/sandbox/.ssh/authorized_keys
# Create uploads directory
mkdir -p /home/sandbox/uploads
chown sandbox:sandbox /home/sandbox/uploads
# Enable SSH
systemctl enable ssh
'
```
### 4.5 Security Hardening
```bash
pct exec 9000 -- bash -c '
# Process limits (prevent fork bombs)
echo "* soft nproc 256" >> /etc/security/limits.conf
echo "* hard nproc 512" >> /etc/security/limits.conf
# Disable core dumps
echo "* hard core 0" >> /etc/security/limits.conf
# Disable unnecessary services
systemctl disable systemd-resolved 2>/dev/null || true
systemctl disable snapd 2>/dev/null || true
'
```
### 4.6 Convert to Template
```bash
pct stop 9000
pct template 9000
```
### 4.7 Verify Template
Clone and test manually:
```bash
pct clone 9000 9999 --hostname test-sandbox --full
pct start 9999
# Wait for DHCP, then SSH in
ssh sandbox@<container-ip>
# Run some commands, verify packages installed
sudo apt-get update
python3 --version
node --version
# Clean up
exit
pct stop 9999
pct destroy 9999
```
---
## 5. SSH Key Setup
### 5.1 Generate Key Pair
```bash
ssh-keygen -t ed25519 -f /etc/mort/sandbox_key -N "" -C "mort-sandbox"
```
### 5.2 Install Public Key in Template
This was done in step 4.2 with `--ssh-public-keys`. If you need to update it:
```bash
pct start 9000 # Only if template — you'll need to untemplate first
# Copy key
cat /etc/mort/sandbox_key.pub | pct exec 9000 -- tee /home/sandbox/.ssh/authorized_keys
pct exec 9000 -- chown sandbox:sandbox /home/sandbox/.ssh/authorized_keys
pct exec 9000 -- chmod 600 /home/sandbox/.ssh/authorized_keys
pct stop 9000
pct template 9000
```
### 5.3 Set Permissions
```bash
chmod 600 /etc/mort/sandbox_key
chmod 644 /etc/mort/sandbox_key.pub
# If running as a specific user:
chown mort:mort /etc/mort/sandbox_key /etc/mort/sandbox_key.pub
```
---
## 6. Configuration
### Go Configuration
```go
signer, _ := sandbox.LoadSSHKey("/etc/mort/sandbox_key")
mgr, _ := sandbox.NewManager(sandbox.Config{
Proxmox: sandbox.ProxmoxConfig{
BaseURL: "https://proxmox.local:8006",
TokenID: "mort-sandbox@pve!sandbox-token",
Secret: os.Getenv("SANDBOX_PROXMOX_SECRET"),
Node: "pve",
TemplateID: 9000,
Pool: "sandbox-pool",
Bridge: "vmbr1",
InsecureSkipVerify: true, // Only for self-signed certs
},
SSH: sandbox.SSHConfig{
Signer: signer,
User: "sandbox", // default
ConnectTimeout: 10 * time.Second, // default
CommandTimeout: 60 * time.Second, // default
},
Defaults: sandbox.ContainerConfig{
CPUs: 1,
MemoryMB: 1024,
DiskGB: 8,
},
})
```
### Environment Variables
| Variable | Description |
|----------|-------------|
| `SANDBOX_PROXMOX_SECRET` | Proxmox API token secret |
| `SANDBOX_SSH_KEY_PATH` | Path to SSH private key (alternative to config) |
---
## 7. Hardening Checklist
Run through this checklist after setup:
### Container Isolation
- [ ] Containers are unprivileged (verify UID mapping in `/etc/pve/lxc/<id>.conf`)
- [ ] Nesting is disabled (`features: nesting=0`)
- [ ] Swap is disabled on containers (`swap: 0`)
- [ ] Resource pool scoping: API token can only touch `sandbox-pool`
### Network Isolation
- [ ] `vmbr1` has no physical ports (`bridge-ports none`)
- [ ] nftables rules loaded: `nft list ruleset` shows sandbox table
- [ ] nftables persists across reboots: `systemctl is-enabled nftables`
- [ ] Default-deny outbound for containers
- [ ] DNS (port 53) allowed for all containers
- [ ] HTTP/HTTPS only for containers in `internet_allowed` set
- [ ] Rate limiting active (50 conn/sec)
- [ ] Containers cannot reach Proxmox management (port 8006 blocked)
### Security Profiles
- [ ] AppArmor profile active: `lxc-container-default-cgns`
- [ ] Process limits in `/etc/security/limits.conf` (nproc 256/512)
- [ ] Core dumps disabled
- [ ] Capability drops verified in container config
### Functional Tests
- [ ] **Fork bomb test**: run `:(){ :|:& };:` in container → PID limit fires, container survives
- [ ] **OOM test**: allocate >1GB memory → container OOM-killed, host unaffected
- [ ] **Network scan test**: `nmap` from container → blocked by nftables
- [ ] **Container escape test**: attempt to mount host filesystem → denied
- [ ] **LAN access test**: ping LAN hosts → blocked
- [ ] **Cross-container test**: ping other sandbox containers → blocked
- [ ] **Internet access test**: HTTP without being in `internet_allowed` → blocked
- [ ] **Internet access test**: add to `internet_allowed` → HTTP works
- [ ] **Cleanup test**: destroy container → verify no orphan volumes
### Operational Tests
- [ ] Clone template → container starts → SSH connects → commands work
- [ ] File upload/download via SFTP works
- [ ] Container destroy removes all resources
- [ ] Orphan cleanup: kill application mid-session, restart, verify cleanup
---
## 8. Monitoring & Maintenance
### Log Rotation
Sandbox session logs should be rotated to prevent disk exhaustion. If using slog to a file:
```
# /etc/logrotate.d/sandbox
/var/log/sandbox/*.log {
daily
rotate 14
compress
delaycompress
missingok
notifempty
}
```
### Storage Cleanup
Verify destroyed containers don't leave orphan volumes:
```bash
# List all LVM volumes in the sandbox storage
lvs | grep sandbox
# Compare with running containers
pct list | grep sandbox
```
### Template Updates
Periodically update the template with latest packages:
```bash
# Un-template (creates a regular container from template)
# Note: you can't un-template directly; clone then replace
pct clone 9000 9001 --hostname template-update --full
pct start 9001
pct exec 9001 -- bash -c 'apt-get update && apt-get upgrade -y && apt-get clean'
pct stop 9001
# Destroy old template and create new one
pct destroy 9000
# Rename 9001 → 9000 (or update your config to use the new ID)
pct template 9001
```
### Proxmox Host Updates
```bash
apt-get update && apt-get dist-upgrade -y
# Reboot if kernel was updated
# Verify nftables rules are still loaded after reboot
nft list ruleset
```
---
## 9. Troubleshooting
### Container won't start
```bash
# Check task log
pct start <id>
# If error, check:
journalctl -u pve-container@<id> -n 50
# Common issues:
# - Storage full: check `df -h` and `lvs`
# - UID mapping issues: verify /etc/subuid and /etc/subgid
```
### SSH connection refused
```bash
# Verify container is running
pct status <id>
# Check if SSH is running inside container
pct exec <id> -- systemctl status ssh
# Verify IP assignment
pct exec <id> -- ip addr show eth0
# Check DHCP leases
cat /var/lib/misc/dnsmasq.leases
```
### Container has no internet (when it should)
```bash
# Verify container IP is in the internet_allowed set
nft list set inet sandbox internet_allowed
# Manually add for testing
nft add element inet sandbox internet_allowed { 10.99.1.5 }
# Verify NAT is working
nft list table nat
# Check if IP forwarding is enabled
cat /proc/sys/net/ipv4/ip_forward # Should be 1
```
### nftables rules lost after reboot
```bash
# Verify nftables is enabled
systemctl is-enabled nftables
# If rules are missing, reload
nft -f /etc/nftables.conf
# Make sure the config file is correct
nft -c -f /etc/nftables.conf # Check syntax without applying
```
### Orphaned containers
```bash
# List all containers in the sandbox pool
pvesh get /pools/sandbox-pool --output-format json | jq '.members[] | select(.type == "lxc")'
# Destroy orphans manually
pct stop <id> && pct destroy <id> --force --purge
```

View File

@@ -1,4 +1,4 @@
package go_llm
package llm
import "fmt"

View File

@@ -1,4 +1,4 @@
package go_llm
package llm
import (
"context"

View File

@@ -1,4 +1,4 @@
package go_llm
package llm
import (
"reflect"

87
go.mod
View File

@@ -1,48 +1,67 @@
module gitea.stevedudenhoeffer.com/steve/go-llm
go 1.23.1
go 1.24.0
toolchain go1.24.2
require (
github.com/google/generative-ai-go v0.19.0
github.com/liushuangls/go-anthropic/v2 v2.15.0
github.com/openai/openai-go v0.1.0-beta.9
golang.org/x/image v0.29.0
google.golang.org/api v0.228.0
github.com/charmbracelet/bubbles v0.21.0
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0
github.com/joho/godotenv v1.5.1
github.com/liushuangls/go-anthropic/v2 v2.17.0
github.com/modelcontextprotocol/go-sdk v1.2.0
github.com/openai/openai-go v1.12.0
golang.org/x/image v0.35.0
google.golang.org/genai v1.43.0
)
require (
cloud.google.com/go v0.120.0 // indirect
cloud.google.com/go/ai v0.10.1 // indirect
cloud.google.com/go/auth v0.15.0 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
cloud.google.com/go/compute/metadata v0.6.0 // indirect
cloud.google.com/go/longrunning v0.6.6 // indirect
cloud.google.com/go v0.123.0 // indirect
cloud.google.com/go/auth v0.18.1 // indirect
cloud.google.com/go/compute/metadata v0.9.0 // indirect
github.com/atotto/clipboard v0.1.4 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
github.com/charmbracelet/x/ansi v0.10.1 // indirect
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
github.com/charmbracelet/x/term v0.2.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/jsonschema-go v0.3.0 // indirect
github.com/google/s2a-go v0.1.9 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect
github.com/googleapis/gax-go/v2 v2.14.1 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.11 // indirect
github.com/googleapis/gax-go/v2 v2.16.0 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/match v1.2.0 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 // indirect
go.opentelemetry.io/otel v1.35.0 // indirect
go.opentelemetry.io/otel/metric v1.35.0 // indirect
go.opentelemetry.io/otel/trace v1.35.0 // indirect
golang.org/x/crypto v0.37.0 // indirect
golang.org/x/net v0.39.0 // indirect
golang.org/x/oauth2 v0.29.0 // indirect
golang.org/x/sync v0.16.0 // indirect
golang.org/x/sys v0.32.0 // indirect
golang.org/x/text v0.27.0 // indirect
golang.org/x/time v0.11.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a // indirect
google.golang.org/grpc v1.71.1 // indirect
google.golang.org/protobuf v1.36.6 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 // indirect
go.opentelemetry.io/otel v1.39.0 // indirect
go.opentelemetry.io/otel/metric v1.39.0 // indirect
go.opentelemetry.io/otel/trace v1.39.0 // indirect
golang.org/x/crypto v0.47.0 // indirect
golang.org/x/net v0.49.0 // indirect
golang.org/x/oauth2 v0.32.0 // indirect
golang.org/x/sys v0.40.0 // indirect
golang.org/x/text v0.33.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260122232226-8e98ce8d340d // indirect
google.golang.org/grpc v1.78.0 // indirect
google.golang.org/protobuf v1.36.11 // indirect
)

186
go.sum
View File

@@ -1,97 +1,145 @@
cloud.google.com/go v0.120.0 h1:wc6bgG9DHyKqF5/vQvX1CiZrtHnxJjBlKUyF9nP6meA=
cloud.google.com/go v0.120.0/go.mod h1:/beW32s8/pGRuj4IILWQNd4uuebeT4dkOhKmkfit64Q=
cloud.google.com/go/ai v0.10.1 h1:EU93KqYmMeOKgaBXAz2DshH2C/BzAT1P+iJORksLIic=
cloud.google.com/go/ai v0.10.1/go.mod h1:sWWHZvmJ83BjuxAQtYEiA0SFTpijtbH+SXWFO14ri5A=
cloud.google.com/go/auth v0.15.0 h1:Ly0u4aA5vG/fsSsxu98qCQBemXtAtJf+95z9HK+cxps=
cloud.google.com/go/auth v0.15.0/go.mod h1:WJDGqZ1o9E9wKIL+IwStfyn/+s59zl4Bi+1KQNVXLZ8=
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I=
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
cloud.google.com/go/longrunning v0.6.6 h1:XJNDo5MUfMM05xK3ewpbSdmt7R2Zw+aQEMbdQR65Rbw=
cloud.google.com/go/longrunning v0.6.6/go.mod h1:hyeGJUrPHcx0u2Uu1UFSoYZLn4lkMrccJig0t4FI7yw=
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
cloud.google.com/go/auth v0.18.1 h1:IwTEx92GFUo2pJ6Qea0EU3zYvKnTAeRCODxfA/G5UWs=
cloud.google.com/go/auth v0.18.1/go.mod h1:GfTYoS9G3CWpRA3Va9doKN9mjPGRS+v41jmZAhBzbrA=
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs=
github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ=
github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/generative-ai-go v0.19.0 h1:R71szggh8wHMCUlEMsW2A/3T+5LdEIkiaHSYgSpUgdg=
github.com/google/generative-ai-go v0.19.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q=
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.3.6 h1:GW/XbdyBFQ8Qe+YAmFU9uHLo7OnF5tL52HFAgMmyrf4=
github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA=
github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrkurSS/Q=
github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA=
github.com/liushuangls/go-anthropic/v2 v2.15.0 h1:zpplg7BRV/9FlMmeMPI0eDwhViB0l9SkNrF8ErYlRoQ=
github.com/liushuangls/go-anthropic/v2 v2.15.0/go.mod h1:kq2yW3JVy1/rph8u5KzX7F3q95CEpCT2RXp/2nfCmb4=
github.com/openai/openai-go v0.1.0-beta.9 h1:ABpubc5yU/3ejee2GgRrbFta81SG/d7bQbB8mIdP0Xo=
github.com/openai/openai-go v0.1.0-beta.9/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
github.com/googleapis/enterprise-certificate-proxy v0.3.11 h1:vAe81Msw+8tKUxi2Dqh/NZMz7475yUvmRIkXr4oN2ao=
github.com/googleapis/enterprise-certificate-proxy v0.3.11/go.mod h1:RFV7MUdlb7AgEq2v7FmMCfeSMCllAzWxFgRdusoGks8=
github.com/googleapis/gax-go/v2 v2.16.0 h1:iHbQmKLLZrexmb0OSsNGTeSTS0HO4YvFOG8g5E4Zd0Y=
github.com/googleapis/gax-go/v2 v2.16.0/go.mod h1:o1vfQjjNZn4+dPnRdl/4ZD7S9414Y4xA+a/6Icj6l14=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/liushuangls/go-anthropic/v2 v2.17.0 h1:iBA6h7aghi1q86owEQ95XE2R2MF/0dQ7bCxtwTxOg4c=
github.com/liushuangls/go-anthropic/v2 v2.17.0/go.mod h1:a550cJXPoTG2FL3DvfKG2zzD5O2vjgvo4tHtoGPzFLU=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s=
github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0=
github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM=
github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0 h1:x7wzEgXfnzJcHDwStJT+mxOz4etr2EcexjqhBvmoakw=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0/go.mod h1:rg+RlpR5dKwaS95IyyZqj5Wd4E13lk/msnTS0Xl9lJM=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 h1:sbiXRNDSWJOTobXh5HyQKjq6wUC5tNybqjIqDpAY4CU=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0/go.mod h1:69uWxva0WgAA/4bu2Yy70SLDBwZXuQ6PbBpbsa5iZrQ=
go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ=
go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y=
go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M=
go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE=
go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY=
go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg=
go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o=
go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w=
go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs=
go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/image v0.29.0 h1:HcdsyR4Gsuys/Axh0rDEmlBmB68rW1U9BUdB3UVHsas=
golang.org/x/image v0.29.0/go.mod h1:RVJROnf3SLK8d26OW91j4FrIHGbsJ8QnbEocVTOWQDA=
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
golang.org/x/oauth2 v0.29.0 h1:WdYw2tdTK1S8olAzWHdgeqfy+Mtm9XNhv/xJsY65d98=
golang.org/x/oauth2 v0.29.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
google.golang.org/api v0.228.0 h1:X2DJ/uoWGnY5obVjewbp8icSL5U4FzuCfy9OjbLSnLs=
google.golang.org/api v0.228.0/go.mod h1:wNvRS1Pbe8r4+IfBIniV8fwCpGwTrYa+kMUDiC5z5a4=
google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a h1:OQ7sHVzkx6L57dQpzUS4ckfWJ51KDH74XHTDe23xWAs=
google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a/go.mod h1:2R6XrVC8Oc08GlNh8ujEpc7HkLiEZ16QeY7FxIs20ac=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a h1:GIqLhp/cYUkuGuiT+vJk8vhOP86L4+SP5j8yXgeVpvI=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
google.golang.org/grpc v1.71.1 h1:ffsFWr7ygTUscGPI0KKK6TLrGz0476KUvvsbqWK0rPI=
google.golang.org/grpc v1.71.1/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 h1:ssfIgGNANqpVFCndZvcuyKbl0g+UAVcbBcqGkG28H0Y=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0/go.mod h1:GQ/474YrbE4Jx8gZ4q5I4hrhUzM6UPzyrqJYV2AqPoQ=
go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48=
go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8=
go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0=
go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs=
go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18=
go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE=
go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8=
go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew=
go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI=
go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA=
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E=
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
golang.org/x/image v0.35.0 h1:LKjiHdgMtO8z7Fh18nGY6KDcoEtVfsgLDPeLyguqb7I=
golang.org/x/image v0.35.0/go.mod h1:MwPLTVgvxSASsxdLzKrl8BRFuyqMyGhLwmC+TO1Sybk=
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY=
golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genai v1.43.0 h1:8vhqhzJNZu1U94e2m+KvDq/TUUjSmDrs1aKkvTa8SoM=
google.golang.org/genai v1.43.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260122232226-8e98ce8d340d h1:xXzuihhT3gL/ntduUZwHECzAn57E8dA6l8SOtYWdD8Q=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260122232226-8e98ce8d340d/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ=
google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc=
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

103
google.go
View File

@@ -1,4 +1,4 @@
package go_llm
package llm
import (
"context"
@@ -8,26 +8,28 @@ import (
"io"
"net/http"
"github.com/google/generative-ai-go/genai"
"google.golang.org/api/option"
"google.golang.org/genai"
)
type google struct {
type googleImpl struct {
key string
model string
}
func (g google) ModelVersion(modelVersion string) (ChatCompletion, error) {
var _ LLM = googleImpl{}
func (g googleImpl) ModelVersion(modelVersion string) (ChatCompletion, error) {
g.model = modelVersion
return g, nil
}
func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (*genai.GenerativeModel, *genai.ChatSession, []genai.Part) {
res := *model
func (g googleImpl) requestToContents(in Request) ([]*genai.Content, *genai.GenerateContentConfig) {
var contents []*genai.Content
var cfg genai.GenerateContentConfig
for _, tool := range in.Toolbox.functions {
res.Tools = append(res.Tools, &genai.Tool{
for _, tool := range in.Toolbox.Functions() {
cfg.Tools = append(cfg.Tools, &genai.Tool{
FunctionDeclarations: []*genai.FunctionDeclaration{
{
Name: tool.Name,
@@ -38,48 +40,44 @@ func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (
})
}
if !in.Toolbox.RequiresTool() {
res.ToolConfig = &genai.ToolConfig{FunctionCallingConfig: &genai.FunctionCallingConfig{
Mode: genai.FunctionCallingAny,
if in.Toolbox.RequiresTool() {
cfg.ToolConfig = &genai.ToolConfig{FunctionCallingConfig: &genai.FunctionCallingConfig{
Mode: genai.FunctionCallingConfigModeAny,
}}
}
cs := res.StartChat()
for i, c := range in.Messages {
content := genai.NewUserContent(genai.Text(c.Text))
for _, c := range in.Messages {
var role genai.Role
switch c.Role {
case RoleAssistant, RoleSystem:
content.Role = "model"
role = genai.RoleModel
case RoleUser:
content.Role = "user"
role = genai.RoleUser
}
var parts []*genai.Part
if c.Text != "" {
parts = append(parts, genai.NewPartFromText(c.Text))
}
for _, img := range c.Images {
if img.Url != "" {
// gemini does not support URLs, so we need to download the image and convert it to a blob
// Download the image from the URL
resp, err := http.Get(img.Url)
if err != nil {
panic(fmt.Sprintf("error downloading image: %v", err))
}
defer resp.Body.Close()
// Check the Content-Length to ensure it's not over 20MB
if resp.ContentLength > 20*1024*1024 {
panic(fmt.Sprintf("image size exceeds 20MB: %d bytes", resp.ContentLength))
}
// Read the content into a byte slice
data, err := io.ReadAll(resp.Body)
if err != nil {
panic(fmt.Sprintf("error reading image data: %v", err))
}
// Ensure the MIME type is appropriate
mimeType := http.DetectContentType(data)
switch mimeType {
case "image/jpeg", "image/png", "image/gif":
@@ -88,38 +86,24 @@ func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (
panic(fmt.Sprintf("unsupported image MIME type: %s", mimeType))
}
// Create a genai.Blob using the validated image data
content.Parts = append(content.Parts, genai.Blob{
MIMEType: mimeType,
Data: data,
})
parts = append(parts, genai.NewPartFromBytes(data, mimeType))
} else {
// convert base64 to blob
b, e := base64.StdEncoding.DecodeString(img.Base64)
if e != nil {
panic(fmt.Sprintf("error decoding base64: %v", e))
}
content.Parts = append(content.Parts, genai.Blob{
MIMEType: img.ContentType,
Data: b,
})
parts = append(parts, genai.NewPartFromBytes(b, img.ContentType))
}
}
// if this is the last message, we want to add to history, we want it to be the parts
if i == len(in.Messages)-1 {
return &res, cs, content.Parts
}
cs.History = append(cs.History, content)
contents = append(contents, genai.NewContentFromParts(parts, role))
}
return &res, cs, nil
return contents, &cfg
}
func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) {
func (g googleImpl) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) {
res := Response{}
for _, c := range in.Candidates {
@@ -127,15 +111,12 @@ func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Respon
var set = false
if c.Content != nil {
for _, p := range c.Content.Parts {
switch p.(type) {
case genai.Text:
choice.Content = string(p.(genai.Text))
if p.Text != "" {
set = true
case genai.FunctionCall:
v := p.(genai.FunctionCall)
choice.Content = p.Text
} else if p.FunctionCall != nil {
v := p.FunctionCall
b, e := json.Marshal(v.Args)
if e != nil {
return Response{}, fmt.Errorf("error marshalling args: %w", e)
}
@@ -150,8 +131,6 @@ func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Respon
choice.Calls = append(choice.Calls, call)
set = true
default:
return Response{}, fmt.Errorf("unknown part type: %T", p)
}
}
}
@@ -165,23 +144,19 @@ func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Respon
return res, nil
}
func (g google) ChatComplete(ctx context.Context, req Request) (Response, error) {
cl, err := genai.NewClient(ctx, option.WithAPIKey(g.key))
func (g googleImpl) ChatComplete(ctx context.Context, req Request) (Response, error) {
cl, err := genai.NewClient(ctx, &genai.ClientConfig{
APIKey: g.key,
Backend: genai.BackendGeminiAPI,
})
if err != nil {
return Response{}, fmt.Errorf("error creating genai client: %w", err)
}
model := cl.GenerativeModel(g.model)
_, cs, parts := g.requestToChatHistory(req, model)
resp, err := cs.SendMessage(ctx, parts...)
//parts := g.requestToGoogleRequest(req, model)
//resp, err := model.GenerateContent(ctx, parts...)
contents, cfg := g.requestToContents(req)
resp, err := cl.Models.GenerateContent(ctx, g.model, contents, cfg)
if err != nil {
return Response{}, fmt.Errorf("error generating content: %w", err)
}

View File

@@ -1,4 +1,4 @@
package utils
package imageutil
import (
"bytes"
@@ -12,8 +12,8 @@ import (
"golang.org/x/image/draw"
)
// CompressImage takes a base64encoded image (JPEG, PNG or GIF) and returns
// a base64encoded version that is at most maxLength in size, or an error.
// CompressImage takes a base-64-encoded image (JPEG, PNG or GIF) and returns
// a base-64-encoded version that is at most maxLength in size, or an error.
func CompressImage(b64 string, maxLength int) (string, string, error) {
raw, err := base64.StdEncoding.DecodeString(b64)
if err != nil {
@@ -29,12 +29,12 @@ func CompressImage(b64 string, maxLength int) (string, string, error) {
case "image/gif":
return compressGIF(raw, maxLength)
default: // jpeg, png, webp, etc. treat as raster
default: // jpeg, png, webp, etc. -> treat as raster
return compressRaster(raw, maxLength)
}
}
// ---------- Raster path (jpeg / png / singleframe gif) ----------
// ---------- Raster path (jpeg / png / single-frame gif) ----------
func compressRaster(src []byte, maxLength int) (string, string, error) {
img, _, err := image.Decode(bytes.NewReader(src))
@@ -57,7 +57,7 @@ func compressRaster(src []byte, maxLength int) (string, string, error) {
continue
}
// downscale 80%
// down-scale 80%
b := img.Bounds()
if b.Dx() < 100 || b.Dy() < 100 {
return "", "", fmt.Errorf("cannot compress below %.02fMiB without destroying image", float64(maxLength)/1048576.0)
@@ -86,7 +86,7 @@ func compressGIF(src []byte, maxLength int) (string, string, error) {
return base64.StdEncoding.EncodeToString(buf.Bytes()), "image/gif", nil
}
// downscale every frame by 80%
// down-scale every frame by 80%
w, h := g.Config.Width, g.Config.Height
if w < 100 || h < 100 {
return "", "", fmt.Errorf("cannot compress animated GIF below 5 MiB without excessive quality loss")
@@ -94,7 +94,7 @@ func compressGIF(src []byte, maxLength int) (string, string, error) {
nw, nh := int(float64(w)*0.8), int(float64(h)*0.8)
for i, frm := range g.Image {
// convert paletted frame RGBA for scaling
// convert paletted frame -> RGBA for scaling
rgba := image.NewRGBA(frm.Bounds())
draw.Draw(rgba, rgba.Bounds(), frm, frm.Bounds().Min, draw.Src)
@@ -109,6 +109,6 @@ func compressGIF(src []byte, maxLength int) (string, string, error) {
g.Image[i] = paletted
}
g.Config.Width, g.Config.Height = nw, nh
// loop back and test size again
// loop back and test size again ...
}
}

272
llm.go
View File

@@ -1,286 +1,30 @@
package go_llm
package llm
import (
"context"
"fmt"
"strings"
"github.com/openai/openai-go"
"github.com/openai/openai-go/packages/param"
)
type Role string
const (
RoleSystem Role = "system"
RoleUser Role = "user"
RoleAssistant Role = "assistant"
)
type Image struct {
Base64 string
ContentType string
Url string
}
func (i Image) toRaw() map[string]any {
res := map[string]any{
"base64": i.Base64,
"contenttype": i.ContentType,
"url": i.Url,
}
return res
}
func (i *Image) fromRaw(raw map[string]any) Image {
var res Image
res.Base64 = raw["base64"].(string)
res.ContentType = raw["contenttype"].(string)
res.Url = raw["url"].(string)
return res
}
type Message struct {
Role Role
Name string
Text string
Images []Image
}
func (m Message) toRaw() map[string]any {
res := map[string]any{
"role": m.Role,
"name": m.Name,
"text": m.Text,
}
images := make([]map[string]any, 0, len(m.Images))
for _, img := range m.Images {
images = append(images, img.toRaw())
}
res["images"] = images
return res
}
func (m *Message) fromRaw(raw map[string]any) Message {
var res Message
res.Role = Role(raw["role"].(string))
res.Name = raw["name"].(string)
res.Text = raw["text"].(string)
images := raw["images"].([]map[string]any)
for _, img := range images {
var i Image
res.Images = append(res.Images, i.fromRaw(img))
}
return res
}
func (m Message) toChatCompletionMessages(model string) []openai.ChatCompletionMessageParamUnion {
var res openai.ChatCompletionMessageParamUnion
var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam
var textContent param.Opt[string]
for _, img := range m.Images {
if img.Base64 != "" {
arrayOfContentParts = append(arrayOfContentParts,
openai.ChatCompletionContentPartUnionParam{
OfImageURL: &openai.ChatCompletionContentPartImageParam{
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
URL: "data:" + img.ContentType + ";base64," + img.Base64,
},
},
},
)
} else if img.Url != "" {
arrayOfContentParts = append(arrayOfContentParts,
openai.ChatCompletionContentPartUnionParam{
OfImageURL: &openai.ChatCompletionContentPartImageParam{
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
URL: img.Url,
},
},
},
)
}
}
if m.Text != "" {
if len(arrayOfContentParts) > 0 {
arrayOfContentParts = append(arrayOfContentParts,
openai.ChatCompletionContentPartUnionParam{
OfText: &openai.ChatCompletionContentPartTextParam{
Text: "\n",
},
},
)
} else {
textContent = openai.String(m.Text)
}
}
a := strings.Split(model, "-")
useSystemInsteadOfDeveloper := true
if len(a) > 1 && a[0][0] == 'o' {
useSystemInsteadOfDeveloper = false
}
switch m.Role {
case RoleSystem:
if useSystemInsteadOfDeveloper {
res = openai.ChatCompletionMessageParamUnion{
OfSystem: &openai.ChatCompletionSystemMessageParam{
Content: openai.ChatCompletionSystemMessageParamContentUnion{
OfString: textContent,
},
},
}
} else {
res = openai.ChatCompletionMessageParamUnion{
OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{
Content: openai.ChatCompletionDeveloperMessageParamContentUnion{
OfString: textContent,
},
},
}
}
case RoleUser:
var name param.Opt[string]
if m.Name != "" {
name = openai.String(m.Name)
}
res = openai.ChatCompletionMessageParamUnion{
OfUser: &openai.ChatCompletionUserMessageParam{
Name: name,
Content: openai.ChatCompletionUserMessageParamContentUnion{
OfString: textContent,
OfArrayOfContentParts: arrayOfContentParts,
},
},
}
case RoleAssistant:
var name param.Opt[string]
if m.Name != "" {
name = openai.String(m.Name)
}
res = openai.ChatCompletionMessageParamUnion{
OfAssistant: &openai.ChatCompletionAssistantMessageParam{
Name: name,
Content: openai.ChatCompletionAssistantMessageParamContentUnion{
OfString: textContent,
},
},
}
}
return []openai.ChatCompletionMessageParamUnion{res}
}
type ToolCall struct {
ID string
FunctionCall FunctionCall
}
func (t ToolCall) toRaw() map[string]any {
res := map[string]any{
"id": t.ID,
}
res["function"] = t.FunctionCall.toRaw()
return res
}
func (t ToolCall) toChatCompletionMessages(_ string) []openai.ChatCompletionMessageParamUnion {
return []openai.ChatCompletionMessageParamUnion{{
OfAssistant: &openai.ChatCompletionAssistantMessageParam{
ToolCalls: []openai.ChatCompletionMessageToolCallParam{
{
ID: t.ID,
Function: openai.ChatCompletionMessageToolCallFunctionParam{
Name: t.FunctionCall.Name,
Arguments: t.FunctionCall.Arguments,
},
},
},
},
}}
}
type ToolCallResponse struct {
ID string
Result any
Error error
}
func (t ToolCallResponse) toRaw() map[string]any {
res := map[string]any{
"id": t.ID,
"result": t.Result,
}
if t.Error != nil {
res["error"] = t.Error.Error()
}
return res
}
func (t ToolCallResponse) toChatCompletionMessages(_ string) []openai.ChatCompletionMessageParamUnion {
var refusal string
if t.Error != nil {
refusal = t.Error.Error()
}
if refusal != "" {
if t.Result != "" {
t.Result = fmt.Sprint(t.Result) + " (error in execution: " + refusal + ")"
} else {
t.Result = "error in execution:" + refusal
}
}
return []openai.ChatCompletionMessageParamUnion{{
OfTool: &openai.ChatCompletionToolMessageParam{
ToolCallID: t.ID,
Content: openai.ChatCompletionToolMessageParamContentUnion{
OfString: openai.String(fmt.Sprint(t.Result)),
},
},
}}
}
// ChatCompletion is the interface for chat completion.
type ChatCompletion interface {
ChatComplete(ctx context.Context, req Request) (Response, error)
}
// LLM is the interface for language model providers.
type LLM interface {
ModelVersion(modelVersion string) (ChatCompletion, error)
}
// OpenAI creates a new OpenAI LLM provider with the given API key.
func OpenAI(key string) LLM {
return openaiImpl{key: key}
}
// Anthropic creates a new Anthropic LLM provider with the given API key.
func Anthropic(key string) LLM {
return anthropic{key: key}
return anthropicImpl{key: key}
}
// Google creates a new Google LLM provider with the given API key.
func Google(key string) LLM {
return google{key: key}
return googleImpl{key: key}
}

238
mcp.go Normal file
View File

@@ -0,0 +1,238 @@
package llm
import (
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"sync"
"github.com/modelcontextprotocol/go-sdk/mcp"
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
)
// MCPServer represents a connection to an MCP server.
// It manages the lifecycle of the connection and provides access to the server's tools.
type MCPServer struct {
// Name is a friendly name for this server (used for logging/identification)
Name string
// Command is the command to run the MCP server (for stdio transport)
Command string
// Args are arguments to pass to the command
Args []string
// Env are environment variables to set for the command (in addition to current environment)
Env []string
// URL is the URL for SSE or HTTP transport (alternative to Command)
URL string
// Transport specifies the transport type: "stdio" (default), "sse", or "http"
Transport string
client *mcp.Client
session *mcp.ClientSession
tools map[string]*mcp.Tool // tool name -> tool definition
mu sync.RWMutex
}
// Connect establishes a connection to the MCP server.
func (m *MCPServer) Connect(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.session != nil {
return nil // Already connected
}
m.client = mcp.NewClient(&mcp.Implementation{
Name: "go-llm",
Version: "1.0.0",
}, nil)
var transport mcp.Transport
switch m.Transport {
case "sse":
transport = &mcp.SSEClientTransport{
Endpoint: m.URL,
}
case "http":
transport = &mcp.StreamableClientTransport{
Endpoint: m.URL,
}
default: // "stdio" or empty
cmd := exec.Command(m.Command, m.Args...)
cmd.Env = append(os.Environ(), m.Env...)
transport = &mcp.CommandTransport{
Command: cmd,
}
}
session, err := m.client.Connect(ctx, transport, nil)
if err != nil {
return fmt.Errorf("failed to connect to MCP server %s: %w", m.Name, err)
}
m.session = session
// Load tools
m.tools = make(map[string]*mcp.Tool)
for tool, err := range session.Tools(ctx, nil) {
if err != nil {
m.session.Close()
m.session = nil
return fmt.Errorf("failed to list tools from %s: %w", m.Name, err)
}
m.tools[tool.Name] = tool
}
return nil
}
// Close closes the connection to the MCP server.
func (m *MCPServer) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
if m.session == nil {
return nil
}
err := m.session.Close()
m.session = nil
m.tools = nil
return err
}
// IsConnected returns true if the server is connected.
func (m *MCPServer) IsConnected() bool {
m.mu.RLock()
defer m.mu.RUnlock()
return m.session != nil
}
// Tools returns the list of tool names available from this server.
func (m *MCPServer) Tools() []string {
m.mu.RLock()
defer m.mu.RUnlock()
var names []string
for name := range m.tools {
names = append(names, name)
}
return names
}
// HasTool returns true if this server provides the named tool.
func (m *MCPServer) HasTool(name string) bool {
m.mu.RLock()
defer m.mu.RUnlock()
_, ok := m.tools[name]
return ok
}
// CallTool calls a tool on the MCP server.
func (m *MCPServer) CallTool(ctx context.Context, name string, arguments map[string]any) (any, error) {
m.mu.RLock()
session := m.session
m.mu.RUnlock()
if session == nil {
return nil, fmt.Errorf("not connected to MCP server %s", m.Name)
}
result, err := session.CallTool(ctx, &mcp.CallToolParams{
Name: name,
Arguments: arguments,
})
if err != nil {
return nil, err
}
// Process the result content
if len(result.Content) == 0 {
return nil, nil
}
// If there's a single text content, return it as a string
if len(result.Content) == 1 {
if textContent, ok := result.Content[0].(*mcp.TextContent); ok {
return textContent.Text, nil
}
}
// For multiple contents or non-text, serialize to string
return contentToString(result.Content), nil
}
// toFunction converts an MCP tool to a go-llm Function (for schema purposes only).
func (m *MCPServer) toFunction(tool *mcp.Tool) Function {
var inputSchema map[string]any
if tool.InputSchema != nil {
data, err := json.Marshal(tool.InputSchema)
if err == nil {
_ = json.Unmarshal(data, &inputSchema)
}
}
if inputSchema == nil {
inputSchema = map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
return Function{
Name: tool.Name,
Description: tool.Description,
Parameters: schema.NewRaw(inputSchema),
}
}
// contentToString converts MCP content to a string representation.
func contentToString(content []mcp.Content) string {
var parts []string
for _, c := range content {
switch tc := c.(type) {
case *mcp.TextContent:
parts = append(parts, tc.Text)
default:
if data, err := json.Marshal(c); err == nil {
parts = append(parts, string(data))
}
}
}
if len(parts) == 1 {
return parts[0]
}
data, _ := json.Marshal(parts)
return string(data)
}
// WithMCPServer adds an MCP server to the toolbox.
// The server must already be connected. Tools from the server will be available
// for use, and tool calls will be routed to the appropriate server.
func (t ToolBox) WithMCPServer(server *MCPServer) ToolBox {
if t.mcpServers == nil {
t.mcpServers = make(map[string]*MCPServer)
}
server.mu.RLock()
defer server.mu.RUnlock()
for name, tool := range server.tools {
// Add the function definition (for schema)
fn := server.toFunction(tool)
t.functions[name] = fn
// Track which server owns this tool
t.mcpServers[name] = server
}
return t
}

115
message.go Normal file
View File

@@ -0,0 +1,115 @@
package llm
// Role represents the role of a message in a conversation.
type Role string
const (
RoleSystem Role = "system"
RoleUser Role = "user"
RoleAssistant Role = "assistant"
)
// Image represents an image that can be included in a message.
type Image struct {
Base64 string
ContentType string
Url string
}
func (i Image) toRaw() map[string]any {
res := map[string]any{
"base64": i.Base64,
"contenttype": i.ContentType,
"url": i.Url,
}
return res
}
func (i *Image) fromRaw(raw map[string]any) Image {
var res Image
res.Base64 = raw["base64"].(string)
res.ContentType = raw["contenttype"].(string)
res.Url = raw["url"].(string)
return res
}
// Message represents a message in a conversation.
type Message struct {
Role Role
Name string
Text string
Images []Image
}
func (m Message) toRaw() map[string]any {
res := map[string]any{
"role": m.Role,
"name": m.Name,
"text": m.Text,
}
images := make([]map[string]any, 0, len(m.Images))
for _, img := range m.Images {
images = append(images, img.toRaw())
}
res["images"] = images
return res
}
func (m *Message) fromRaw(raw map[string]any) Message {
var res Message
res.Role = Role(raw["role"].(string))
res.Name = raw["name"].(string)
res.Text = raw["text"].(string)
images := raw["images"].([]map[string]any)
for _, img := range images {
var i Image
res.Images = append(res.Images, i.fromRaw(img))
}
return res
}
// ToolCall represents a tool call made by an assistant.
type ToolCall struct {
ID string
FunctionCall FunctionCall
}
func (t ToolCall) toRaw() map[string]any {
res := map[string]any{
"id": t.ID,
}
res["function"] = t.FunctionCall.toRaw()
return res
}
// ToolCallResponse represents the response to a tool call.
type ToolCallResponse struct {
ID string
Result any
Error error
}
func (t ToolCallResponse) toRaw() map[string]any {
res := map[string]any{
"id": t.ID,
"result": t.Result,
}
if t.Error != nil {
res["error"] = t.Error.Error()
}
return res
}

210
openai.go
View File

@@ -1,4 +1,4 @@
package go_llm
package llm
import (
"context"
@@ -7,6 +7,7 @@ import (
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/param"
"github.com/openai/openai-go/shared"
)
@@ -24,14 +25,14 @@ func (o openaiImpl) newRequestToOpenAIRequest(request Request) openai.ChatComple
}
for _, i := range request.Conversation {
res.Messages = append(res.Messages, i.toChatCompletionMessages(o.model)...)
res.Messages = append(res.Messages, inputToChatCompletionMessages(i, o.model)...)
}
for _, msg := range request.Messages {
res.Messages = append(res.Messages, msg.toChatCompletionMessages(o.model)...)
res.Messages = append(res.Messages, messageToChatCompletionMessages(msg, o.model)...)
}
for _, tool := range request.Toolbox.functions {
for _, tool := range request.Toolbox.Functions() {
res.Tools = append(res.Tools, openai.ChatCompletionToolParam{
Type: "function",
Function: shared.FunctionDefinitionParam{
@@ -111,10 +112,9 @@ func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response
req := o.newRequestToOpenAIRequest(request)
resp, err := cl.Chat.Completions.New(ctx, req)
//resp, err := cl.CreateChatCompletion(ctx, req)
if err != nil {
return Response{}, fmt.Errorf("unhandled openaiImpl error: %w", err)
return Response{}, fmt.Errorf("unhandled openai error: %w", err)
}
return o.responseToLLMResponse(resp), nil
@@ -122,7 +122,201 @@ func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response
func (o openaiImpl) ModelVersion(modelVersion string) (ChatCompletion, error) {
return openaiImpl{
key: o.key,
model: modelVersion,
key: o.key,
model: modelVersion,
baseUrl: o.baseUrl,
}, nil
}
// inputToChatCompletionMessages converts an Input to OpenAI chat completion messages.
func inputToChatCompletionMessages(input Input, model string) []openai.ChatCompletionMessageParamUnion {
switch v := input.(type) {
case Message:
return messageToChatCompletionMessages(v, model)
case ToolCall:
return toolCallToChatCompletionMessages(v)
case ToolCallResponse:
return toolCallResponseToChatCompletionMessages(v)
case ResponseChoice:
return responseChoiceToChatCompletionMessages(v)
default:
return nil
}
}
func messageToChatCompletionMessages(m Message, model string) []openai.ChatCompletionMessageParamUnion {
var res openai.ChatCompletionMessageParamUnion
var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam
var textContent param.Opt[string]
for _, img := range m.Images {
if img.Base64 != "" {
arrayOfContentParts = append(arrayOfContentParts,
openai.ChatCompletionContentPartUnionParam{
OfImageURL: &openai.ChatCompletionContentPartImageParam{
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
URL: "data:" + img.ContentType + ";base64," + img.Base64,
},
},
},
)
} else if img.Url != "" {
arrayOfContentParts = append(arrayOfContentParts,
openai.ChatCompletionContentPartUnionParam{
OfImageURL: &openai.ChatCompletionContentPartImageParam{
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
URL: img.Url,
},
},
},
)
}
}
if m.Text != "" {
if len(arrayOfContentParts) > 0 {
arrayOfContentParts = append(arrayOfContentParts,
openai.ChatCompletionContentPartUnionParam{
OfText: &openai.ChatCompletionContentPartTextParam{
Text: "\n",
},
},
)
} else {
textContent = openai.String(m.Text)
}
}
a := strings.Split(model, "-")
useSystemInsteadOfDeveloper := true
if len(a) > 1 && a[0][0] == 'o' {
useSystemInsteadOfDeveloper = false
}
switch m.Role {
case RoleSystem:
if useSystemInsteadOfDeveloper {
res = openai.ChatCompletionMessageParamUnion{
OfSystem: &openai.ChatCompletionSystemMessageParam{
Content: openai.ChatCompletionSystemMessageParamContentUnion{
OfString: textContent,
},
},
}
} else {
res = openai.ChatCompletionMessageParamUnion{
OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{
Content: openai.ChatCompletionDeveloperMessageParamContentUnion{
OfString: textContent,
},
},
}
}
case RoleUser:
var name param.Opt[string]
if m.Name != "" {
name = openai.String(m.Name)
}
res = openai.ChatCompletionMessageParamUnion{
OfUser: &openai.ChatCompletionUserMessageParam{
Name: name,
Content: openai.ChatCompletionUserMessageParamContentUnion{
OfString: textContent,
OfArrayOfContentParts: arrayOfContentParts,
},
},
}
case RoleAssistant:
var name param.Opt[string]
if m.Name != "" {
name = openai.String(m.Name)
}
res = openai.ChatCompletionMessageParamUnion{
OfAssistant: &openai.ChatCompletionAssistantMessageParam{
Name: name,
Content: openai.ChatCompletionAssistantMessageParamContentUnion{
OfString: textContent,
},
},
}
}
return []openai.ChatCompletionMessageParamUnion{res}
}
func toolCallToChatCompletionMessages(t ToolCall) []openai.ChatCompletionMessageParamUnion {
return []openai.ChatCompletionMessageParamUnion{{
OfAssistant: &openai.ChatCompletionAssistantMessageParam{
ToolCalls: []openai.ChatCompletionMessageToolCallParam{
{
ID: t.ID,
Function: openai.ChatCompletionMessageToolCallFunctionParam{
Name: t.FunctionCall.Name,
Arguments: t.FunctionCall.Arguments,
},
},
},
},
}}
}
func toolCallResponseToChatCompletionMessages(t ToolCallResponse) []openai.ChatCompletionMessageParamUnion {
var refusal string
if t.Error != nil {
refusal = t.Error.Error()
}
result := t.Result
if refusal != "" {
if result != "" {
result = fmt.Sprint(result) + " (error in execution: " + refusal + ")"
} else {
result = "error in execution:" + refusal
}
}
return []openai.ChatCompletionMessageParamUnion{{
OfTool: &openai.ChatCompletionToolMessageParam{
ToolCallID: t.ID,
Content: openai.ChatCompletionToolMessageParamContentUnion{
OfString: openai.String(fmt.Sprint(result)),
},
},
}}
}
func responseChoiceToChatCompletionMessages(r ResponseChoice) []openai.ChatCompletionMessageParamUnion {
var as openai.ChatCompletionAssistantMessageParam
if r.Name != "" {
as.Name = openai.String(r.Name)
}
if r.Refusal != "" {
as.Refusal = openai.String(r.Refusal)
}
if r.Content != "" {
as.Content.OfString = openai.String(r.Content)
}
for _, call := range r.Calls {
as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{
ID: call.ID,
Function: openai.ChatCompletionMessageToolCallFunctionParam{
Name: call.FunctionCall.Name,
Arguments: call.FunctionCall.Arguments,
},
})
}
return []openai.ChatCompletionMessageParamUnion{
{
OfAssistant: &as,
},
}
}

219
openai_transcriber.go Normal file
View File

@@ -0,0 +1,219 @@
package llm
import (
"bytes"
"context"
"encoding/json"
"fmt"
"strings"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
)
type openaiTranscriber struct {
key string
model string
baseUrl string
}
var _ Transcriber = openaiTranscriber{}
// OpenAITranscriber creates a transcriber backed by OpenAI's audio models.
// If model is empty, whisper-1 is used by default.
func OpenAITranscriber(key string, model string) Transcriber {
if strings.TrimSpace(model) == "" {
model = "whisper-1"
}
return openaiTranscriber{
key: key,
model: model,
}
}
func (o openaiTranscriber) Transcribe(ctx context.Context, wav []byte, opts TranscriptionOptions) (Transcription, error) {
if len(wav) == 0 {
return Transcription{}, fmt.Errorf("wav data is empty")
}
format := opts.ResponseFormat
if format == "" {
if strings.HasPrefix(o.model, "gpt-4o") {
format = TranscriptionResponseFormatJSON
} else {
format = TranscriptionResponseFormatVerboseJSON
}
}
if format != TranscriptionResponseFormatJSON && format != TranscriptionResponseFormatVerboseJSON {
return Transcription{}, fmt.Errorf("openai transcriber requires response_format json or verbose_json for structured output")
}
if len(opts.TimestampGranularities) > 0 && format != TranscriptionResponseFormatVerboseJSON {
return Transcription{}, fmt.Errorf("timestamp granularities require response_format=verbose_json")
}
params := openai.AudioTranscriptionNewParams{
File: openai.File(bytes.NewReader(wav), "audio.wav", "audio/wav"),
Model: openai.AudioModel(o.model),
}
if opts.Language != "" {
params.Language = openai.String(opts.Language)
}
if opts.Prompt != "" {
params.Prompt = openai.String(opts.Prompt)
}
if opts.Temperature != nil {
params.Temperature = openai.Float(*opts.Temperature)
}
params.ResponseFormat = openai.AudioResponseFormat(format)
if opts.IncludeLogprobs {
params.Include = []openai.TranscriptionInclude{openai.TranscriptionIncludeLogprobs}
}
if len(opts.TimestampGranularities) > 0 {
for _, granularity := range opts.TimestampGranularities {
params.TimestampGranularities = append(params.TimestampGranularities, string(granularity))
}
}
clientOptions := []option.RequestOption{
option.WithAPIKey(o.key),
}
if o.baseUrl != "" {
clientOptions = append(clientOptions, option.WithBaseURL(o.baseUrl))
}
client := openai.NewClient(clientOptions...)
resp, err := client.Audio.Transcriptions.New(ctx, params)
if err != nil {
return Transcription{}, fmt.Errorf("openai transcription failed: %w", err)
}
return openaiTranscriptionToResult(o.model, resp), nil
}
type openaiVerboseTranscription struct {
Text string `json:"text"`
Language string `json:"language"`
Duration float64 `json:"duration"`
Segments []openaiVerboseSegment `json:"segments"`
Words []openaiVerboseWord `json:"words"`
}
type openaiVerboseSegment struct {
ID int `json:"id"`
Start float64 `json:"start"`
End float64 `json:"end"`
Text string `json:"text"`
Tokens []int `json:"tokens"`
AvgLogprob *float64 `json:"avg_logprob"`
CompressionRatio *float64 `json:"compression_ratio"`
NoSpeechProb *float64 `json:"no_speech_prob"`
Words []openaiVerboseWord `json:"words"`
}
type openaiVerboseWord struct {
Word string `json:"word"`
Start float64 `json:"start"`
End float64 `json:"end"`
}
func openaiTranscriptionToResult(model string, resp *openai.Transcription) Transcription {
result := Transcription{
Provider: "openai",
Model: model,
}
if resp == nil {
return result
}
result.Text = resp.Text
result.RawJSON = resp.RawJSON()
for _, logprob := range resp.Logprobs {
result.Logprobs = append(result.Logprobs, TranscriptionTokenLogprob{
Token: logprob.Token,
Bytes: logprob.Bytes,
Logprob: logprob.Logprob,
})
}
if usage := openaiUsageToTranscriptionUsage(resp.Usage); usage.Type != "" {
result.Usage = usage
}
if result.RawJSON == "" {
return result
}
var verbose openaiVerboseTranscription
if err := json.Unmarshal([]byte(result.RawJSON), &verbose); err != nil {
return result
}
if verbose.Text != "" {
result.Text = verbose.Text
}
result.Language = verbose.Language
result.DurationSeconds = verbose.Duration
for _, seg := range verbose.Segments {
segment := TranscriptionSegment{
ID: seg.ID,
Start: seg.Start,
End: seg.End,
Text: seg.Text,
Tokens: append([]int(nil), seg.Tokens...),
AvgLogprob: seg.AvgLogprob,
CompressionRatio: seg.CompressionRatio,
NoSpeechProb: seg.NoSpeechProb,
}
for _, word := range seg.Words {
segment.Words = append(segment.Words, TranscriptionWord{
Word: word.Word,
Start: word.Start,
End: word.End,
})
}
result.Segments = append(result.Segments, segment)
}
for _, word := range verbose.Words {
result.Words = append(result.Words, TranscriptionWord{
Word: word.Word,
Start: word.Start,
End: word.End,
})
}
return result
}
func openaiUsageToTranscriptionUsage(usage openai.TranscriptionUsageUnion) TranscriptionUsage {
switch usage.Type {
case "tokens":
tokens := usage.AsTokens()
return TranscriptionUsage{
Type: usage.Type,
InputTokens: tokens.InputTokens,
OutputTokens: tokens.OutputTokens,
TotalTokens: tokens.TotalTokens,
AudioTokens: tokens.InputTokenDetails.AudioTokens,
TextTokens: tokens.InputTokenDetails.TextTokens,
}
case "duration":
duration := usage.AsDuration()
return TranscriptionUsage{
Type: usage.Type,
Seconds: duration.Seconds,
}
default:
return TranscriptionUsage{}
}
}

View File

@@ -1,4 +1,4 @@
package go_llm
package llm
import (
"strings"

View File

@@ -0,0 +1,11 @@
// Package anthropic provides the Anthropic LLM provider.
package anthropic
import (
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
)
// New creates a new Anthropic LLM provider with the given API key.
func New(key string) llm.LLM {
return llm.Anthropic(key)
}

11
provider/google/google.go Normal file
View File

@@ -0,0 +1,11 @@
// Package google provides the Google LLM provider.
package google
import (
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
)
// New creates a new Google LLM provider with the given API key.
func New(key string) llm.LLM {
return llm.Google(key)
}

11
provider/openai/openai.go Normal file
View File

@@ -0,0 +1,11 @@
// Package openai provides the OpenAI LLM provider.
package openai
import (
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
)
// New creates a new OpenAI LLM provider with the given API key.
func New(key string) llm.LLM {
return llm.OpenAI(key)
}

View File

@@ -1,17 +1,20 @@
package go_llm
import (
"github.com/openai/openai-go"
)
type rawAble interface {
toRaw() map[string]any
fromRaw(raw map[string]any) Input
}
package llm
// Input is the interface for conversation inputs.
// Types that implement this interface can be part of a conversation:
// Message, ToolCall, ToolCallResponse, and ResponseChoice.
type Input interface {
toChatCompletionMessages(model string) []openai.ChatCompletionMessageParamUnion
// isInput is a marker method to ensure only valid types implement this interface.
isInput()
}
// Implement Input interface for all valid input types.
func (Message) isInput() {}
func (ToolCall) isInput() {}
func (ToolCallResponse) isInput() {}
func (ResponseChoice) isInput() {}
// Request represents a request to a language model.
type Request struct {
Conversation []Input
Messages []Message

View File

@@ -1,9 +1,6 @@
package go_llm
import (
"github.com/openai/openai-go"
)
package llm
// ResponseChoice represents a single choice in a response.
type ResponseChoice struct {
Index int
Role Role
@@ -32,36 +29,6 @@ func (r ResponseChoice) toRaw() map[string]any {
return res
}
func (r ResponseChoice) toChatCompletionMessages(_ string) []openai.ChatCompletionMessageParamUnion {
var as openai.ChatCompletionAssistantMessageParam
if r.Name != "" {
as.Name = openai.String(r.Name)
}
if r.Refusal != "" {
as.Refusal = openai.String(r.Refusal)
}
if r.Content != "" {
as.Content.OfString = openai.String(r.Content)
}
for _, call := range r.Calls {
as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{
ID: call.ID,
Function: openai.ChatCompletionMessageToolCallFunctionParam{
Name: call.FunctionCall.Name,
Arguments: call.FunctionCall.Arguments,
},
})
}
return []openai.ChatCompletionMessageParamUnion{
{
OfAssistant: &as,
},
}
}
func (r ResponseChoice) toInput() []Input {
var res []Input
@@ -79,6 +46,7 @@ func (r ResponseChoice) toInput() []Input {
return res
}
// Response represents a response from a language model.
type Response struct {
Choices []ResponseChoice
}

View File

@@ -4,8 +4,8 @@ import (
"errors"
"reflect"
"github.com/google/generative-ai-go/genai"
"github.com/openai/openai-go"
"google.golang.org/genai"
)
type array struct {

View File

@@ -5,8 +5,8 @@ import (
"reflect"
"strconv"
"github.com/google/generative-ai-go/genai"
"github.com/openai/openai-go"
"google.golang.org/genai"
)
// just enforcing that basic implements Type

View File

@@ -5,8 +5,8 @@ import (
"reflect"
"slices"
"github.com/google/generative-ai-go/genai"
"github.com/openai/openai-go"
"google.golang.org/genai"
)
type enum struct {

View File

@@ -4,8 +4,8 @@ import (
"errors"
"reflect"
"github.com/google/generative-ai-go/genai"
"github.com/openai/openai-go"
"google.golang.org/genai"
)
const (

134
schema/raw.go Normal file
View File

@@ -0,0 +1,134 @@
package schema
import (
"encoding/json"
"fmt"
"reflect"
"github.com/openai/openai-go"
"google.golang.org/genai"
)
// Raw represents a raw JSON schema that is passed through directly.
// This is used for MCP tools where we receive the schema from the server.
type Raw struct {
schema map[string]any
}
// NewRaw creates a new Raw schema from a map.
func NewRaw(schema map[string]any) Raw {
if schema == nil {
schema = map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
return Raw{schema: schema}
}
// NewRawFromJSON creates a new Raw schema from JSON bytes.
func NewRawFromJSON(data []byte) (Raw, error) {
var schema map[string]any
if err := json.Unmarshal(data, &schema); err != nil {
return Raw{}, fmt.Errorf("failed to parse JSON schema: %w", err)
}
return NewRaw(schema), nil
}
func (r Raw) OpenAIParameters() openai.FunctionParameters {
return openai.FunctionParameters(r.schema)
}
func (r Raw) GoogleParameters() *genai.Schema {
return mapToGenaiSchema(r.schema)
}
func (r Raw) AnthropicInputSchema() map[string]any {
return r.schema
}
func (r Raw) Required() bool {
return false
}
func (r Raw) Description() string {
if desc, ok := r.schema["description"].(string); ok {
return desc
}
return ""
}
func (r Raw) FromAny(val any) (reflect.Value, error) {
return reflect.ValueOf(val), nil
}
func (r Raw) SetValueOnField(obj reflect.Value, val reflect.Value) {
// No-op for raw schemas
}
// mapToGenaiSchema converts a map[string]any JSON schema to genai.Schema
func mapToGenaiSchema(m map[string]any) *genai.Schema {
if m == nil {
return nil
}
schema := &genai.Schema{}
// Type
if t, ok := m["type"].(string); ok {
switch t {
case "string":
schema.Type = genai.TypeString
case "number":
schema.Type = genai.TypeNumber
case "integer":
schema.Type = genai.TypeInteger
case "boolean":
schema.Type = genai.TypeBoolean
case "array":
schema.Type = genai.TypeArray
case "object":
schema.Type = genai.TypeObject
}
}
// Description
if desc, ok := m["description"].(string); ok {
schema.Description = desc
}
// Enum
if enum, ok := m["enum"].([]any); ok {
for _, e := range enum {
if s, ok := e.(string); ok {
schema.Enum = append(schema.Enum, s)
}
}
}
// Properties (for objects)
if props, ok := m["properties"].(map[string]any); ok {
schema.Properties = make(map[string]*genai.Schema)
for k, v := range props {
if vm, ok := v.(map[string]any); ok {
schema.Properties[k] = mapToGenaiSchema(vm)
}
}
}
// Required
if req, ok := m["required"].([]any); ok {
for _, r := range req {
if s, ok := r.(string); ok {
schema.Required = append(schema.Required, s)
}
}
}
// Items (for arrays)
if items, ok := m["items"].(map[string]any); ok {
schema.Items = mapToGenaiSchema(items)
}
return schema
}

View File

@@ -3,8 +3,8 @@ package schema
import (
"reflect"
"github.com/google/generative-ai-go/genai"
"github.com/openai/openai-go"
"google.golang.org/genai"
)
type Type interface {

View File

@@ -1,7 +1,8 @@
package go_llm
package llm
import (
"context"
"encoding/json"
"errors"
"fmt"
)
@@ -11,6 +12,7 @@ import (
// the correct parameters.
type ToolBox struct {
functions map[string]Function
mcpServers map[string]*MCPServer // tool name -> MCP server that provides it
dontRequireTool bool
}
@@ -91,6 +93,18 @@ var (
)
func (t ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) {
// Check if this is an MCP tool
if server, ok := t.mcpServers[functionName]; ok {
var args map[string]any
if params != "" {
if err := json.Unmarshal([]byte(params), &args); err != nil {
return nil, fmt.Errorf("failed to parse MCP tool arguments: %w", err)
}
}
return server.CallTool(ctx, functionName, args)
}
// Regular function
f, ok := t.functions[functionName]
if !ok {

145
transcriber.go Normal file
View File

@@ -0,0 +1,145 @@
package llm
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
)
// Transcriber abstracts a speech-to-text model implementation.
type Transcriber interface {
Transcribe(ctx context.Context, wav []byte, opts TranscriptionOptions) (Transcription, error)
}
// TranscriptionResponseFormat controls the output format requested from a transcriber.
type TranscriptionResponseFormat string
const (
TranscriptionResponseFormatJSON TranscriptionResponseFormat = "json"
TranscriptionResponseFormatVerboseJSON TranscriptionResponseFormat = "verbose_json"
TranscriptionResponseFormatText TranscriptionResponseFormat = "text"
TranscriptionResponseFormatSRT TranscriptionResponseFormat = "srt"
TranscriptionResponseFormatVTT TranscriptionResponseFormat = "vtt"
)
// TranscriptionTimestampGranularity defines the requested timestamp detail.
type TranscriptionTimestampGranularity string
const (
TranscriptionTimestampGranularityWord TranscriptionTimestampGranularity = "word"
TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment"
)
// TranscriptionOptions configures transcription behavior.
type TranscriptionOptions struct {
Language string
Prompt string
Temperature *float64
ResponseFormat TranscriptionResponseFormat
TimestampGranularities []TranscriptionTimestampGranularity
IncludeLogprobs bool
}
// Transcription captures a normalized transcription result.
type Transcription struct {
Provider string
Model string
Text string
Language string
DurationSeconds float64
Segments []TranscriptionSegment
Words []TranscriptionWord
Logprobs []TranscriptionTokenLogprob
Usage TranscriptionUsage
RawJSON string
}
// TranscriptionSegment provides a coarse time-sliced transcription segment.
type TranscriptionSegment struct {
ID int
Start float64
End float64
Text string
Tokens []int
AvgLogprob *float64
CompressionRatio *float64
NoSpeechProb *float64
Words []TranscriptionWord
}
// TranscriptionWord provides a word-level timestamp.
type TranscriptionWord struct {
Word string
Start float64
End float64
Confidence *float64
}
// TranscriptionTokenLogprob captures token-level log probability details.
type TranscriptionTokenLogprob struct {
Token string
Bytes []float64
Logprob float64
}
// TranscriptionUsage captures token or duration usage details.
type TranscriptionUsage struct {
Type string
InputTokens int64
OutputTokens int64
TotalTokens int64
AudioTokens int64
TextTokens int64
Seconds float64
}
// TranscribeFile converts an audio file to WAV and transcribes it.
func TranscribeFile(ctx context.Context, filename string, transcriber Transcriber, opts TranscriptionOptions) (Transcription, error) {
if transcriber == nil {
return Transcription{}, fmt.Errorf("transcriber is nil")
}
wav, err := audioFileToWav(ctx, filename)
if err != nil {
return Transcription{}, err
}
return transcriber.Transcribe(ctx, wav, opts)
}
func audioFileToWav(ctx context.Context, filename string) ([]byte, error) {
if filename == "" {
return nil, fmt.Errorf("filename is empty")
}
if strings.EqualFold(filepath.Ext(filename), ".wav") {
data, err := os.ReadFile(filename)
if err != nil {
return nil, fmt.Errorf("read wav file: %w", err)
}
return data, nil
}
tempFile, err := os.CreateTemp("", "go-llm-audio-*.wav")
if err != nil {
return nil, fmt.Errorf("create temp wav file: %w", err)
}
tempPath := tempFile.Name()
_ = tempFile.Close()
defer os.Remove(tempPath)
cmd := exec.CommandContext(ctx, "ffmpeg", "-hide_banner", "-loglevel", "error", "-y", "-i", filename, "-vn", "-f", "wav", tempPath)
if output, err := cmd.CombinedOutput(); err != nil {
return nil, fmt.Errorf("ffmpeg convert failed: %w (output: %s)", err, strings.TrimSpace(string(output)))
}
data, err := os.ReadFile(tempPath)
if err != nil {
return nil, fmt.Errorf("read converted wav file: %w", err)
}
return data, nil
}

32
v2/CLAUDE.md Normal file
View File

@@ -0,0 +1,32 @@
# CLAUDE.md for go-llm v2
## Build and Test Commands
- Build project: `cd v2 && go build ./...`
- Run all tests: `cd v2 && go test ./...`
- Run specific test: `cd v2 && go test -v -run <TestName> ./...`
- Tidy dependencies: `cd v2 && go mod tidy`
- Vet: `cd v2 && go vet ./...`
## Code Style Guidelines
- **Indentation**: Standard Go tabs
- **Naming**: `camelCase` for unexported, `PascalCase` for exported
- **Error Handling**: Always check and handle errors immediately. Wrap with `fmt.Errorf("%w: ...", err)`
- **Imports**: Standard library first, then third-party, then internal packages
## Package Structure
- Root package `llm` — public API (Client, Model, Chat, ToolBox, Message types)
- `provider/` — Provider interface that backends implement
- `openai/`, `anthropic/`, `google/` — Provider implementations
- `tools/` — Ready-to-use sample tools (WebSearch, Browser, Exec, ReadFile, WriteFile, HTTP)
- `sandbox/` — Isolated Linux container environments via Proxmox LXC + SSH
- `internal/schema/` — JSON Schema generation from Go structs
- `internal/imageutil/` — Image compression utilities
## Key Design Decisions
1. Unified `Message` type instead of marker interfaces
2. `map[string]any` JSON Schema (no provider coupling)
3. Tool functions return `(string, error)`, use standard `context.Context`
4. `Chat.Send()` auto-loops tool calls; `Chat.SendRaw()` for manual control
5. MCP one-call connect: `MCPStdioServer(ctx, cmd, args...)`
6. Streaming via pull-based `StreamReader.Next()`
7. Middleware for logging, retry, timeout, usage tracking

113
v2/agent/agent.go Normal file
View File

@@ -0,0 +1,113 @@
// Package agent provides a simple agent abstraction built on top of go-llm.
//
// An Agent wraps a model, system prompt, and tools into a reusable unit.
// Agents can be turned into tools via AsTool, enabling parent agents to
// delegate work to specialized sub-agents through the normal tool-call loop.
//
// Example — orchestrator with sub-agents:
//
// researcher := agent.New(model, "You research topics via web search.",
// agent.WithTools(llm.NewToolBox(tools.WebSearch(apiKey))),
// )
// coder := agent.New(model, "You write and run code.",
// agent.WithTools(llm.NewToolBox(tools.Exec())),
// )
// orchestrator := agent.New(model, "You coordinate research and coding tasks.",
// agent.WithTools(llm.NewToolBox(
// researcher.AsTool("research", "Research a topic"),
// coder.AsTool("code", "Write and run code"),
// )),
// )
// result, err := orchestrator.Run(ctx, "Build a fibonacci function in Go")
package agent
import (
"context"
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
)
// Agent is a configured LLM agent with a system prompt and tools.
// Each call to Run creates a fresh conversation (no state is carried between runs).
type Agent struct {
model *llm.Model
system string
tools *llm.ToolBox
reqOpts []llm.RequestOption
}
// Option configures an Agent.
type Option func(*Agent)
// WithTools sets the tools available to the agent.
func WithTools(tb *llm.ToolBox) Option {
return func(a *Agent) { a.tools = tb }
}
// WithRequestOptions sets default request options (temperature, max tokens, etc.)
// applied to every completion call the agent makes.
func WithRequestOptions(opts ...llm.RequestOption) Option {
return func(a *Agent) { a.reqOpts = opts }
}
// New creates an agent with the given model and system prompt.
func New(model *llm.Model, system string, opts ...Option) *Agent {
a := &Agent{
model: model,
system: system,
}
for _, opt := range opts {
opt(a)
}
return a
}
// Run executes the agent with a user prompt. Each call is a fresh conversation.
// The agent loops tool calls automatically until it produces a text response.
func (a *Agent) Run(ctx context.Context, prompt string) (string, error) {
return a.RunMessages(ctx, []llm.Message{llm.UserMessage(prompt)})
}
// RunMessages executes the agent with full message control.
// Each call is a fresh conversation. The agent loops tool calls automatically.
func (a *Agent) RunMessages(ctx context.Context, messages []llm.Message) (string, error) {
chat := llm.NewChat(a.model, a.reqOpts...)
if a.system != "" {
chat.SetSystem(a.system)
}
if a.tools != nil {
chat.SetTools(a.tools)
}
// Send each message; the last one triggers the completion loop.
// All but the last are added as context.
for i, msg := range messages {
if i < len(messages)-1 {
chat.AddToolResults(msg) // AddToolResults just appends to history
continue
}
return chat.SendMessage(ctx, msg)
}
// Empty messages — send an empty user message
return chat.Send(ctx, "")
}
// delegateParams is the parameter struct for the tool created by AsTool.
type delegateParams struct {
Input string `json:"input" description:"The task or question to delegate to this agent"`
}
// AsTool creates a llm.Tool that delegates to this agent.
// When a parent agent calls this tool, it runs the agent with the provided input
// as the prompt and returns the agent's text response as the tool result.
//
// This enables sub-agent patterns where a parent agent can spawn specialized
// child agents through the normal tool-call mechanism.
func (a *Agent) AsTool(name, description string) llm.Tool {
return llm.Define[delegateParams](name, description,
func(ctx context.Context, p delegateParams) (string, error) {
return a.Run(ctx, p.Input)
},
)
}

244
v2/agent/agent_test.go Normal file
View File

@@ -0,0 +1,244 @@
package agent
import (
"context"
"errors"
"sync"
"testing"
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
// mockProvider is a test helper that implements provider.Provider.
type mockProvider struct {
mu sync.Mutex
completeFunc func(ctx context.Context, req provider.Request) (provider.Response, error)
requests []provider.Request
}
func (m *mockProvider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
m.mu.Lock()
m.requests = append(m.requests, req)
m.mu.Unlock()
return m.completeFunc(ctx, req)
}
func (m *mockProvider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
close(events)
return nil
}
func (m *mockProvider) lastRequest() provider.Request {
m.mu.Lock()
defer m.mu.Unlock()
if len(m.requests) == 0 {
return provider.Request{}
}
return m.requests[len(m.requests)-1]
}
func newMockModel(fn func(ctx context.Context, req provider.Request) (provider.Response, error)) *llm.Model {
mp := &mockProvider{completeFunc: fn}
return llm.NewClient(mp).Model("mock-model")
}
func newSimpleMockModel(text string) *llm.Model {
return newMockModel(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{Text: text}, nil
})
}
func TestAgent_Run(t *testing.T) {
model := newSimpleMockModel("Hello from agent!")
a := New(model, "You are a helpful assistant.")
result, err := a.Run(context.Background(), "Say hello")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != "Hello from agent!" {
t.Errorf("expected 'Hello from agent!', got %q", result)
}
}
func TestAgent_Run_WithTools(t *testing.T) {
callCount := 0
model := newMockModel(func(ctx context.Context, req provider.Request) (provider.Response, error) {
callCount++
if callCount == 1 {
// First call: model requests a tool call
return provider.Response{
ToolCalls: []provider.ToolCall{
{ID: "tc1", Name: "greet", Arguments: `{}`},
},
}, nil
}
// Second call: model returns text after seeing tool result
return provider.Response{Text: "Tool said: hello!"}, nil
})
tool := llm.DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) {
return "hello!", nil
})
a := New(model, "You are helpful.", WithTools(llm.NewToolBox(tool)))
result, err := a.Run(context.Background(), "Use the greet tool")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != "Tool said: hello!" {
t.Errorf("expected 'Tool said: hello!', got %q", result)
}
if callCount != 2 {
t.Errorf("expected 2 calls (tool loop), got %d", callCount)
}
}
func TestAgent_AsTool(t *testing.T) {
// Create a child agent
childModel := newSimpleMockModel("child result: 42")
child := New(childModel, "You compute things.")
// Create the tool from the child agent
childTool := child.AsTool("compute", "Delegate computation to child agent")
// Verify tool metadata
if childTool.Name != "compute" {
t.Errorf("expected tool name 'compute', got %q", childTool.Name)
}
if childTool.Description != "Delegate computation to child agent" {
t.Errorf("expected correct description, got %q", childTool.Description)
}
// Execute the tool directly (simulating what the parent's Chat.Send loop does)
result, err := childTool.Execute(context.Background(), `{"input":"what is 6*7?"}`)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != "child result: 42" {
t.Errorf("expected 'child result: 42', got %q", result)
}
}
func TestAgent_AsTool_ParentChild(t *testing.T) {
// Child agent that always returns a fixed result
childModel := newSimpleMockModel("researched: Go generics are great")
child := New(childModel, "You are a researcher.")
// Parent agent: first call returns tool call, second returns text
parentCallCount := 0
parentModel := newMockModel(func(ctx context.Context, req provider.Request) (provider.Response, error) {
parentCallCount++
if parentCallCount == 1 {
return provider.Response{
ToolCalls: []provider.ToolCall{
{ID: "tc1", Name: "research", Arguments: `{"input":"Tell me about Go generics"}`},
},
}, nil
}
// After getting tool result, parent synthesizes final answer
return provider.Response{Text: "Based on research: Go generics are great"}, nil
})
parent := New(parentModel, "You coordinate tasks.",
WithTools(llm.NewToolBox(
child.AsTool("research", "Research a topic"),
)),
)
result, err := parent.Run(context.Background(), "Tell me about Go generics")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != "Based on research: Go generics are great" {
t.Errorf("expected synthesized result, got %q", result)
}
if parentCallCount != 2 {
t.Errorf("expected 2 parent calls (tool loop), got %d", parentCallCount)
}
}
func TestAgent_RunMessages(t *testing.T) {
model := newSimpleMockModel("I see the system and user messages")
a := New(model, "You are helpful.")
messages := []llm.Message{
llm.UserMessage("First question"),
llm.AssistantMessage("First answer"),
llm.UserMessage("Follow up"),
}
result, err := a.RunMessages(context.Background(), messages)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != "I see the system and user messages" {
t.Errorf("unexpected result: %q", result)
}
}
func TestAgent_ContextCancellation(t *testing.T) {
model := newMockModel(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, ctx.Err()
})
a := New(model, "You are helpful.")
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
_, err := a.Run(ctx, "This should fail")
if err == nil {
t.Fatal("expected error from cancelled context")
}
}
func TestAgent_WithRequestOptions(t *testing.T) {
var capturedReq provider.Request
model := newMockModel(func(ctx context.Context, req provider.Request) (provider.Response, error) {
capturedReq = req
return provider.Response{Text: "ok"}, nil
})
a := New(model, "You are helpful.",
WithRequestOptions(llm.WithTemperature(0.3), llm.WithMaxTokens(100)),
)
_, err := a.Run(context.Background(), "test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if capturedReq.Temperature == nil || *capturedReq.Temperature != 0.3 {
t.Errorf("expected temperature 0.3, got %v", capturedReq.Temperature)
}
if capturedReq.MaxTokens == nil || *capturedReq.MaxTokens != 100 {
t.Errorf("expected maxTokens 100, got %v", capturedReq.MaxTokens)
}
}
func TestAgent_Run_Error(t *testing.T) {
wantErr := errors.New("model failed")
model := newMockModel(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, wantErr
})
a := New(model, "You are helpful.")
_, err := a.Run(context.Background(), "test")
if err == nil {
t.Fatal("expected error, got nil")
}
}
func TestAgent_EmptySystem(t *testing.T) {
model := newSimpleMockModel("no system prompt")
a := New(model, "") // Empty system prompt
result, err := a.Run(context.Background(), "test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != "no system prompt" {
t.Errorf("unexpected result: %q", result)
}
}

107
v2/agent/example_test.go Normal file
View File

@@ -0,0 +1,107 @@
package agent_test
import (
"context"
"fmt"
"os"
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/agent"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/tools"
)
// A researcher agent that can search the web and browse pages.
func Example_researcher() {
model := llm.OpenAI(os.Getenv("OPENAI_API_KEY")).Model("gpt-4o")
researcher := agent.New(model,
"You are a research assistant. Use web search to find information, "+
"then use the browser to read full articles when needed. "+
"Provide a concise summary of your findings.",
agent.WithTools(llm.NewToolBox(
tools.WebSearch(os.Getenv("BRAVE_API_KEY")),
tools.Browser(),
)),
agent.WithRequestOptions(llm.WithTemperature(0.3)),
)
result, err := researcher.Run(context.Background(), "What are the latest developments in Go generics?")
if err != nil {
fmt.Println("Error:", err)
return
}
fmt.Println(result)
}
// A coder agent that can read, write, and execute code.
func Example_coder() {
model := llm.OpenAI(os.Getenv("OPENAI_API_KEY")).Model("gpt-4o")
coder := agent.New(model,
"You are a coding assistant. You can read files, write files, and execute commands. "+
"When asked to create a program, write the code to a file and then run it to verify it works.",
agent.WithTools(llm.NewToolBox(
tools.ReadFile(),
tools.WriteFile(),
tools.Exec(
tools.WithAllowedCommands([]string{"go", "python", "node", "cat", "ls"}),
tools.WithWorkDir(os.TempDir()),
),
)),
)
result, err := coder.Run(context.Background(),
"Create a Go program that prints the first 10 Fibonacci numbers. Save it and run it.")
if err != nil {
fmt.Println("Error:", err)
return
}
fmt.Println(result)
}
// An orchestrator agent that delegates to specialized sub-agents.
// The orchestrator breaks a complex task into subtasks and dispatches them
// to the appropriate sub-agent via tool calls.
func Example_orchestrator() {
model := llm.OpenAI(os.Getenv("OPENAI_API_KEY")).Model("gpt-4o")
// Specialized sub-agents
researcher := agent.New(model,
"You are a research assistant. Search the web for information on the given topic "+
"and return a concise summary.",
agent.WithTools(llm.NewToolBox(
tools.WebSearch(os.Getenv("BRAVE_API_KEY")),
)),
)
coder := agent.New(model,
"You are a coding assistant. Write and test code as requested. "+
"Save files and run them to verify they work.",
agent.WithTools(llm.NewToolBox(
tools.ReadFile(),
tools.WriteFile(),
tools.Exec(tools.WithAllowedCommands([]string{"go", "python"})),
)),
)
// Orchestrator can delegate to both sub-agents
orchestrator := agent.New(model,
"You are a project manager. Break complex tasks into research and coding subtasks. "+
"Use delegate_research for information gathering and delegate_coding for implementation. "+
"Synthesize the results into a final answer.",
agent.WithTools(llm.NewToolBox(
researcher.AsTool("delegate_research",
"Delegate a research task. Provide a clear question or topic to research."),
coder.AsTool("delegate_coding",
"Delegate a coding task. Provide clear requirements for what to implement."),
)),
)
result, err := orchestrator.Run(context.Background(),
"Research how to implement a binary search tree in Go, then create one with insert and search operations.")
if err != nil {
fmt.Println("Error:", err)
return
}
fmt.Println(result)
}

275
v2/anthropic/anthropic.go Normal file
View File

@@ -0,0 +1,275 @@
// Package anthropic implements the go-llm v2 provider interface for Anthropic.
package anthropic
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/internal/imageutil"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
anth "github.com/liushuangls/go-anthropic/v2"
)
// Provider implements the provider.Provider interface for Anthropic.
type Provider struct {
apiKey string
}
// New creates a new Anthropic provider.
func New(apiKey string) *Provider {
return &Provider{apiKey: apiKey}
}
// Complete performs a non-streaming completion.
func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
cl := anth.NewClient(p.apiKey)
anthReq := p.buildRequest(req)
resp, err := cl.CreateMessages(ctx, anthReq)
if err != nil {
return provider.Response{}, fmt.Errorf("anthropic completion error: %w", err)
}
return p.convertResponse(resp), nil
}
// Stream performs a streaming completion.
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
cl := anth.NewClient(p.apiKey)
anthReq := p.buildRequest(req)
resp, err := cl.CreateMessagesStream(ctx, anth.MessagesStreamRequest{
MessagesRequest: anthReq,
OnContentBlockDelta: func(data anth.MessagesEventContentBlockDeltaData) {
if data.Delta.Type == "text_delta" && data.Delta.Text != nil {
events <- provider.StreamEvent{
Type: provider.StreamEventText,
Text: *data.Delta.Text,
}
}
},
})
if err != nil {
return fmt.Errorf("anthropic stream error: %w", err)
}
result := p.convertResponse(resp)
events <- provider.StreamEvent{
Type: provider.StreamEventDone,
Response: &result,
}
return nil
}
func (p *Provider) buildRequest(req provider.Request) anth.MessagesRequest {
anthReq := anth.MessagesRequest{
Model: anth.Model(req.Model),
MaxTokens: 4096,
}
if req.MaxTokens != nil {
anthReq.MaxTokens = *req.MaxTokens
}
var msgs []anth.Message
for _, msg := range req.Messages {
if msg.Role == "system" {
if len(anthReq.System) > 0 {
anthReq.System += "\n"
}
anthReq.System += msg.Content
continue
}
if msg.Role == "tool" {
// Tool results in Anthropic format - use the helper
toolUseID := msg.ToolCallID
content := msg.Content
isError := false
msgs = append(msgs, anth.Message{
Role: anth.RoleUser,
Content: []anth.MessageContent{
{
Type: anth.MessagesContentTypeToolResult,
MessageContentToolResult: &anth.MessageContentToolResult{
ToolUseID: &toolUseID,
Content: []anth.MessageContent{
{
Type: anth.MessagesContentTypeText,
Text: &content,
},
},
IsError: &isError,
},
},
},
})
continue
}
role := anth.RoleUser
if msg.Role == "assistant" {
role = anth.RoleAssistant
}
m := anth.Message{
Role: role,
Content: []anth.MessageContent{},
}
if msg.Content != "" {
m.Content = append(m.Content, anth.MessageContent{
Type: anth.MessagesContentTypeText,
Text: &msg.Content,
})
}
// Handle tool calls in assistant messages
for _, tc := range msg.ToolCalls {
var input json.RawMessage
if tc.Arguments != "" {
input = json.RawMessage(tc.Arguments)
} else {
input = json.RawMessage("{}")
}
m.Content = append(m.Content, anth.MessageContent{
Type: anth.MessagesContentTypeToolUse,
MessageContentToolUse: &anth.MessageContentToolUse{
ID: tc.ID,
Name: tc.Name,
Input: input,
},
})
}
// Handle images
for _, img := range msg.Images {
if role == anth.RoleAssistant {
role = anth.RoleUser
m.Role = anth.RoleUser
}
if img.Base64 != "" {
b64 := img.Base64
contentType := img.ContentType
// Compress if > 5MiB
raw, err := base64.StdEncoding.DecodeString(b64)
if err == nil && len(raw) >= 5242880 {
compressed, mime, cerr := imageutil.CompressImage(b64, 5*1024*1024)
if cerr == nil {
b64 = compressed
contentType = mime
}
}
m.Content = append(m.Content, anth.NewImageMessageContent(
anth.NewMessageContentSource(
anth.MessagesContentSourceTypeBase64,
contentType,
b64,
)))
} else if img.URL != "" {
// Download and convert to base64 (Anthropic doesn't support URLs directly)
resp, err := http.Get(img.URL)
if err != nil {
continue
}
data, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
continue
}
contentType := resp.Header.Get("Content-Type")
b64 := base64.StdEncoding.EncodeToString(data)
m.Content = append(m.Content, anth.NewImageMessageContent(
anth.NewMessageContentSource(
anth.MessagesContentSourceTypeBase64,
contentType,
b64,
)))
}
}
// Audio is not supported by Anthropic — skip silently.
// Merge consecutive same-role messages (Anthropic requires alternating)
if len(msgs) > 0 && msgs[len(msgs)-1].Role == role {
msgs[len(msgs)-1].Content = append(msgs[len(msgs)-1].Content, m.Content...)
} else {
msgs = append(msgs, m)
}
}
for _, tool := range req.Tools {
anthReq.Tools = append(anthReq.Tools, anth.ToolDefinition{
Name: tool.Name,
Description: tool.Description,
InputSchema: tool.Schema,
})
}
anthReq.Messages = msgs
if req.Temperature != nil {
f := float32(*req.Temperature)
anthReq.Temperature = &f
}
if req.TopP != nil {
f := float32(*req.TopP)
anthReq.TopP = &f
}
if len(req.Stop) > 0 {
anthReq.StopSequences = req.Stop
}
return anthReq
}
func (p *Provider) convertResponse(resp anth.MessagesResponse) provider.Response {
var res provider.Response
var textParts []string
for _, block := range resp.Content {
switch block.Type {
case anth.MessagesContentTypeText:
if block.Text != nil {
textParts = append(textParts, *block.Text)
}
case anth.MessagesContentTypeToolUse:
if block.MessageContentToolUse != nil {
args, _ := json.Marshal(block.MessageContentToolUse.Input)
res.ToolCalls = append(res.ToolCalls, provider.ToolCall{
ID: block.MessageContentToolUse.ID,
Name: block.MessageContentToolUse.Name,
Arguments: string(args),
})
}
}
}
res.Text = strings.Join(textParts, "")
res.Usage = &provider.Usage{
InputTokens: resp.Usage.InputTokens,
OutputTokens: resp.Usage.OutputTokens,
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
}
return res
}

153
v2/chat.go Normal file
View File

@@ -0,0 +1,153 @@
package llm
import (
"context"
"fmt"
)
// Chat manages a multi-turn conversation with automatic history tracking
// and optional automatic tool-call execution.
type Chat struct {
model *Model
messages []Message
tools *ToolBox
opts []RequestOption
}
// NewChat creates a new conversation with the given model.
func NewChat(model *Model, opts ...RequestOption) *Chat {
return &Chat{
model: model,
opts: opts,
}
}
// SetSystem sets or replaces the system message.
func (c *Chat) SetSystem(text string) {
filtered := make([]Message, 0, len(c.messages)+1)
for _, m := range c.messages {
if m.Role != RoleSystem {
filtered = append(filtered, m)
}
}
c.messages = append([]Message{SystemMessage(text)}, filtered...)
}
// SetTools configures the tools available for this chat.
func (c *Chat) SetTools(tb *ToolBox) {
c.tools = tb
}
// Send sends a user message and returns the assistant's text response.
// If the model calls tools, they are executed automatically and the loop
// continues until the model produces a text response (the "agent loop").
func (c *Chat) Send(ctx context.Context, text string) (string, error) {
return c.SendMessage(ctx, UserMessage(text))
}
// SendWithImages sends a user message with images attached.
func (c *Chat) SendWithImages(ctx context.Context, text string, images ...Image) (string, error) {
return c.SendMessage(ctx, UserMessageWithImages(text, images...))
}
// SendMessage sends an arbitrary message and returns the final text response.
// Handles the full tool-call loop automatically.
func (c *Chat) SendMessage(ctx context.Context, msg Message) (string, error) {
c.messages = append(c.messages, msg)
opts := c.buildOpts()
for {
resp, err := c.model.Complete(ctx, c.messages, opts...)
if err != nil {
return "", fmt.Errorf("completion failed: %w", err)
}
c.messages = append(c.messages, resp.Message())
if !resp.HasToolCalls() {
return resp.Text, nil
}
if c.tools == nil {
return "", ErrNoToolsConfigured
}
toolResults, err := c.tools.ExecuteAll(ctx, resp.ToolCalls)
if err != nil {
return "", fmt.Errorf("tool execution failed: %w", err)
}
c.messages = append(c.messages, toolResults...)
}
}
// SendRaw sends a message and returns the raw Response without automatic tool execution.
// Useful when you want to handle tool calls manually.
func (c *Chat) SendRaw(ctx context.Context, msg Message) (Response, error) {
c.messages = append(c.messages, msg)
opts := c.buildOpts()
resp, err := c.model.Complete(ctx, c.messages, opts...)
if err != nil {
return Response{}, err
}
c.messages = append(c.messages, resp.Message())
return resp, nil
}
// SendStream sends a user message and returns a StreamReader for streaming responses.
func (c *Chat) SendStream(ctx context.Context, text string) (*StreamReader, error) {
c.messages = append(c.messages, UserMessage(text))
opts := c.buildOpts()
cfg := &requestConfig{}
for _, opt := range opts {
opt(cfg)
}
req := buildProviderRequest(c.model.model, c.messages, cfg)
return newStreamReader(ctx, c.model.provider, req)
}
// AddToolResults manually adds tool results to the conversation.
// Use with SendRaw when handling tool calls manually.
func (c *Chat) AddToolResults(results ...Message) {
c.messages = append(c.messages, results...)
}
// Messages returns the current conversation history (read-only copy).
func (c *Chat) Messages() []Message {
cp := make([]Message, len(c.messages))
copy(cp, c.messages)
return cp
}
// Reset clears the conversation history.
func (c *Chat) Reset() {
c.messages = nil
}
// Fork creates a copy of this chat with identical history, for branching conversations.
func (c *Chat) Fork() *Chat {
c2 := &Chat{
model: c.model,
messages: make([]Message, len(c.messages)),
tools: c.tools,
opts: c.opts,
}
copy(c2.messages, c.messages)
return c2
}
func (c *Chat) buildOpts() []RequestOption {
opts := make([]RequestOption, len(c.opts))
copy(opts, c.opts)
if c.tools != nil {
opts = append(opts, WithTools(c.tools))
}
return opts
}

407
v2/chat_test.go Normal file
View File

@@ -0,0 +1,407 @@
package llm
import (
"context"
"errors"
"sync/atomic"
"testing"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
func TestChat_Send(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "Hello there!"})
model := newMockModel(mp)
chat := NewChat(model)
text, err := chat.Send(context.Background(), "Hi")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if text != "Hello there!" {
t.Errorf("expected 'Hello there!', got %q", text)
}
}
func TestChat_SendMessage(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "reply"})
model := newMockModel(mp)
chat := NewChat(model)
_, err := chat.SendMessage(context.Background(), UserMessage("msg1"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
msgs := chat.Messages()
if len(msgs) != 2 {
t.Fatalf("expected 2 messages (user + assistant), got %d", len(msgs))
}
if msgs[0].Role != RoleUser {
t.Errorf("expected first message role=user, got %v", msgs[0].Role)
}
if msgs[0].Content.Text != "msg1" {
t.Errorf("expected first message text='msg1', got %q", msgs[0].Content.Text)
}
if msgs[1].Role != RoleAssistant {
t.Errorf("expected second message role=assistant, got %v", msgs[1].Role)
}
if msgs[1].Content.Text != "reply" {
t.Errorf("expected second message text='reply', got %q", msgs[1].Content.Text)
}
}
func TestChat_SetSystem(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp)
chat := NewChat(model)
chat.SetSystem("You are a bot")
msgs := chat.Messages()
if len(msgs) != 1 {
t.Fatalf("expected 1 message, got %d", len(msgs))
}
if msgs[0].Role != RoleSystem {
t.Errorf("expected role=system, got %v", msgs[0].Role)
}
if msgs[0].Content.Text != "You are a bot" {
t.Errorf("expected system text, got %q", msgs[0].Content.Text)
}
// Replace system message
chat.SetSystem("You are a helpful bot")
msgs = chat.Messages()
if len(msgs) != 1 {
t.Fatalf("expected 1 message after replace, got %d", len(msgs))
}
if msgs[0].Content.Text != "You are a helpful bot" {
t.Errorf("expected replaced system text, got %q", msgs[0].Content.Text)
}
// System message stays first even after adding other messages
_, _ = chat.Send(context.Background(), "Hi")
chat.SetSystem("New system")
msgs = chat.Messages()
if msgs[0].Role != RoleSystem {
t.Errorf("expected system as first message, got %v", msgs[0].Role)
}
if msgs[0].Content.Text != "New system" {
t.Errorf("expected 'New system', got %q", msgs[0].Content.Text)
}
}
func TestChat_ToolCallLoop(t *testing.T) {
var callCount int32
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
n := atomic.AddInt32(&callCount, 1)
if n == 1 {
// First call: request a tool
return provider.Response{
ToolCalls: []provider.ToolCall{
{ID: "tc1", Name: "greet", Arguments: "{}"},
},
}, nil
}
// Second call: return text
return provider.Response{Text: "done"}, nil
})
model := newMockModel(mp)
chat := NewChat(model)
tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) {
return "hello!", nil
})
chat.SetTools(NewToolBox(tool))
text, err := chat.Send(context.Background(), "test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if text != "done" {
t.Errorf("expected 'done', got %q", text)
}
if atomic.LoadInt32(&callCount) != 2 {
t.Errorf("expected 2 provider calls, got %d", callCount)
}
// Check message history: user, assistant (tool call), tool result, assistant (text)
msgs := chat.Messages()
if len(msgs) != 4 {
t.Fatalf("expected 4 messages, got %d", len(msgs))
}
if msgs[0].Role != RoleUser {
t.Errorf("msg[0]: expected user, got %v", msgs[0].Role)
}
if msgs[1].Role != RoleAssistant {
t.Errorf("msg[1]: expected assistant, got %v", msgs[1].Role)
}
if len(msgs[1].ToolCalls) != 1 {
t.Errorf("msg[1]: expected 1 tool call, got %d", len(msgs[1].ToolCalls))
}
if msgs[2].Role != RoleTool {
t.Errorf("msg[2]: expected tool, got %v", msgs[2].Role)
}
if msgs[2].Content.Text != "hello!" {
t.Errorf("msg[2]: expected 'hello!', got %q", msgs[2].Content.Text)
}
if msgs[3].Role != RoleAssistant {
t.Errorf("msg[3]: expected assistant, got %v", msgs[3].Role)
}
}
func TestChat_ToolCallLoop_NoTools(t *testing.T) {
mp := newMockProvider(provider.Response{
ToolCalls: []provider.ToolCall{
{ID: "tc1", Name: "fake", Arguments: "{}"},
},
})
model := newMockModel(mp)
chat := NewChat(model)
_, err := chat.Send(context.Background(), "test")
if !errors.Is(err, ErrNoToolsConfigured) {
t.Errorf("expected ErrNoToolsConfigured, got %v", err)
}
}
func TestChat_SendRaw(t *testing.T) {
mp := newMockProvider(provider.Response{
Text: "raw response",
ToolCalls: []provider.ToolCall{
{ID: "tc1", Name: "tool1", Arguments: `{"x":1}`},
},
})
model := newMockModel(mp)
chat := NewChat(model)
resp, err := chat.SendRaw(context.Background(), UserMessage("test"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "raw response" {
t.Errorf("expected 'raw response', got %q", resp.Text)
}
if !resp.HasToolCalls() {
t.Error("expected HasToolCalls() to be true")
}
if len(resp.ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(resp.ToolCalls))
}
if resp.ToolCalls[0].Name != "tool1" {
t.Errorf("expected tool name 'tool1', got %q", resp.ToolCalls[0].Name)
}
}
func TestChat_SendRaw_ManualToolResults(t *testing.T) {
var callCount int32
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
n := atomic.AddInt32(&callCount, 1)
if n == 1 {
return provider.Response{
ToolCalls: []provider.ToolCall{
{ID: "tc1", Name: "tool1", Arguments: "{}"},
},
}, nil
}
return provider.Response{Text: "final"}, nil
})
model := newMockModel(mp)
chat := NewChat(model)
// First call returns tool calls
resp, err := chat.SendRaw(context.Background(), UserMessage("test"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !resp.HasToolCalls() {
t.Fatal("expected tool calls")
}
// Manually add tool result
chat.AddToolResults(ToolResultMessage("tc1", "tool result"))
// Second call returns text
resp, err = chat.SendRaw(context.Background(), UserMessage("continue"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "final" {
t.Errorf("expected 'final', got %q", resp.Text)
}
// Check the full history
msgs := chat.Messages()
// user, assistant(tool call), tool result, user, assistant(text)
if len(msgs) != 5 {
t.Fatalf("expected 5 messages, got %d", len(msgs))
}
if msgs[2].Role != RoleTool {
t.Errorf("expected msg[2] role=tool, got %v", msgs[2].Role)
}
if msgs[2].ToolCallID != "tc1" {
t.Errorf("expected msg[2] toolCallID=tc1, got %q", msgs[2].ToolCallID)
}
}
func TestChat_Messages(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp)
chat := NewChat(model)
_, _ = chat.Send(context.Background(), "test")
msgs := chat.Messages()
// Verify it's a copy — modifying returned slice shouldn't affect chat
msgs[0] = Message{}
original := chat.Messages()
if original[0].Role != RoleUser {
t.Error("Messages() did not return a copy")
}
}
func TestChat_Reset(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp)
chat := NewChat(model)
_, _ = chat.Send(context.Background(), "test")
if len(chat.Messages()) == 0 {
t.Fatal("expected messages before reset")
}
chat.Reset()
if len(chat.Messages()) != 0 {
t.Errorf("expected 0 messages after reset, got %d", len(chat.Messages()))
}
}
func TestChat_Fork(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp)
chat := NewChat(model)
_, _ = chat.Send(context.Background(), "msg1")
fork := chat.Fork()
// Fork should have same history
if len(fork.Messages()) != len(chat.Messages()) {
t.Fatalf("fork should have same message count: got %d vs %d", len(fork.Messages()), len(chat.Messages()))
}
// Adding to fork should not affect original
_, _ = fork.Send(context.Background(), "msg2")
if len(fork.Messages()) == len(chat.Messages()) {
t.Error("fork messages should be independent of original")
}
// Adding to original should not affect fork
originalLen := len(chat.Messages())
_, _ = chat.Send(context.Background(), "msg3")
if len(chat.Messages()) == originalLen {
t.Error("original should have more messages after send")
}
}
func TestChat_SendWithImages(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "I see an image"})
model := newMockModel(mp)
chat := NewChat(model)
img := Image{URL: "https://example.com/image.png"}
text, err := chat.SendWithImages(context.Background(), "What's in this image?", img)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if text != "I see an image" {
t.Errorf("expected 'I see an image', got %q", text)
}
// Verify the image was passed through to the provider
req := mp.lastRequest()
if len(req.Messages) == 0 {
t.Fatal("expected messages in request")
}
lastUserMsg := req.Messages[0]
if len(lastUserMsg.Images) != 1 {
t.Fatalf("expected 1 image, got %d", len(lastUserMsg.Images))
}
if lastUserMsg.Images[0].URL != "https://example.com/image.png" {
t.Errorf("expected image URL, got %q", lastUserMsg.Images[0].URL)
}
}
func TestChat_MultipleToolCallRounds(t *testing.T) {
var callCount int32
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
n := atomic.AddInt32(&callCount, 1)
if n <= 3 {
return provider.Response{
ToolCalls: []provider.ToolCall{
{ID: "tc" + string(rune('0'+n)), Name: "counter", Arguments: "{}"},
},
}, nil
}
return provider.Response{Text: "all done"}, nil
})
model := newMockModel(mp)
chat := NewChat(model)
var execCount int32
tool := DefineSimple("counter", "Counts", func(ctx context.Context) (string, error) {
atomic.AddInt32(&execCount, 1)
return "counted", nil
})
chat.SetTools(NewToolBox(tool))
text, err := chat.Send(context.Background(), "count three times")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if text != "all done" {
t.Errorf("expected 'all done', got %q", text)
}
if atomic.LoadInt32(&callCount) != 4 {
t.Errorf("expected 4 provider calls, got %d", callCount)
}
if atomic.LoadInt32(&execCount) != 3 {
t.Errorf("expected 3 tool executions, got %d", execCount)
}
}
func TestChat_SendError(t *testing.T) {
wantErr := errors.New("provider failed")
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, wantErr
})
model := newMockModel(mp)
chat := NewChat(model)
_, err := chat.Send(context.Background(), "test")
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, wantErr) {
t.Errorf("expected wrapped provider error, got %v", err)
}
}
func TestChat_WithRequestOptions(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp)
chat := NewChat(model, WithTemperature(0.5), WithMaxTokens(200))
_, err := chat.Send(context.Background(), "test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
req := mp.lastRequest()
if req.Temperature == nil || *req.Temperature != 0.5 {
t.Errorf("expected temperature 0.5, got %v", req.Temperature)
}
if req.MaxTokens == nil || *req.MaxTokens != 200 {
t.Errorf("expected maxTokens 200, got %v", req.MaxTokens)
}
}

48
v2/constructors.go Normal file
View File

@@ -0,0 +1,48 @@
package llm
import (
anthProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/anthropic"
googleProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/google"
openaiProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openai"
)
// OpenAI creates an OpenAI client.
//
// Example:
//
// model := llm.OpenAI("sk-...").Model("gpt-4o")
func OpenAI(apiKey string, opts ...ClientOption) *Client {
cfg := &clientConfig{}
for _, opt := range opts {
opt(cfg)
}
return NewClient(openaiProvider.New(apiKey, cfg.baseURL))
}
// Anthropic creates an Anthropic client.
//
// Example:
//
// model := llm.Anthropic("sk-ant-...").Model("claude-sonnet-4-20250514")
func Anthropic(apiKey string, opts ...ClientOption) *Client {
cfg := &clientConfig{}
for _, opt := range opts {
opt(cfg)
}
_ = cfg // Anthropic doesn't support custom base URL in the SDK
return NewClient(anthProvider.New(apiKey))
}
// Google creates a Google (Gemini) client.
//
// Example:
//
// model := llm.Google("...").Model("gemini-2.0-flash")
func Google(apiKey string, opts ...ClientOption) *Client {
cfg := &clientConfig{}
for _, opt := range opts {
opt(cfg)
}
_ = cfg // Google doesn't support custom base URL in the SDK
return NewClient(googleProvider.New(apiKey))
}

20
v2/errors.go Normal file
View File

@@ -0,0 +1,20 @@
package llm
import "errors"
var (
// ErrNoToolsConfigured is returned when the model requests tool calls but no tools are available.
ErrNoToolsConfigured = errors.New("model requested tool calls but no tools configured")
// ErrToolNotFound is returned when a requested tool is not in the toolbox.
ErrToolNotFound = errors.New("tool not found")
// ErrNotConnected is returned when trying to use an MCP server that isn't connected.
ErrNotConnected = errors.New("MCP server not connected")
// ErrStreamClosed is returned when trying to read from a closed stream.
ErrStreamClosed = errors.New("stream closed")
// ErrNoStructuredOutput is returned when the model did not return a structured output tool call.
ErrNoStructuredOutput = errors.New("model did not return structured output")
)

54
v2/generate.go Normal file
View File

@@ -0,0 +1,54 @@
package llm
import (
"context"
"encoding/json"
"fmt"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/internal/schema"
)
const structuredOutputToolName = "structured_output"
// Generate sends a single user prompt to the model and parses the response into T.
// T must be a struct. The model is forced to return structured output matching T's schema
// by using a hidden tool call internally.
func Generate[T any](ctx context.Context, model *Model, prompt string, opts ...RequestOption) (T, error) {
return GenerateWith[T](ctx, model, []Message{UserMessage(prompt)}, opts...)
}
// GenerateWith sends the given messages to the model and parses the response into T.
// T must be a struct. The model is forced to return structured output matching T's schema
// by using a hidden tool call internally.
func GenerateWith[T any](ctx context.Context, model *Model, messages []Message, opts ...RequestOption) (T, error) {
var zero T
s := schema.FromStruct(zero)
tool := Tool{
Name: structuredOutputToolName,
Description: "Return your response as structured data using this function. You MUST call this function with your response.",
Schema: s,
}
// Append WithTools as the last option so it overrides any user-provided tools.
opts = append(opts, WithTools(NewToolBox(tool)))
resp, err := model.Complete(ctx, messages, opts...)
if err != nil {
return zero, err
}
// Find the structured_output tool call in the response.
for _, tc := range resp.ToolCalls {
if tc.Name == structuredOutputToolName {
var result T
if err := json.Unmarshal([]byte(tc.Arguments), &result); err != nil {
return zero, fmt.Errorf("failed to parse structured output: %w", err)
}
return result, nil
}
}
return zero, ErrNoStructuredOutput
}

241
v2/generate_test.go Normal file
View File

@@ -0,0 +1,241 @@
package llm
import (
"context"
"errors"
"testing"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
type testPerson struct {
Name string `json:"name" description:"The person's name"`
Age int `json:"age" description:"The person's age"`
}
func TestGenerate(t *testing.T) {
mp := newMockProvider(provider.Response{
ToolCalls: []provider.ToolCall{
{
ID: "call_1",
Name: "structured_output",
Arguments: `{"name":"Alice","age":30}`,
},
},
})
model := newMockModel(mp)
result, err := Generate[testPerson](context.Background(), model, "Tell me about Alice")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Name != "Alice" {
t.Errorf("expected name 'Alice', got %q", result.Name)
}
if result.Age != 30 {
t.Errorf("expected age 30, got %d", result.Age)
}
// Verify the tool was sent in the request
req := mp.lastRequest()
if len(req.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(req.Tools))
}
if req.Tools[0].Name != "structured_output" {
t.Errorf("expected tool name 'structured_output', got %q", req.Tools[0].Name)
}
}
func TestGenerateWith(t *testing.T) {
mp := newMockProvider(provider.Response{
ToolCalls: []provider.ToolCall{
{
ID: "call_1",
Name: "structured_output",
Arguments: `{"name":"Bob","age":25}`,
},
},
})
model := newMockModel(mp)
messages := []Message{
SystemMessage("You are helpful."),
UserMessage("Tell me about Bob"),
}
result, err := GenerateWith[testPerson](context.Background(), model, messages)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Name != "Bob" {
t.Errorf("expected name 'Bob', got %q", result.Name)
}
if result.Age != 25 {
t.Errorf("expected age 25, got %d", result.Age)
}
// Verify messages were passed through
req := mp.lastRequest()
if len(req.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(req.Messages))
}
if req.Messages[0].Role != "system" {
t.Errorf("expected first message role 'system', got %q", req.Messages[0].Role)
}
}
func TestGenerate_NoToolCall(t *testing.T) {
mp := newMockProvider(provider.Response{
Text: "I can't use tools right now.",
})
model := newMockModel(mp)
_, err := Generate[testPerson](context.Background(), model, "Tell me about someone")
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, ErrNoStructuredOutput) {
t.Errorf("expected ErrNoStructuredOutput, got %v", err)
}
}
func TestGenerate_InvalidJSON(t *testing.T) {
mp := newMockProvider(provider.Response{
ToolCalls: []provider.ToolCall{
{
ID: "call_1",
Name: "structured_output",
Arguments: `{not valid json}`,
},
},
})
model := newMockModel(mp)
_, err := Generate[testPerson](context.Background(), model, "Tell me about someone")
if err == nil {
t.Fatal("expected error, got nil")
}
if errors.Is(err, ErrNoStructuredOutput) {
t.Error("expected parse error, not ErrNoStructuredOutput")
}
}
type testAddress struct {
Street string `json:"street" description:"Street address"`
City string `json:"city" description:"City name"`
}
type testPersonWithAddress struct {
Name string `json:"name" description:"The person's name"`
Age int `json:"age" description:"The person's age"`
Address testAddress `json:"address" description:"The person's address"`
}
func TestGenerate_NestedStruct(t *testing.T) {
mp := newMockProvider(provider.Response{
ToolCalls: []provider.ToolCall{
{
ID: "call_1",
Name: "structured_output",
Arguments: `{"name":"Carol","age":40,"address":{"street":"123 Main St","city":"Springfield"}}`,
},
},
})
model := newMockModel(mp)
result, err := Generate[testPersonWithAddress](context.Background(), model, "Tell me about Carol")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Name != "Carol" {
t.Errorf("expected name 'Carol', got %q", result.Name)
}
if result.Address.Street != "123 Main St" {
t.Errorf("expected street '123 Main St', got %q", result.Address.Street)
}
if result.Address.City != "Springfield" {
t.Errorf("expected city 'Springfield', got %q", result.Address.City)
}
}
func TestGenerate_WithOptions(t *testing.T) {
mp := newMockProvider(provider.Response{
ToolCalls: []provider.ToolCall{
{
ID: "call_1",
Name: "structured_output",
Arguments: `{"name":"Dave","age":35}`,
},
},
})
model := newMockModel(mp)
_, err := Generate[testPerson](context.Background(), model, "Tell me about Dave",
WithTemperature(0.5),
WithMaxTokens(200),
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
req := mp.lastRequest()
if req.Temperature == nil || *req.Temperature != 0.5 {
t.Errorf("expected temperature 0.5, got %v", req.Temperature)
}
if req.MaxTokens == nil || *req.MaxTokens != 200 {
t.Errorf("expected maxTokens 200, got %v", req.MaxTokens)
}
}
func TestGenerate_WithMiddleware(t *testing.T) {
var middlewareCalled bool
mw := func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
middlewareCalled = true
return next(ctx, model, messages, cfg)
}
}
mp := newMockProvider(provider.Response{
ToolCalls: []provider.ToolCall{
{
ID: "call_1",
Name: "structured_output",
Arguments: `{"name":"Eve","age":28}`,
},
},
})
model := newMockModel(mp).WithMiddleware(mw)
result, err := Generate[testPerson](context.Background(), model, "Tell me about Eve")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !middlewareCalled {
t.Error("middleware was not called")
}
if result.Name != "Eve" {
t.Errorf("expected name 'Eve', got %q", result.Name)
}
}
func TestGenerate_WrongToolName(t *testing.T) {
mp := newMockProvider(provider.Response{
ToolCalls: []provider.ToolCall{
{
ID: "call_1",
Name: "some_other_tool",
Arguments: `{"name":"Frank","age":50}`,
},
},
})
model := newMockModel(mp)
_, err := Generate[testPerson](context.Background(), model, "Tell me about Frank")
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, ErrNoStructuredOutput) {
t.Errorf("expected ErrNoStructuredOutput, got %v", err)
}
}

41
v2/go.mod Normal file
View File

@@ -0,0 +1,41 @@
module gitea.stevedudenhoeffer.com/steve/go-llm/v2
go 1.24.0
toolchain go1.24.2
require (
github.com/liushuangls/go-anthropic/v2 v2.17.0
github.com/modelcontextprotocol/go-sdk v1.2.0
github.com/openai/openai-go v1.12.0
github.com/pkg/sftp v1.13.10
golang.org/x/crypto v0.41.0
golang.org/x/image v0.35.0
google.golang.org/genai v1.45.0
)
require (
cloud.google.com/go v0.116.0 // indirect
cloud.google.com/go/auth v0.9.3 // indirect
cloud.google.com/go/compute/metadata v0.5.0 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/jsonschema-go v0.3.0 // indirect
github.com/google/s2a-go v0.1.8 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/kr/fs v0.1.0 // indirect
github.com/tidwall/gjson v1.14.4 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
go.opencensus.io v0.24.0 // indirect
golang.org/x/net v0.42.0 // indirect
golang.org/x/oauth2 v0.30.0 // indirect
golang.org/x/sys v0.35.0 // indirect
golang.org/x/text v0.33.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
google.golang.org/grpc v1.66.2 // indirect
google.golang.org/protobuf v1.34.2 // indirect
)

165
v2/go.sum Normal file
View File

@@ -0,0 +1,165 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE=
cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U=
cloud.google.com/go/auth v0.9.3 h1:VOEUIAADkkLtyfr3BLa3R8Ed/j6w1jTBmARx+wb5w5U=
cloud.google.com/go/auth v0.9.3/go.mod h1:7z6VY+7h3KUdRov5F1i8NDP5ZzWKYmEPO842BgCsmTk=
cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY=
cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q=
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM=
github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw=
github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
github.com/liushuangls/go-anthropic/v2 v2.17.0 h1:iBA6h7aghi1q86owEQ95XE2R2MF/0dQ7bCxtwTxOg4c=
github.com/liushuangls/go-anthropic/v2 v2.17.0/go.mod h1:a550cJXPoTG2FL3DvfKG2zzD5O2vjgvo4tHtoGPzFLU=
github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s=
github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10=
github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0=
github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
github.com/pkg/sftp v1.13.10 h1:+5FbKNTe5Z9aspU88DPIKJ9z2KZoaGCu6Sr6kKR/5mU=
github.com/pkg/sftp v1.13.10/go.mod h1:bJ1a7uDhrX/4OII+agvy28lzRvQrmIQuaHrcI1HbeGA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/image v0.35.0 h1:LKjiHdgMtO8z7Fh18nGY6KDcoEtVfsgLDPeLyguqb7I=
golang.org/x/image v0.35.0/go.mod h1:MwPLTVgvxSASsxdLzKrl8BRFuyqMyGhLwmC+TO1Sybk=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4=
golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/genai v1.45.0 h1:s80ZpS42XW0zu/ogiOtenCio17nJ7reEFJjoCftukpA=
google.golang.org/genai v1.45.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
google.golang.org/grpc v1.66.2 h1:3QdXkuq3Bkh7w+ywLdLvM56cmGvQHUMZpiCzt6Rqaoo=
google.golang.org/grpc v1.66.2/go.mod h1:s3/l6xSSCURdVfAnL+TqCNMyTDAGN6+lZeVxnZR128Y=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

355
v2/google/google.go Normal file
View File

@@ -0,0 +1,355 @@
// Package google implements the go-llm v2 provider interface for Google (Gemini).
package google
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
"google.golang.org/genai"
)
// Provider implements the provider.Provider interface for Google Gemini.
type Provider struct {
apiKey string
}
// New creates a new Google provider.
func New(apiKey string) *Provider {
return &Provider{apiKey: apiKey}
}
// Complete performs a non-streaming completion.
func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
cl, err := genai.NewClient(ctx, &genai.ClientConfig{
APIKey: p.apiKey,
Backend: genai.BackendGeminiAPI,
})
if err != nil {
return provider.Response{}, fmt.Errorf("google client error: %w", err)
}
contents, cfg := p.buildRequest(req)
resp, err := cl.Models.GenerateContent(ctx, req.Model, contents, cfg)
if err != nil {
return provider.Response{}, fmt.Errorf("google completion error: %w", err)
}
return p.convertResponse(resp)
}
// Stream performs a streaming completion.
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
cl, err := genai.NewClient(ctx, &genai.ClientConfig{
APIKey: p.apiKey,
Backend: genai.BackendGeminiAPI,
})
if err != nil {
return fmt.Errorf("google client error: %w", err)
}
contents, cfg := p.buildRequest(req)
var fullText strings.Builder
var toolCalls []provider.ToolCall
for resp, err := range cl.Models.GenerateContentStream(ctx, req.Model, contents, cfg) {
if err != nil {
return fmt.Errorf("google stream error: %w", err)
}
for _, c := range resp.Candidates {
if c.Content == nil {
continue
}
for _, part := range c.Content.Parts {
if part.Text != "" {
fullText.WriteString(part.Text)
events <- provider.StreamEvent{
Type: provider.StreamEventText,
Text: part.Text,
}
}
if part.FunctionCall != nil {
args, _ := json.Marshal(part.FunctionCall.Args)
tc := provider.ToolCall{
ID: part.FunctionCall.Name,
Name: part.FunctionCall.Name,
Arguments: string(args),
}
toolCalls = append(toolCalls, tc)
events <- provider.StreamEvent{
Type: provider.StreamEventToolStart,
ToolCall: &tc,
ToolIndex: len(toolCalls) - 1,
}
events <- provider.StreamEvent{
Type: provider.StreamEventToolEnd,
ToolCall: &tc,
ToolIndex: len(toolCalls) - 1,
}
}
}
}
}
events <- provider.StreamEvent{
Type: provider.StreamEventDone,
Response: &provider.Response{
Text: fullText.String(),
ToolCalls: toolCalls,
},
}
return nil
}
func (p *Provider) buildRequest(req provider.Request) ([]*genai.Content, *genai.GenerateContentConfig) {
var contents []*genai.Content
cfg := &genai.GenerateContentConfig{}
for _, tool := range req.Tools {
cfg.Tools = append(cfg.Tools, &genai.Tool{
FunctionDeclarations: []*genai.FunctionDeclaration{
{
Name: tool.Name,
Description: tool.Description,
Parameters: schemaToGenai(tool.Schema),
},
},
})
}
if req.Temperature != nil {
f := float32(*req.Temperature)
cfg.Temperature = &f
}
if req.MaxTokens != nil {
cfg.MaxOutputTokens = int32(*req.MaxTokens)
}
if req.TopP != nil {
f := float32(*req.TopP)
cfg.TopP = &f
}
if len(req.Stop) > 0 {
cfg.StopSequences = req.Stop
}
for _, msg := range req.Messages {
var role genai.Role
switch msg.Role {
case "system":
cfg.SystemInstruction = genai.NewContentFromText(msg.Content, genai.RoleUser)
continue
case "assistant":
role = genai.RoleModel
case "tool":
// Tool results go as function responses (Genai uses RoleUser for function responses)
contents = append(contents, &genai.Content{
Role: genai.RoleUser,
Parts: []*genai.Part{
{
FunctionResponse: &genai.FunctionResponse{
Name: msg.ToolCallID,
Response: map[string]any{
"result": msg.Content,
},
},
},
},
})
continue
default:
role = genai.RoleUser
}
var parts []*genai.Part
if msg.Content != "" {
parts = append(parts, genai.NewPartFromText(msg.Content))
}
// Handle tool calls in assistant messages
for _, tc := range msg.ToolCalls {
var args map[string]any
if tc.Arguments != "" {
_ = json.Unmarshal([]byte(tc.Arguments), &args)
}
parts = append(parts, &genai.Part{
FunctionCall: &genai.FunctionCall{
Name: tc.Name,
Args: args,
},
})
}
for _, img := range msg.Images {
if img.URL != "" {
// Gemini doesn't support URLs directly; download
resp, err := http.Get(img.URL)
if err != nil {
continue
}
data, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
continue
}
mimeType := http.DetectContentType(data)
parts = append(parts, genai.NewPartFromBytes(data, mimeType))
} else if img.Base64 != "" {
data, err := base64.StdEncoding.DecodeString(img.Base64)
if err != nil {
continue
}
parts = append(parts, genai.NewPartFromBytes(data, img.ContentType))
}
}
for _, aud := range msg.Audio {
if aud.URL != "" {
resp, err := http.Get(aud.URL)
if err != nil {
continue
}
data, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
continue
}
mimeType := resp.Header.Get("Content-Type")
if mimeType == "" {
mimeType = aud.ContentType
}
if mimeType == "" {
mimeType = "audio/wav"
}
parts = append(parts, genai.NewPartFromBytes(data, mimeType))
} else if aud.Base64 != "" {
data, err := base64.StdEncoding.DecodeString(aud.Base64)
if err != nil {
continue
}
ct := aud.ContentType
if ct == "" {
ct = "audio/wav"
}
parts = append(parts, genai.NewPartFromBytes(data, ct))
}
}
contents = append(contents, genai.NewContentFromParts(parts, role))
}
return contents, cfg
}
func (p *Provider) convertResponse(resp *genai.GenerateContentResponse) (provider.Response, error) {
var res provider.Response
for _, c := range resp.Candidates {
if c.Content == nil {
continue
}
for _, part := range c.Content.Parts {
if part.Text != "" {
res.Text += part.Text
}
if part.FunctionCall != nil {
args, _ := json.Marshal(part.FunctionCall.Args)
res.ToolCalls = append(res.ToolCalls, provider.ToolCall{
ID: part.FunctionCall.Name,
Name: part.FunctionCall.Name,
Arguments: string(args),
})
}
}
}
if resp.UsageMetadata != nil {
res.Usage = &provider.Usage{
InputTokens: int(resp.UsageMetadata.PromptTokenCount),
OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount),
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
}
}
return res, nil
}
// schemaToGenai converts a JSON Schema map to a genai.Schema.
func schemaToGenai(s map[string]any) *genai.Schema {
if s == nil {
return nil
}
schema := &genai.Schema{}
if t, ok := s["type"].(string); ok {
switch t {
case "object":
schema.Type = genai.TypeObject
case "array":
schema.Type = genai.TypeArray
case "string":
schema.Type = genai.TypeString
case "integer":
schema.Type = genai.TypeInteger
case "number":
schema.Type = genai.TypeNumber
case "boolean":
schema.Type = genai.TypeBoolean
}
}
if desc, ok := s["description"].(string); ok {
schema.Description = desc
}
if props, ok := s["properties"].(map[string]any); ok {
schema.Properties = make(map[string]*genai.Schema)
for k, v := range props {
if vm, ok := v.(map[string]any); ok {
schema.Properties[k] = schemaToGenai(vm)
}
}
}
if req, ok := s["required"].([]string); ok {
schema.Required = req
} else if req, ok := s["required"].([]any); ok {
for _, r := range req {
if rs, ok := r.(string); ok {
schema.Required = append(schema.Required, rs)
}
}
}
if items, ok := s["items"].(map[string]any); ok {
schema.Items = schemaToGenai(items)
}
if enums, ok := s["enum"].([]string); ok {
schema.Enum = enums
} else if enums, ok := s["enum"].([]any); ok {
for _, e := range enums {
if es, ok := e.(string); ok {
schema.Enum = append(schema.Enum, es)
}
}
}
return schema
}

View File

@@ -0,0 +1,105 @@
// Package imageutil provides image compression utilities.
package imageutil
import (
"bytes"
"encoding/base64"
"fmt"
"image"
"image/gif"
"image/jpeg"
_ "image/png" // register PNG decoder
"net/http"
"golang.org/x/image/draw"
)
// CompressImage takes a base-64-encoded image (JPEG, PNG or GIF) and returns
// a base-64-encoded version that is at most maxLength bytes, along with the MIME type.
func CompressImage(b64 string, maxLength int) (string, string, error) {
raw, err := base64.StdEncoding.DecodeString(b64)
if err != nil {
return "", "", fmt.Errorf("base64 decode: %w", err)
}
mime := http.DetectContentType(raw)
if len(raw) <= maxLength {
return b64, mime, nil
}
switch mime {
case "image/gif":
return compressGIF(raw, maxLength)
default:
return compressRaster(raw, maxLength)
}
}
func compressRaster(src []byte, maxLength int) (string, string, error) {
img, _, err := image.Decode(bytes.NewReader(src))
if err != nil {
return "", "", fmt.Errorf("decode raster: %w", err)
}
quality := 95
for {
var buf bytes.Buffer
if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: quality}); err != nil {
return "", "", fmt.Errorf("jpeg encode: %w", err)
}
if buf.Len() <= maxLength {
return base64.StdEncoding.EncodeToString(buf.Bytes()), "image/jpeg", nil
}
if quality > 20 {
quality -= 5
continue
}
b := img.Bounds()
if b.Dx() < 100 || b.Dy() < 100 {
return "", "", fmt.Errorf("cannot compress below %.02fMiB without destroying image", float64(maxLength)/1048576.0)
}
dst := image.NewRGBA(image.Rect(0, 0, int(float64(b.Dx())*0.8), int(float64(b.Dy())*0.8)))
draw.ApproxBiLinear.Scale(dst, dst.Bounds(), img, b, draw.Over, nil)
img = dst
quality = 95
}
}
func compressGIF(src []byte, maxLength int) (string, string, error) {
g, err := gif.DecodeAll(bytes.NewReader(src))
if err != nil {
return "", "", fmt.Errorf("gif decode: %w", err)
}
for {
var buf bytes.Buffer
if err := gif.EncodeAll(&buf, g); err != nil {
return "", "", fmt.Errorf("gif encode: %w", err)
}
if buf.Len() <= maxLength {
return base64.StdEncoding.EncodeToString(buf.Bytes()), "image/gif", nil
}
w, h := g.Config.Width, g.Config.Height
if w < 100 || h < 100 {
return "", "", fmt.Errorf("cannot compress animated GIF below %.02fMiB", float64(maxLength)/1048576.0)
}
nw, nh := int(float64(w)*0.8), int(float64(h)*0.8)
for i, frm := range g.Image {
rgba := image.NewRGBA(frm.Bounds())
draw.Draw(rgba, rgba.Bounds(), frm, frm.Bounds().Min, draw.Src)
dst := image.NewRGBA(image.Rect(0, 0, nw, nh))
draw.ApproxBiLinear.Scale(dst, dst.Bounds(), rgba, rgba.Bounds(), draw.Over, nil)
paletted := image.NewPaletted(dst.Bounds(), nil)
draw.FloydSteinberg.Draw(paletted, paletted.Bounds(), dst, dst.Bounds().Min)
g.Image[i] = paletted
}
g.Config.Width, g.Config.Height = nw, nh
}
}

View File

@@ -0,0 +1,188 @@
// Package schema provides JSON Schema generation from Go structs.
// It produces standard JSON Schema as map[string]any, with no provider-specific types.
package schema
import (
"reflect"
"strings"
)
// FromStruct generates a JSON Schema object from a Go struct.
// Struct tags used:
// - `json:"name"` — sets the property name (standard Go JSON convention)
// - `description:"..."` — sets the property description
// - `enum:"a,b,c"` — restricts string values to the given set
//
// Pointer fields are treated as optional; non-pointer fields are required.
// Anonymous (embedded) struct fields are flattened into the parent.
func FromStruct(v any) map[string]any {
t := reflect.TypeOf(v)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
panic("schema.FromStruct expects a struct or pointer to struct")
}
return objectSchema(t)
}
func objectSchema(t reflect.Type) map[string]any {
properties := map[string]any{}
var required []string
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
// Skip unexported fields
if !field.IsExported() {
continue
}
// Flatten anonymous (embedded) structs
if field.Anonymous {
ft := field.Type
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
if ft.Kind() == reflect.Struct {
embedded := objectSchema(ft)
if props, ok := embedded["properties"].(map[string]any); ok {
for k, v := range props {
properties[k] = v
}
}
if req, ok := embedded["required"].([]string); ok {
required = append(required, req...)
}
}
continue
}
name := fieldName(field)
isRequired := true
ft := field.Type
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
isRequired = false
}
prop := fieldSchema(field, ft)
properties[name] = prop
if isRequired {
required = append(required, name)
}
}
result := map[string]any{
"type": "object",
"properties": properties,
}
if len(required) > 0 {
result["required"] = required
}
return result
}
func fieldSchema(field reflect.StructField, ft reflect.Type) map[string]any {
prop := map[string]any{}
// Check for enum tag first
if enumTag, ok := field.Tag.Lookup("enum"); ok {
vals := parseEnum(enumTag)
prop["type"] = "string"
prop["enum"] = vals
if desc, ok := field.Tag.Lookup("description"); ok {
prop["description"] = desc
}
return prop
}
switch ft.Kind() {
case reflect.String:
prop["type"] = "string"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
prop["type"] = "integer"
case reflect.Float32, reflect.Float64:
prop["type"] = "number"
case reflect.Bool:
prop["type"] = "boolean"
case reflect.Struct:
return objectSchema(ft)
case reflect.Slice:
prop["type"] = "array"
elemType := ft.Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
prop["items"] = typeSchema(elemType)
case reflect.Map:
prop["type"] = "object"
if ft.Key().Kind() == reflect.String {
valType := ft.Elem()
if valType.Kind() == reflect.Ptr {
valType = valType.Elem()
}
prop["additionalProperties"] = typeSchema(valType)
}
default:
prop["type"] = "string" // fallback
}
if desc, ok := field.Tag.Lookup("description"); ok {
prop["description"] = desc
}
return prop
}
func typeSchema(t reflect.Type) map[string]any {
switch t.Kind() {
case reflect.String:
return map[string]any{"type": "string"}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return map[string]any{"type": "integer"}
case reflect.Float32, reflect.Float64:
return map[string]any{"type": "number"}
case reflect.Bool:
return map[string]any{"type": "boolean"}
case reflect.Struct:
return objectSchema(t)
case reflect.Slice:
elemType := t.Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
return map[string]any{
"type": "array",
"items": typeSchema(elemType),
}
default:
return map[string]any{"type": "string"}
}
}
func fieldName(f reflect.StructField) string {
if tag, ok := f.Tag.Lookup("json"); ok {
parts := strings.SplitN(tag, ",", 2)
if parts[0] != "" && parts[0] != "-" {
return parts[0]
}
}
return f.Name
}
func parseEnum(tag string) []string {
parts := strings.Split(tag, ",")
var vals []string
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
vals = append(vals, p)
}
}
return vals
}

View File

@@ -0,0 +1,181 @@
package schema
import (
"encoding/json"
"testing"
)
type SimpleParams struct {
Name string `json:"name" description:"The name"`
Age int `json:"age" description:"The age"`
}
type OptionalParams struct {
Required string `json:"required" description:"A required field"`
Optional *string `json:"optional,omitempty" description:"An optional field"`
}
type EnumParams struct {
Color string `json:"color" description:"The color" enum:"red,green,blue"`
}
type NestedParams struct {
Inner SimpleParams `json:"inner" description:"Nested object"`
}
type ArrayParams struct {
Items []string `json:"items" description:"A list of items"`
}
type EmbeddedBase struct {
ID string `json:"id" description:"The ID"`
}
type EmbeddedParams struct {
EmbeddedBase
Name string `json:"name" description:"The name"`
}
func TestFromStruct_Simple(t *testing.T) {
s := FromStruct(SimpleParams{})
if s["type"] != "object" {
t.Errorf("expected type=object, got %v", s["type"])
}
props, ok := s["properties"].(map[string]any)
if !ok {
t.Fatal("expected properties to be map[string]any")
}
if len(props) != 2 {
t.Errorf("expected 2 properties, got %d", len(props))
}
nameSchema, ok := props["name"].(map[string]any)
if !ok {
t.Fatal("expected name property to be map[string]any")
}
if nameSchema["type"] != "string" {
t.Errorf("expected name type=string, got %v", nameSchema["type"])
}
if nameSchema["description"] != "The name" {
t.Errorf("expected name description='The name', got %v", nameSchema["description"])
}
ageSchema, ok := props["age"].(map[string]any)
if !ok {
t.Fatal("expected age property to be map[string]any")
}
if ageSchema["type"] != "integer" {
t.Errorf("expected age type=integer, got %v", ageSchema["type"])
}
required, ok := s["required"].([]string)
if !ok {
t.Fatal("expected required to be []string")
}
if len(required) != 2 {
t.Errorf("expected 2 required fields, got %d", len(required))
}
}
func TestFromStruct_Optional(t *testing.T) {
s := FromStruct(OptionalParams{})
required, ok := s["required"].([]string)
if !ok {
t.Fatal("expected required to be []string")
}
// Only "required" field should be required, not "optional"
if len(required) != 1 {
t.Errorf("expected 1 required field, got %d: %v", len(required), required)
}
if required[0] != "required" {
t.Errorf("expected required field 'required', got %v", required[0])
}
}
func TestFromStruct_Enum(t *testing.T) {
s := FromStruct(EnumParams{})
props := s["properties"].(map[string]any)
colorSchema := props["color"].(map[string]any)
if colorSchema["type"] != "string" {
t.Errorf("expected enum type=string, got %v", colorSchema["type"])
}
enums, ok := colorSchema["enum"].([]string)
if !ok {
t.Fatal("expected enum to be []string")
}
if len(enums) != 3 {
t.Errorf("expected 3 enum values, got %d", len(enums))
}
}
func TestFromStruct_Nested(t *testing.T) {
s := FromStruct(NestedParams{})
props := s["properties"].(map[string]any)
innerSchema := props["inner"].(map[string]any)
if innerSchema["type"] != "object" {
t.Errorf("expected nested type=object, got %v", innerSchema["type"])
}
innerProps := innerSchema["properties"].(map[string]any)
if len(innerProps) != 2 {
t.Errorf("expected 2 inner properties, got %d", len(innerProps))
}
}
func TestFromStruct_Array(t *testing.T) {
s := FromStruct(ArrayParams{})
props := s["properties"].(map[string]any)
itemsSchema := props["items"].(map[string]any)
if itemsSchema["type"] != "array" {
t.Errorf("expected array type=array, got %v", itemsSchema["type"])
}
items := itemsSchema["items"].(map[string]any)
if items["type"] != "string" {
t.Errorf("expected items type=string, got %v", items["type"])
}
}
func TestFromStruct_Embedded(t *testing.T) {
s := FromStruct(EmbeddedParams{})
props := s["properties"].(map[string]any)
// Should have both ID from embedded and Name
if len(props) != 2 {
t.Errorf("expected 2 properties (flattened), got %d", len(props))
}
if _, ok := props["id"]; !ok {
t.Error("expected 'id' property from embedded struct")
}
if _, ok := props["name"]; !ok {
t.Error("expected 'name' property")
}
}
func TestFromStruct_ValidJSON(t *testing.T) {
s := FromStruct(SimpleParams{})
data, err := json.Marshal(s)
if err != nil {
t.Fatalf("schema should be valid JSON: %v", err)
}
var parsed map[string]any
if err := json.Unmarshal(data, &parsed); err != nil {
t.Fatalf("schema should round-trip through JSON: %v", err)
}
}

207
v2/llm.go Normal file
View File

@@ -0,0 +1,207 @@
package llm
import (
"context"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
// Client represents an LLM provider. Create with OpenAI(), Anthropic(), Google().
type Client struct {
p provider.Provider
middleware []Middleware
}
// NewClient creates a Client backed by the given provider.
// Use this to integrate custom provider implementations or for testing.
func NewClient(p provider.Provider) *Client {
return &Client{p: p}
}
// Model returns a Model for the specified model version.
func (c *Client) Model(modelVersion string) *Model {
return &Model{
provider: c.p,
model: modelVersion,
middleware: c.middleware,
}
}
// WithMiddleware returns a new Client with additional middleware applied to all models.
func (c *Client) WithMiddleware(mw ...Middleware) *Client {
c2 := &Client{
p: c.p,
middleware: append(append([]Middleware{}, c.middleware...), mw...),
}
return c2
}
// Model represents a specific model from a provider, ready for completions.
type Model struct {
provider provider.Provider
model string
middleware []Middleware
}
// Complete sends a non-streaming completion request.
func (m *Model) Complete(ctx context.Context, messages []Message, opts ...RequestOption) (Response, error) {
cfg := &requestConfig{}
for _, opt := range opts {
opt(cfg)
}
chain := m.buildChain()
return chain(ctx, m.model, messages, cfg)
}
// Stream sends a streaming completion request, returning a StreamReader.
func (m *Model) Stream(ctx context.Context, messages []Message, opts ...RequestOption) (*StreamReader, error) {
cfg := &requestConfig{}
for _, opt := range opts {
opt(cfg)
}
req := buildProviderRequest(m.model, messages, cfg)
return newStreamReader(ctx, m.provider, req)
}
// WithMiddleware returns a new Model with additional middleware applied.
func (m *Model) WithMiddleware(mw ...Middleware) *Model {
return &Model{
provider: m.provider,
model: m.model,
middleware: append(append([]Middleware{}, m.middleware...), mw...),
}
}
func (m *Model) buildChain() CompletionFunc {
// Base handler that calls the provider
base := func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
req := buildProviderRequest(model, messages, cfg)
resp, err := m.provider.Complete(ctx, req)
if err != nil {
return Response{}, err
}
return convertProviderResponse(resp), nil
}
// Apply middleware in reverse order (first middleware wraps outermost)
chain := base
for i := len(m.middleware) - 1; i >= 0; i-- {
chain = m.middleware[i](chain)
}
return chain
}
func buildProviderRequest(model string, messages []Message, cfg *requestConfig) provider.Request {
req := provider.Request{
Model: model,
Messages: convertMessages(messages),
}
if cfg.temperature != nil {
req.Temperature = cfg.temperature
}
if cfg.maxTokens != nil {
req.MaxTokens = cfg.maxTokens
}
if cfg.topP != nil {
req.TopP = cfg.topP
}
if len(cfg.stop) > 0 {
req.Stop = cfg.stop
}
if cfg.tools != nil {
for _, tool := range cfg.tools.AllTools() {
req.Tools = append(req.Tools, provider.ToolDef{
Name: tool.Name,
Description: tool.Description,
Schema: tool.Schema,
})
}
}
return req
}
func convertMessages(msgs []Message) []provider.Message {
out := make([]provider.Message, len(msgs))
for i, m := range msgs {
pm := provider.Message{
Role: string(m.Role),
Content: m.Content.Text,
ToolCallID: m.ToolCallID,
}
for _, img := range m.Content.Images {
pm.Images = append(pm.Images, provider.Image{
URL: img.URL,
Base64: img.Base64,
ContentType: img.ContentType,
})
}
for _, aud := range m.Content.Audio {
pm.Audio = append(pm.Audio, provider.Audio{
URL: aud.URL,
Base64: aud.Base64,
ContentType: aud.ContentType,
})
}
for _, tc := range m.ToolCalls {
pm.ToolCalls = append(pm.ToolCalls, provider.ToolCall{
ID: tc.ID,
Name: tc.Name,
Arguments: tc.Arguments,
})
}
out[i] = pm
}
return out
}
func convertProviderResponse(resp provider.Response) Response {
r := Response{
Text: resp.Text,
}
for _, tc := range resp.ToolCalls {
r.ToolCalls = append(r.ToolCalls, ToolCall{
ID: tc.ID,
Name: tc.Name,
Arguments: tc.Arguments,
})
}
if resp.Usage != nil {
r.Usage = &Usage{
InputTokens: resp.Usage.InputTokens,
OutputTokens: resp.Usage.OutputTokens,
TotalTokens: resp.Usage.TotalTokens,
}
}
// Build the assistant message for conversation history
r.message = Message{
Role: RoleAssistant,
Content: Content{Text: resp.Text},
ToolCalls: r.ToolCalls,
}
return r
}
// --- Provider constructors ---
// These are defined here and delegate to provider-specific packages.
// They are set up via init() in the provider packages, or defined directly.
// ClientOption configures a client.
type ClientOption func(*clientConfig)
type clientConfig struct {
baseURL string
}
// WithBaseURL overrides the API base URL.
func WithBaseURL(url string) ClientOption {
return func(c *clientConfig) { c.baseURL = url }
}

264
v2/mcp.go Normal file
View File

@@ -0,0 +1,264 @@
package llm
import (
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"sync"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// MCPTransport specifies how to connect to an MCP server.
type MCPTransport string
const (
MCPStdio MCPTransport = "stdio"
MCPSSE MCPTransport = "sse"
MCPHTTP MCPTransport = "http"
)
// MCPServer represents a connection to an MCP server.
type MCPServer struct {
name string
transport MCPTransport
// stdio fields
command string
args []string
env []string
// network fields
url string
// internal
client *mcp.Client
session *mcp.ClientSession
tools map[string]*mcp.Tool
mu sync.RWMutex
}
// MCPOption configures an MCP server.
type MCPOption func(*MCPServer)
// WithMCPEnv adds environment variables for the subprocess.
func WithMCPEnv(env ...string) MCPOption {
return func(s *MCPServer) { s.env = env }
}
// WithMCPName sets a friendly name for logging.
func WithMCPName(name string) MCPOption {
return func(s *MCPServer) { s.name = name }
}
// MCPStdioServer creates and connects to an MCP server via stdio transport.
//
// Example:
//
// server, err := llm.MCPStdioServer(ctx, "npx", "-y", "@modelcontextprotocol/server-filesystem", "/tmp")
func MCPStdioServer(ctx context.Context, command string, args ...string) (*MCPServer, error) {
s := &MCPServer{
name: command,
transport: MCPStdio,
command: command,
args: args,
}
if err := s.connect(ctx); err != nil {
return nil, err
}
return s, nil
}
// MCPHTTPServer creates and connects to an MCP server via streamable HTTP transport.
//
// Example:
//
// server, err := llm.MCPHTTPServer(ctx, "https://mcp.example.com")
func MCPHTTPServer(ctx context.Context, url string, opts ...MCPOption) (*MCPServer, error) {
s := &MCPServer{
name: url,
transport: MCPHTTP,
url: url,
}
for _, opt := range opts {
opt(s)
}
if err := s.connect(ctx); err != nil {
return nil, err
}
return s, nil
}
// MCPSSEServer creates and connects to an MCP server via SSE transport.
func MCPSSEServer(ctx context.Context, url string, opts ...MCPOption) (*MCPServer, error) {
s := &MCPServer{
name: url,
transport: MCPSSE,
url: url,
}
for _, opt := range opts {
opt(s)
}
if err := s.connect(ctx); err != nil {
return nil, err
}
return s, nil
}
func (s *MCPServer) connect(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.session != nil {
return nil
}
s.client = mcp.NewClient(&mcp.Implementation{
Name: "go-llm-v2",
Version: "2.0.0",
}, nil)
var transport mcp.Transport
switch s.transport {
case MCPSSE:
transport = &mcp.SSEClientTransport{
Endpoint: s.url,
}
case MCPHTTP:
transport = &mcp.StreamableClientTransport{
Endpoint: s.url,
}
default: // stdio
cmd := exec.Command(s.command, s.args...)
cmd.Env = append(os.Environ(), s.env...)
transport = &mcp.CommandTransport{
Command: cmd,
}
}
session, err := s.client.Connect(ctx, transport, nil)
if err != nil {
return fmt.Errorf("failed to connect to MCP server %s: %w", s.name, err)
}
s.session = session
// Load tools
s.tools = make(map[string]*mcp.Tool)
for tool, err := range session.Tools(ctx, nil) {
if err != nil {
s.session.Close()
s.session = nil
return fmt.Errorf("failed to list tools from %s: %w", s.name, err)
}
s.tools[tool.Name] = tool
}
return nil
}
// Close closes the connection to the MCP server.
func (s *MCPServer) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.session == nil {
return nil
}
err := s.session.Close()
s.session = nil
s.tools = nil
return err
}
// IsConnected returns true if the server is connected.
func (s *MCPServer) IsConnected() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.session != nil
}
// ListTools returns Tool definitions for all tools this server provides.
func (s *MCPServer) ListTools() []Tool {
s.mu.RLock()
defer s.mu.RUnlock()
var tools []Tool
for _, t := range s.tools {
tools = append(tools, s.toTool(t))
}
return tools
}
// CallTool invokes a tool on the server.
func (s *MCPServer) CallTool(ctx context.Context, name string, arguments map[string]any) (string, error) {
s.mu.RLock()
session := s.session
s.mu.RUnlock()
if session == nil {
return "", fmt.Errorf("%w: %s", ErrNotConnected, s.name)
}
result, err := session.CallTool(ctx, &mcp.CallToolParams{
Name: name,
Arguments: arguments,
})
if err != nil {
return "", err
}
if len(result.Content) == 0 {
return "", nil
}
return contentToString(result.Content), nil
}
func (s *MCPServer) toTool(t *mcp.Tool) Tool {
var inputSchema map[string]any
if t.InputSchema != nil {
data, err := json.Marshal(t.InputSchema)
if err == nil {
_ = json.Unmarshal(data, &inputSchema)
}
}
if inputSchema == nil {
inputSchema = map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
return Tool{
Name: t.Name,
Description: t.Description,
Schema: inputSchema,
isMCP: true,
mcpServer: s,
}
}
func contentToString(content []mcp.Content) string {
var parts []string
for _, c := range content {
switch tc := c.(type) {
case *mcp.TextContent:
parts = append(parts, tc.Text)
default:
if data, err := json.Marshal(c); err == nil {
parts = append(parts, string(data))
}
}
}
if len(parts) == 1 {
return parts[0]
}
data, _ := json.Marshal(parts)
return string(data)
}

87
v2/message.go Normal file
View File

@@ -0,0 +1,87 @@
package llm
// Role represents who authored a message.
type Role string
const (
RoleSystem Role = "system"
RoleUser Role = "user"
RoleAssistant Role = "assistant"
RoleTool Role = "tool"
)
// Image represents an image attachment.
type Image struct {
// Provide exactly one of URL or Base64.
URL string // HTTP(S) URL
Base64 string // Raw base64-encoded data
ContentType string // MIME type (e.g., "image/png"), required for Base64
}
// Audio represents an audio attachment.
type Audio struct {
// Provide exactly one of URL or Base64.
URL string // HTTP(S) URL to audio file
Base64 string // Raw base64-encoded audio data
ContentType string // MIME type (e.g., "audio/wav", "audio/mp3")
}
// Content represents message content with optional text, images, and audio.
type Content struct {
Text string
Images []Image
Audio []Audio
}
// ToolCall represents a tool invocation requested by the assistant.
type ToolCall struct {
ID string
Name string
Arguments string // raw JSON
}
// Message represents a single message in a conversation.
type Message struct {
Role Role
Content Content
// ToolCallID is set when Role == RoleTool, identifying which tool call this responds to.
ToolCallID string
// ToolCalls is set when the assistant requests tool invocations.
ToolCalls []ToolCall
}
// UserMessage creates a user message with text content.
func UserMessage(text string) Message {
return Message{Role: RoleUser, Content: Content{Text: text}}
}
// UserMessageWithImages creates a user message with text and images.
func UserMessageWithImages(text string, images ...Image) Message {
return Message{Role: RoleUser, Content: Content{Text: text, Images: images}}
}
// UserMessageWithAudio creates a user message with text and audio attachments.
func UserMessageWithAudio(text string, audio ...Audio) Message {
return Message{Role: RoleUser, Content: Content{Text: text, Audio: audio}}
}
// SystemMessage creates a system prompt message.
func SystemMessage(text string) Message {
return Message{Role: RoleSystem, Content: Content{Text: text}}
}
// AssistantMessage creates an assistant message with text content.
func AssistantMessage(text string) Message {
return Message{Role: RoleAssistant, Content: Content{Text: text}}
}
// ToolResultMessage creates a tool result message.
func ToolResultMessage(toolCallID string, result string) Message {
return Message{
Role: RoleTool,
Content: Content{Text: result},
ToolCallID: toolCallID,
}
}

212
v2/message_test.go Normal file
View File

@@ -0,0 +1,212 @@
package llm
import (
"testing"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
func TestUserMessage(t *testing.T) {
msg := UserMessage("hello")
if msg.Role != RoleUser {
t.Errorf("expected role=user, got %v", msg.Role)
}
if msg.Content.Text != "hello" {
t.Errorf("expected text='hello', got %q", msg.Content.Text)
}
if len(msg.Content.Images) != 0 {
t.Errorf("expected no images, got %d", len(msg.Content.Images))
}
}
func TestUserMessageWithImages(t *testing.T) {
img1 := Image{URL: "https://example.com/1.png"}
img2 := Image{Base64: "abc123", ContentType: "image/png"}
msg := UserMessageWithImages("describe", img1, img2)
if msg.Role != RoleUser {
t.Errorf("expected role=user, got %v", msg.Role)
}
if msg.Content.Text != "describe" {
t.Errorf("expected text='describe', got %q", msg.Content.Text)
}
if len(msg.Content.Images) != 2 {
t.Fatalf("expected 2 images, got %d", len(msg.Content.Images))
}
if msg.Content.Images[0].URL != "https://example.com/1.png" {
t.Errorf("expected image[0] URL, got %q", msg.Content.Images[0].URL)
}
if msg.Content.Images[1].Base64 != "abc123" {
t.Errorf("expected image[1] base64='abc123', got %q", msg.Content.Images[1].Base64)
}
if msg.Content.Images[1].ContentType != "image/png" {
t.Errorf("expected image[1] contentType='image/png', got %q", msg.Content.Images[1].ContentType)
}
}
func TestSystemMessage(t *testing.T) {
msg := SystemMessage("Be helpful")
if msg.Role != RoleSystem {
t.Errorf("expected role=system, got %v", msg.Role)
}
if msg.Content.Text != "Be helpful" {
t.Errorf("expected text='Be helpful', got %q", msg.Content.Text)
}
}
func TestAssistantMessage(t *testing.T) {
msg := AssistantMessage("Sure thing")
if msg.Role != RoleAssistant {
t.Errorf("expected role=assistant, got %v", msg.Role)
}
if msg.Content.Text != "Sure thing" {
t.Errorf("expected text='Sure thing', got %q", msg.Content.Text)
}
}
func TestToolResultMessage(t *testing.T) {
msg := ToolResultMessage("tc-123", "result data")
if msg.Role != RoleTool {
t.Errorf("expected role=tool, got %v", msg.Role)
}
if msg.ToolCallID != "tc-123" {
t.Errorf("expected toolCallID='tc-123', got %q", msg.ToolCallID)
}
if msg.Content.Text != "result data" {
t.Errorf("expected text='result data', got %q", msg.Content.Text)
}
}
func TestConvertMessages(t *testing.T) {
msgs := []Message{
SystemMessage("system prompt"),
UserMessageWithImages("look at this", Image{URL: "https://example.com/img.png"}),
{
Role: RoleAssistant,
Content: Content{Text: "I'll use a tool"},
ToolCalls: []ToolCall{
{ID: "tc1", Name: "search", Arguments: `{"q":"test"}`},
},
},
ToolResultMessage("tc1", "found it"),
}
converted := convertMessages(msgs)
if len(converted) != 4 {
t.Fatalf("expected 4 converted messages, got %d", len(converted))
}
// System message
if converted[0].Role != "system" {
t.Errorf("msg[0]: expected role='system', got %q", converted[0].Role)
}
if converted[0].Content != "system prompt" {
t.Errorf("msg[0]: expected content='system prompt', got %q", converted[0].Content)
}
// User message with images
if converted[1].Role != "user" {
t.Errorf("msg[1]: expected role='user', got %q", converted[1].Role)
}
if len(converted[1].Images) != 1 {
t.Fatalf("msg[1]: expected 1 image, got %d", len(converted[1].Images))
}
if converted[1].Images[0].URL != "https://example.com/img.png" {
t.Errorf("msg[1]: expected image URL, got %q", converted[1].Images[0].URL)
}
// Assistant message with tool calls
if converted[2].Role != "assistant" {
t.Errorf("msg[2]: expected role='assistant', got %q", converted[2].Role)
}
if len(converted[2].ToolCalls) != 1 {
t.Fatalf("msg[2]: expected 1 tool call, got %d", len(converted[2].ToolCalls))
}
if converted[2].ToolCalls[0].ID != "tc1" {
t.Errorf("msg[2]: expected tool call ID='tc1', got %q", converted[2].ToolCalls[0].ID)
}
if converted[2].ToolCalls[0].Name != "search" {
t.Errorf("msg[2]: expected tool call name='search', got %q", converted[2].ToolCalls[0].Name)
}
if converted[2].ToolCalls[0].Arguments != `{"q":"test"}` {
t.Errorf("msg[2]: expected tool call arguments, got %q", converted[2].ToolCalls[0].Arguments)
}
// Tool result message
if converted[3].Role != "tool" {
t.Errorf("msg[3]: expected role='tool', got %q", converted[3].Role)
}
if converted[3].ToolCallID != "tc1" {
t.Errorf("msg[3]: expected toolCallID='tc1', got %q", converted[3].ToolCallID)
}
if converted[3].Content != "found it" {
t.Errorf("msg[3]: expected content='found it', got %q", converted[3].Content)
}
}
func TestConvertProviderResponse(t *testing.T) {
t.Run("text only", func(t *testing.T) {
resp := convertProviderResponse(provider.Response{
Text: "hello",
Usage: &provider.Usage{
InputTokens: 10,
OutputTokens: 5,
TotalTokens: 15,
},
})
if resp.Text != "hello" {
t.Errorf("expected text='hello', got %q", resp.Text)
}
if resp.HasToolCalls() {
t.Error("expected no tool calls")
}
if resp.Usage == nil {
t.Fatal("expected usage")
}
if resp.Usage.InputTokens != 10 {
t.Errorf("expected 10 input tokens, got %d", resp.Usage.InputTokens)
}
msg := resp.Message()
if msg.Role != RoleAssistant {
t.Errorf("expected role=assistant, got %v", msg.Role)
}
if msg.Content.Text != "hello" {
t.Errorf("expected message text='hello', got %q", msg.Content.Text)
}
})
t.Run("with tool calls", func(t *testing.T) {
resp := convertProviderResponse(provider.Response{
ToolCalls: []provider.ToolCall{
{ID: "tc1", Name: "search", Arguments: `{"q":"go"}`},
{ID: "tc2", Name: "calc", Arguments: `{"a":1}`},
},
})
if !resp.HasToolCalls() {
t.Fatal("expected tool calls")
}
if len(resp.ToolCalls) != 2 {
t.Fatalf("expected 2 tool calls, got %d", len(resp.ToolCalls))
}
if resp.ToolCalls[0].ID != "tc1" || resp.ToolCalls[0].Name != "search" {
t.Errorf("unexpected tool call[0]: %+v", resp.ToolCalls[0])
}
if resp.ToolCalls[1].ID != "tc2" || resp.ToolCalls[1].Name != "calc" {
t.Errorf("unexpected tool call[1]: %+v", resp.ToolCalls[1])
}
msg := resp.Message()
if len(msg.ToolCalls) != 2 {
t.Errorf("expected 2 tool calls in message, got %d", len(msg.ToolCalls))
}
})
t.Run("nil usage", func(t *testing.T) {
resp := convertProviderResponse(provider.Response{Text: "ok"})
if resp.Usage != nil {
t.Errorf("expected nil usage, got %+v", resp.Usage)
}
})
}

117
v2/middleware.go Normal file
View File

@@ -0,0 +1,117 @@
package llm
import (
"context"
"fmt"
"log/slog"
"sync"
"time"
)
// CompletionFunc is the signature for the completion call chain.
type CompletionFunc func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error)
// Middleware wraps a completion call. It receives the next handler in the chain
// and returns a new handler that can inspect/modify the request and response.
type Middleware func(next CompletionFunc) CompletionFunc
// WithLogging returns middleware that logs requests and responses via slog.
func WithLogging(logger *slog.Logger) Middleware {
return func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
logger.Info("llm request",
"model", model,
"message_count", len(messages),
)
start := time.Now()
resp, err := next(ctx, model, messages, cfg)
elapsed := time.Since(start)
if err != nil {
logger.Error("llm error", "model", model, "elapsed", elapsed, "error", err)
} else {
logger.Info("llm response",
"model", model,
"elapsed", elapsed,
"text_len", len(resp.Text),
"tool_calls", len(resp.ToolCalls),
)
}
return resp, err
}
}
}
// WithRetry returns middleware that retries failed requests with configurable backoff.
func WithRetry(maxRetries int, backoff func(attempt int) time.Duration) Middleware {
return func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
var lastErr error
for attempt := 0; attempt <= maxRetries; attempt++ {
if attempt > 0 {
select {
case <-ctx.Done():
return Response{}, ctx.Err()
case <-time.After(backoff(attempt)):
}
}
resp, err := next(ctx, model, messages, cfg)
if err == nil {
return resp, nil
}
lastErr = err
}
return Response{}, fmt.Errorf("after %d retries: %w", maxRetries, lastErr)
}
}
}
// WithTimeout returns middleware that enforces a per-request timeout.
func WithTimeout(d time.Duration) Middleware {
return func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
ctx, cancel := context.WithTimeout(ctx, d)
defer cancel()
return next(ctx, model, messages, cfg)
}
}
}
// UsageTracker accumulates token usage statistics across calls.
type UsageTracker struct {
mu sync.Mutex
TotalInput int64
TotalOutput int64
TotalRequests int64
}
// Add records usage from a single request.
func (ut *UsageTracker) Add(u *Usage) {
if u == nil {
return
}
ut.mu.Lock()
defer ut.mu.Unlock()
ut.TotalInput += int64(u.InputTokens)
ut.TotalOutput += int64(u.OutputTokens)
ut.TotalRequests++
}
// Summary returns the accumulated totals.
func (ut *UsageTracker) Summary() (input, output, requests int64) {
ut.mu.Lock()
defer ut.mu.Unlock()
return ut.TotalInput, ut.TotalOutput, ut.TotalRequests
}
// WithUsageTracking returns middleware that accumulates token usage across calls.
func WithUsageTracking(tracker *UsageTracker) Middleware {
return func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
resp, err := next(ctx, model, messages, cfg)
if err == nil {
tracker.Add(resp.Usage)
}
return resp, err
}
}
}

282
v2/middleware_test.go Normal file
View File

@@ -0,0 +1,282 @@
package llm
import (
"context"
"errors"
"log/slog"
"sync"
"sync/atomic"
"testing"
"time"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
func TestWithRetry_Success(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp).WithMiddleware(
WithRetry(3, func(attempt int) time.Duration { return time.Millisecond }),
)
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "ok" {
t.Errorf("expected 'ok', got %q", resp.Text)
}
if len(mp.Requests) != 1 {
t.Errorf("expected 1 request (no retries needed), got %d", len(mp.Requests))
}
}
func TestWithRetry_EventualSuccess(t *testing.T) {
var callCount int32
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
n := atomic.AddInt32(&callCount, 1)
if n <= 2 {
return provider.Response{}, errors.New("transient error")
}
return provider.Response{Text: "success"}, nil
})
model := newMockModel(mp).WithMiddleware(
WithRetry(3, func(attempt int) time.Duration { return time.Millisecond }),
)
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "success" {
t.Errorf("expected 'success', got %q", resp.Text)
}
if atomic.LoadInt32(&callCount) != 3 {
t.Errorf("expected 3 calls, got %d", callCount)
}
}
func TestWithRetry_AllFail(t *testing.T) {
providerErr := errors.New("persistent error")
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, providerErr
})
model := newMockModel(mp).WithMiddleware(
WithRetry(2, func(attempt int) time.Duration { return time.Millisecond }),
)
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, providerErr) {
t.Errorf("expected wrapped persistent error, got %v", err)
}
if len(mp.Requests) != 3 {
t.Errorf("expected 3 requests (1 initial + 2 retries), got %d", len(mp.Requests))
}
}
func TestWithRetry_ContextCancelled(t *testing.T) {
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, errors.New("fail")
})
model := newMockModel(mp).WithMiddleware(
WithRetry(10, func(attempt int) time.Duration { return 5 * time.Second }),
)
ctx, cancel := context.WithCancel(context.Background())
// Cancel after a short delay
go func() {
time.Sleep(50 * time.Millisecond)
cancel()
}()
_, err := model.Complete(ctx, []Message{UserMessage("test")})
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, context.Canceled) {
t.Errorf("expected context.Canceled, got %v", err)
}
}
func TestWithTimeout(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "fast"})
model := newMockModel(mp).WithMiddleware(WithTimeout(5 * time.Second))
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "fast" {
t.Errorf("expected 'fast', got %q", resp.Text)
}
}
func TestWithTimeout_Exceeded(t *testing.T) {
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
select {
case <-ctx.Done():
return provider.Response{}, ctx.Err()
case <-time.After(5 * time.Second):
return provider.Response{Text: "slow"}, nil
}
})
model := newMockModel(mp).WithMiddleware(WithTimeout(50 * time.Millisecond))
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Errorf("expected DeadlineExceeded, got %v", err)
}
}
func TestWithUsageTracking(t *testing.T) {
mp := newMockProvider(provider.Response{
Text: "ok",
Usage: &provider.Usage{
InputTokens: 10,
OutputTokens: 5,
TotalTokens: 15,
},
})
tracker := &UsageTracker{}
model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker))
// Make two requests
for i := 0; i < 2; i++ {
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error on call %d: %v", i, err)
}
}
input, output, requests := tracker.Summary()
if input != 20 {
t.Errorf("expected total input 20, got %d", input)
}
if output != 10 {
t.Errorf("expected total output 10, got %d", output)
}
if requests != 2 {
t.Errorf("expected 2 requests, got %d", requests)
}
}
func TestWithUsageTracking_NilUsage(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "no usage"})
tracker := &UsageTracker{}
model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker))
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
input, output, requests := tracker.Summary()
if input != 0 || output != 0 {
t.Errorf("expected 0 tokens with nil usage, got input=%d output=%d", input, output)
}
// Add(nil) returns early without incrementing TotalRequests
if requests != 0 {
t.Errorf("expected 0 requests (nil usage skips Add), got %d", requests)
}
}
func TestUsageTracker_Concurrent(t *testing.T) {
tracker := &UsageTracker{}
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
tracker.Add(&Usage{
InputTokens: 10,
OutputTokens: 5,
TotalTokens: 15,
})
}()
}
wg.Wait()
input, output, requests := tracker.Summary()
if input != 1000 {
t.Errorf("expected total input 1000, got %d", input)
}
if output != 500 {
t.Errorf("expected total output 500, got %d", output)
}
if requests != 100 {
t.Errorf("expected 100 requests, got %d", requests)
}
}
func TestMiddleware_Chaining(t *testing.T) {
var order []string
mw1 := func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
order = append(order, "mw1-before")
resp, err := next(ctx, model, messages, cfg)
order = append(order, "mw1-after")
return resp, err
}
}
mw2 := func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
order = append(order, "mw2-before")
resp, err := next(ctx, model, messages, cfg)
order = append(order, "mw2-after")
return resp, err
}
}
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp).WithMiddleware(mw1, mw2)
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
expected := []string{"mw1-before", "mw2-before", "mw2-after", "mw1-after"}
if len(order) != len(expected) {
t.Fatalf("expected %d middleware calls, got %d: %v", len(expected), len(order), order)
}
for i, v := range expected {
if order[i] != v {
t.Errorf("order[%d]: expected %q, got %q", i, v, order[i])
}
}
}
func TestWithLogging(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "logged"})
logger := slog.Default()
model := newMockModel(mp).WithMiddleware(WithLogging(logger))
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "logged" {
t.Errorf("expected 'logged', got %q", resp.Text)
}
}
func TestWithLogging_Error(t *testing.T) {
providerErr := errors.New("log this error")
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, providerErr
})
logger := slog.Default()
model := newMockModel(mp).WithMiddleware(WithLogging(logger))
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if !errors.Is(err, providerErr) {
t.Errorf("expected provider error, got %v", err)
}
}

87
v2/mock_provider_test.go Normal file
View File

@@ -0,0 +1,87 @@
package llm
import (
"context"
"sync"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
// mockProvider is a configurable mock implementation of provider.Provider for testing.
type mockProvider struct {
CompleteFunc func(ctx context.Context, req provider.Request) (provider.Response, error)
StreamFunc func(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error
// mu guards Requests
mu sync.Mutex
Requests []provider.Request
}
func (m *mockProvider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
m.mu.Lock()
m.Requests = append(m.Requests, req)
m.mu.Unlock()
return m.CompleteFunc(ctx, req)
}
func (m *mockProvider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
m.mu.Lock()
m.Requests = append(m.Requests, req)
m.mu.Unlock()
if m.StreamFunc != nil {
return m.StreamFunc(ctx, req, events)
}
close(events)
return nil
}
// lastRequest returns the most recent request recorded by the mock.
func (m *mockProvider) lastRequest() provider.Request {
m.mu.Lock()
defer m.mu.Unlock()
if len(m.Requests) == 0 {
return provider.Request{}
}
return m.Requests[len(m.Requests)-1]
}
// newMockProvider creates a mock that always returns the given response.
func newMockProvider(resp provider.Response) *mockProvider {
return &mockProvider{
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
return resp, nil
},
}
}
// newMockProviderFunc creates a mock with a custom Complete function.
func newMockProviderFunc(fn func(ctx context.Context, req provider.Request) (provider.Response, error)) *mockProvider {
return &mockProvider{CompleteFunc: fn}
}
// newMockStreamProvider creates a mock that streams the given events.
func newMockStreamProvider(events []provider.StreamEvent) *mockProvider {
return &mockProvider{
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, nil
},
StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error {
for _, ev := range events {
select {
case ch <- ev:
case <-ctx.Done():
return ctx.Err()
}
}
return nil
},
}
}
// newMockModel creates a *Model backed by the given mock provider.
func newMockModel(p *mockProvider) *Model {
return &Model{
provider: p,
model: "mock-model",
}
}

215
v2/model_test.go Normal file
View File

@@ -0,0 +1,215 @@
package llm
import (
"context"
"errors"
"testing"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
func TestModel_Complete(t *testing.T) {
mp := newMockProvider(provider.Response{
Text: "Hello!",
Usage: &provider.Usage{
InputTokens: 10,
OutputTokens: 5,
TotalTokens: 15,
},
})
model := newMockModel(mp)
resp, err := model.Complete(context.Background(), []Message{UserMessage("Hi")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "Hello!" {
t.Errorf("expected text 'Hello!', got %q", resp.Text)
}
if resp.Usage == nil {
t.Fatal("expected usage, got nil")
}
if resp.Usage.InputTokens != 10 {
t.Errorf("expected input tokens 10, got %d", resp.Usage.InputTokens)
}
if resp.Usage.OutputTokens != 5 {
t.Errorf("expected output tokens 5, got %d", resp.Usage.OutputTokens)
}
if resp.Usage.TotalTokens != 15 {
t.Errorf("expected total tokens 15, got %d", resp.Usage.TotalTokens)
}
}
func TestModel_Complete_WithOptions(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp)
temp := 0.7
maxTok := 100
topP := 0.9
_, err := model.Complete(context.Background(), []Message{UserMessage("test")},
WithTemperature(temp),
WithMaxTokens(maxTok),
WithTopP(topP),
WithStop("STOP", "END"),
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
req := mp.lastRequest()
if req.Temperature == nil || *req.Temperature != temp {
t.Errorf("expected temperature %v, got %v", temp, req.Temperature)
}
if req.MaxTokens == nil || *req.MaxTokens != maxTok {
t.Errorf("expected maxTokens %v, got %v", maxTok, req.MaxTokens)
}
if req.TopP == nil || *req.TopP != topP {
t.Errorf("expected topP %v, got %v", topP, req.TopP)
}
if len(req.Stop) != 2 || req.Stop[0] != "STOP" || req.Stop[1] != "END" {
t.Errorf("expected stop [STOP END], got %v", req.Stop)
}
}
func TestModel_Complete_Error(t *testing.T) {
wantErr := errors.New("provider error")
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, wantErr
})
model := newMockModel(mp)
_, err := model.Complete(context.Background(), []Message{UserMessage("Hi")})
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, wantErr) {
t.Errorf("expected error %v, got %v", wantErr, err)
}
}
func TestModel_Complete_WithTools(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "done"})
model := newMockModel(mp)
tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) {
return "hello", nil
})
tb := NewToolBox(tool)
_, err := model.Complete(context.Background(), []Message{UserMessage("test")}, WithTools(tb))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
req := mp.lastRequest()
if len(req.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(req.Tools))
}
if req.Tools[0].Name != "greet" {
t.Errorf("expected tool name 'greet', got %q", req.Tools[0].Name)
}
if req.Tools[0].Description != "Says hello" {
t.Errorf("expected tool description 'Says hello', got %q", req.Tools[0].Description)
}
}
func TestClient_Model(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "hi"})
client := NewClient(mp)
model := client.Model("test-model")
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "hi" {
t.Errorf("expected 'hi', got %q", resp.Text)
}
req := mp.lastRequest()
if req.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", req.Model)
}
}
func TestClient_WithMiddleware(t *testing.T) {
var called bool
mw := func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
called = true
return next(ctx, model, messages, cfg)
}
}
mp := newMockProvider(provider.Response{Text: "ok"})
client := NewClient(mp).WithMiddleware(mw)
model := client.Model("test-model")
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !called {
t.Error("middleware was not called")
}
}
func TestModel_WithMiddleware(t *testing.T) {
var order []string
mw1 := func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
order = append(order, "mw1")
return next(ctx, model, messages, cfg)
}
}
mw2 := func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
order = append(order, "mw2")
return next(ctx, model, messages, cfg)
}
}
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp).WithMiddleware(mw1).WithMiddleware(mw2)
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(order) != 2 || order[0] != "mw1" || order[1] != "mw2" {
t.Errorf("expected middleware order [mw1 mw2], got %v", order)
}
}
func TestModel_Complete_NoUsage(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "no usage"})
model := newMockModel(mp)
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Usage != nil {
t.Errorf("expected nil usage, got %+v", resp.Usage)
}
}
func TestModel_Complete_ResponseMessage(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "response text"})
model := newMockModel(mp)
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
msg := resp.Message()
if msg.Role != RoleAssistant {
t.Errorf("expected role assistant, got %v", msg.Role)
}
if msg.Content.Text != "response text" {
t.Errorf("expected text 'response text', got %q", msg.Content.Text)
}
}

395
v2/openai/openai.go Normal file
View File

@@ -0,0 +1,395 @@
// Package openai implements the go-llm v2 provider interface for OpenAI.
package openai
import (
"context"
"encoding/base64"
"fmt"
"io"
"net/http"
"path"
"strings"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/param"
"github.com/openai/openai-go/shared"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
// Provider implements the provider.Provider interface for OpenAI.
type Provider struct {
apiKey string
baseURL string
}
// New creates a new OpenAI provider.
func New(apiKey string, baseURL string) *Provider {
return &Provider{apiKey: apiKey, baseURL: baseURL}
}
// Complete performs a non-streaming completion.
func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
var opts []option.RequestOption
opts = append(opts, option.WithAPIKey(p.apiKey))
if p.baseURL != "" {
opts = append(opts, option.WithBaseURL(p.baseURL))
}
cl := openai.NewClient(opts...)
oaiReq := p.buildRequest(req)
resp, err := cl.Chat.Completions.New(ctx, oaiReq)
if err != nil {
return provider.Response{}, fmt.Errorf("openai completion error: %w", err)
}
return p.convertResponse(resp), nil
}
// Stream performs a streaming completion.
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
var opts []option.RequestOption
opts = append(opts, option.WithAPIKey(p.apiKey))
if p.baseURL != "" {
opts = append(opts, option.WithBaseURL(p.baseURL))
}
cl := openai.NewClient(opts...)
oaiReq := p.buildRequest(req)
stream := cl.Chat.Completions.NewStreaming(ctx, oaiReq)
var fullText strings.Builder
var toolCalls []provider.ToolCall
toolCallArgs := map[int]*strings.Builder{}
for stream.Next() {
chunk := stream.Current()
for _, choice := range chunk.Choices {
// Text delta
if choice.Delta.Content != "" {
fullText.WriteString(choice.Delta.Content)
events <- provider.StreamEvent{
Type: provider.StreamEventText,
Text: choice.Delta.Content,
}
}
// Tool call deltas
for _, tc := range choice.Delta.ToolCalls {
idx := int(tc.Index)
if tc.ID != "" {
// New tool call starting
for len(toolCalls) <= idx {
toolCalls = append(toolCalls, provider.ToolCall{})
}
toolCalls[idx].ID = tc.ID
toolCalls[idx].Name = tc.Function.Name
toolCallArgs[idx] = &strings.Builder{}
events <- provider.StreamEvent{
Type: provider.StreamEventToolStart,
ToolCall: &provider.ToolCall{
ID: tc.ID,
Name: tc.Function.Name,
},
ToolIndex: idx,
}
}
if tc.Function.Arguments != "" {
if b, ok := toolCallArgs[idx]; ok {
b.WriteString(tc.Function.Arguments)
}
events <- provider.StreamEvent{
Type: provider.StreamEventToolDelta,
ToolIndex: idx,
ToolCall: &provider.ToolCall{
Arguments: tc.Function.Arguments,
},
}
}
}
}
}
if err := stream.Err(); err != nil {
return fmt.Errorf("openai stream error: %w", err)
}
// Finalize tool calls
for idx := range toolCalls {
if b, ok := toolCallArgs[idx]; ok {
toolCalls[idx].Arguments = b.String()
}
events <- provider.StreamEvent{
Type: provider.StreamEventToolEnd,
ToolIndex: idx,
ToolCall: &toolCalls[idx],
}
}
// Send done event
events <- provider.StreamEvent{
Type: provider.StreamEventDone,
Response: &provider.Response{
Text: fullText.String(),
ToolCalls: toolCalls,
},
}
return nil
}
func (p *Provider) buildRequest(req provider.Request) openai.ChatCompletionNewParams {
oaiReq := openai.ChatCompletionNewParams{
Model: req.Model,
}
for _, msg := range req.Messages {
oaiReq.Messages = append(oaiReq.Messages, convertMessage(msg, req.Model))
}
for _, tool := range req.Tools {
oaiReq.Tools = append(oaiReq.Tools, openai.ChatCompletionToolParam{
Type: "function",
Function: shared.FunctionDefinitionParam{
Name: tool.Name,
Description: openai.String(tool.Description),
Parameters: openai.FunctionParameters(tool.Schema),
},
})
}
if req.Temperature != nil {
// o* and gpt-5* models don't support custom temperatures
if !strings.HasPrefix(req.Model, "o") && !strings.HasPrefix(req.Model, "gpt-5") {
oaiReq.Temperature = openai.Float(*req.Temperature)
}
}
if req.MaxTokens != nil {
oaiReq.MaxCompletionTokens = openai.Int(int64(*req.MaxTokens))
}
if req.TopP != nil {
oaiReq.TopP = openai.Float(*req.TopP)
}
if len(req.Stop) > 0 {
oaiReq.Stop = openai.ChatCompletionNewParamsStopUnion{OfString: openai.String(req.Stop[0])}
}
return oaiReq
}
func convertMessage(msg provider.Message, model string) openai.ChatCompletionMessageParamUnion {
var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam
var textContent param.Opt[string]
for _, img := range msg.Images {
var url string
if img.Base64 != "" {
url = "data:" + img.ContentType + ";base64," + img.Base64
} else if img.URL != "" {
url = img.URL
}
if url != "" {
arrayOfContentParts = append(arrayOfContentParts,
openai.ChatCompletionContentPartUnionParam{
OfImageURL: &openai.ChatCompletionContentPartImageParam{
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
URL: url,
},
},
},
)
}
}
for _, aud := range msg.Audio {
var b64Data string
var format string
if aud.Base64 != "" {
b64Data = aud.Base64
format = audioFormat(aud.ContentType)
} else if aud.URL != "" {
resp, err := http.Get(aud.URL)
if err != nil {
continue
}
data, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
continue
}
b64Data = base64.StdEncoding.EncodeToString(data)
ct := resp.Header.Get("Content-Type")
if ct == "" {
ct = aud.ContentType
}
if ct == "" {
ct = audioFormatFromURL(aud.URL)
}
format = audioFormat(ct)
}
if b64Data != "" && format != "" {
arrayOfContentParts = append(arrayOfContentParts,
openai.ChatCompletionContentPartUnionParam{
OfInputAudio: &openai.ChatCompletionContentPartInputAudioParam{
InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
Data: b64Data,
Format: format,
},
},
},
)
}
}
if msg.Content != "" {
if len(arrayOfContentParts) > 0 {
arrayOfContentParts = append(arrayOfContentParts,
openai.ChatCompletionContentPartUnionParam{
OfText: &openai.ChatCompletionContentPartTextParam{
Text: msg.Content,
},
},
)
} else {
textContent = openai.String(msg.Content)
}
}
// Determine if this model uses developer messages instead of system
useDeveloper := false
parts := strings.Split(model, "-")
if len(parts) > 1 && len(parts[0]) > 0 && parts[0][0] == 'o' {
useDeveloper = true
}
switch msg.Role {
case "system":
if useDeveloper {
return openai.ChatCompletionMessageParamUnion{
OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{
Content: openai.ChatCompletionDeveloperMessageParamContentUnion{
OfString: textContent,
},
},
}
}
return openai.ChatCompletionMessageParamUnion{
OfSystem: &openai.ChatCompletionSystemMessageParam{
Content: openai.ChatCompletionSystemMessageParamContentUnion{
OfString: textContent,
},
},
}
case "user":
return openai.ChatCompletionMessageParamUnion{
OfUser: &openai.ChatCompletionUserMessageParam{
Content: openai.ChatCompletionUserMessageParamContentUnion{
OfString: textContent,
OfArrayOfContentParts: arrayOfContentParts,
},
},
}
case "assistant":
as := &openai.ChatCompletionAssistantMessageParam{}
if msg.Content != "" {
as.Content.OfString = openai.String(msg.Content)
}
for _, tc := range msg.ToolCalls {
as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{
ID: tc.ID,
Function: openai.ChatCompletionMessageToolCallFunctionParam{
Name: tc.Name,
Arguments: tc.Arguments,
},
})
}
return openai.ChatCompletionMessageParamUnion{OfAssistant: as}
case "tool":
return openai.ChatCompletionMessageParamUnion{
OfTool: &openai.ChatCompletionToolMessageParam{
ToolCallID: msg.ToolCallID,
Content: openai.ChatCompletionToolMessageParamContentUnion{
OfString: openai.String(msg.Content),
},
},
}
}
// Fallback to user message
return openai.ChatCompletionMessageParamUnion{
OfUser: &openai.ChatCompletionUserMessageParam{
Content: openai.ChatCompletionUserMessageParamContentUnion{
OfString: textContent,
},
},
}
}
func (p *Provider) convertResponse(resp *openai.ChatCompletion) provider.Response {
var res provider.Response
if resp == nil || len(resp.Choices) == 0 {
return res
}
choice := resp.Choices[0]
res.Text = choice.Message.Content
for _, tc := range choice.Message.ToolCalls {
res.ToolCalls = append(res.ToolCalls, provider.ToolCall{
ID: tc.ID,
Name: tc.Function.Name,
Arguments: strings.TrimSpace(tc.Function.Arguments),
})
}
if resp.Usage.TotalTokens > 0 {
res.Usage = &provider.Usage{
InputTokens: int(resp.Usage.PromptTokens),
OutputTokens: int(resp.Usage.CompletionTokens),
TotalTokens: int(resp.Usage.TotalTokens),
}
}
return res
}
// audioFormat converts a MIME type to an OpenAI audio format string ("wav" or "mp3").
func audioFormat(contentType string) string {
ct := strings.ToLower(contentType)
switch {
case strings.Contains(ct, "wav"):
return "wav"
case strings.Contains(ct, "mp3"), strings.Contains(ct, "mpeg"):
return "mp3"
default:
return "wav"
}
}
// audioFormatFromURL guesses the audio format from a URL's file extension.
func audioFormatFromURL(u string) string {
ext := strings.ToLower(path.Ext(u))
switch ext {
case ".mp3":
return "audio/mp3"
case ".wav":
return "audio/wav"
default:
return "audio/wav"
}
}

230
v2/openai/transcriber.go Normal file
View File

@@ -0,0 +1,230 @@
package openai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"strings"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
// Transcriber implements the provider.Transcriber interface using OpenAI's audio models.
type Transcriber struct {
key string
model string
baseURL string
}
var _ provider.Transcriber = (*Transcriber)(nil)
// NewTranscriber creates a transcriber backed by OpenAI's audio models.
// If model is empty, "whisper-1" is used by default.
func NewTranscriber(key string, model string) *Transcriber {
if strings.TrimSpace(model) == "" {
model = "whisper-1"
}
return &Transcriber{
key: key,
model: model,
}
}
// NewTranscriberWithBaseURL creates a transcriber with a custom API base URL.
func NewTranscriberWithBaseURL(key, model, baseURL string) *Transcriber {
t := NewTranscriber(key, model)
t.baseURL = baseURL
return t
}
// Transcribe performs speech-to-text transcription of WAV audio data.
func (t *Transcriber) Transcribe(ctx context.Context, wav []byte, opts provider.TranscriptionOptions) (provider.Transcription, error) {
if len(wav) == 0 {
return provider.Transcription{}, fmt.Errorf("wav data is empty")
}
format := opts.ResponseFormat
if format == "" {
if strings.HasPrefix(t.model, "gpt-4o") {
format = provider.TranscriptionResponseFormatJSON
} else {
format = provider.TranscriptionResponseFormatVerboseJSON
}
}
if format != provider.TranscriptionResponseFormatJSON && format != provider.TranscriptionResponseFormatVerboseJSON {
return provider.Transcription{}, fmt.Errorf("openai transcriber requires response_format json or verbose_json for structured output")
}
if len(opts.TimestampGranularities) > 0 && format != provider.TranscriptionResponseFormatVerboseJSON {
return provider.Transcription{}, fmt.Errorf("timestamp granularities require response_format=verbose_json")
}
params := openai.AudioTranscriptionNewParams{
File: openai.File(bytes.NewReader(wav), "audio.wav", "audio/wav"),
Model: openai.AudioModel(t.model),
}
if opts.Language != "" {
params.Language = openai.String(opts.Language)
}
if opts.Prompt != "" {
params.Prompt = openai.String(opts.Prompt)
}
if opts.Temperature != nil {
params.Temperature = openai.Float(*opts.Temperature)
}
params.ResponseFormat = openai.AudioResponseFormat(format)
if opts.IncludeLogprobs {
params.Include = []openai.TranscriptionInclude{openai.TranscriptionIncludeLogprobs}
}
if len(opts.TimestampGranularities) > 0 {
for _, granularity := range opts.TimestampGranularities {
params.TimestampGranularities = append(params.TimestampGranularities, string(granularity))
}
}
clientOptions := []option.RequestOption{
option.WithAPIKey(t.key),
}
if t.baseURL != "" {
clientOptions = append(clientOptions, option.WithBaseURL(t.baseURL))
}
client := openai.NewClient(clientOptions...)
resp, err := client.Audio.Transcriptions.New(ctx, params)
if err != nil {
return provider.Transcription{}, fmt.Errorf("openai transcription failed: %w", err)
}
return transcriptionToResult(t.model, resp), nil
}
type verboseTranscription struct {
Text string `json:"text"`
Language string `json:"language"`
Duration float64 `json:"duration"`
Segments []verboseSegment `json:"segments"`
Words []verboseWord `json:"words"`
}
type verboseSegment struct {
ID int `json:"id"`
Start float64 `json:"start"`
End float64 `json:"end"`
Text string `json:"text"`
Tokens []int `json:"tokens"`
AvgLogprob *float64 `json:"avg_logprob"`
CompressionRatio *float64 `json:"compression_ratio"`
NoSpeechProb *float64 `json:"no_speech_prob"`
Words []verboseWord `json:"words"`
}
type verboseWord struct {
Word string `json:"word"`
Start float64 `json:"start"`
End float64 `json:"end"`
}
func transcriptionToResult(model string, resp *openai.Transcription) provider.Transcription {
result := provider.Transcription{
Provider: "openai",
Model: model,
}
if resp == nil {
return result
}
result.Text = resp.Text
result.RawJSON = resp.RawJSON()
for _, logprob := range resp.Logprobs {
result.Logprobs = append(result.Logprobs, provider.TranscriptionTokenLogprob{
Token: logprob.Token,
Bytes: logprob.Bytes,
Logprob: logprob.Logprob,
})
}
if usage := usageToTranscriptionUsage(resp.Usage); usage.Type != "" {
result.Usage = usage
}
if result.RawJSON == "" {
return result
}
var verbose verboseTranscription
if err := json.Unmarshal([]byte(result.RawJSON), &verbose); err != nil {
return result
}
if verbose.Text != "" {
result.Text = verbose.Text
}
result.Language = verbose.Language
result.DurationSeconds = verbose.Duration
for _, seg := range verbose.Segments {
segment := provider.TranscriptionSegment{
ID: seg.ID,
Start: seg.Start,
End: seg.End,
Text: seg.Text,
Tokens: append([]int(nil), seg.Tokens...),
AvgLogprob: seg.AvgLogprob,
CompressionRatio: seg.CompressionRatio,
NoSpeechProb: seg.NoSpeechProb,
}
for _, word := range seg.Words {
segment.Words = append(segment.Words, provider.TranscriptionWord{
Word: word.Word,
Start: word.Start,
End: word.End,
})
}
result.Segments = append(result.Segments, segment)
}
for _, word := range verbose.Words {
result.Words = append(result.Words, provider.TranscriptionWord{
Word: word.Word,
Start: word.Start,
End: word.End,
})
}
return result
}
func usageToTranscriptionUsage(usage openai.TranscriptionUsageUnion) provider.TranscriptionUsage {
switch usage.Type {
case "tokens":
tokens := usage.AsTokens()
return provider.TranscriptionUsage{
Type: usage.Type,
InputTokens: tokens.InputTokens,
OutputTokens: tokens.OutputTokens,
TotalTokens: tokens.TotalTokens,
AudioTokens: tokens.InputTokenDetails.AudioTokens,
TextTokens: tokens.InputTokenDetails.TextTokens,
}
case "duration":
duration := usage.AsDuration()
return provider.TranscriptionUsage{
Type: usage.Type,
Seconds: duration.Seconds,
}
default:
return provider.TranscriptionUsage{}
}
}

100
v2/provider/provider.go Normal file
View File

@@ -0,0 +1,100 @@
// Package provider defines the interface that LLM backend implementations must satisfy.
package provider
import "context"
// Message is the provider-level message representation.
type Message struct {
Role string
Content string
Images []Image
Audio []Audio
ToolCalls []ToolCall
ToolCallID string
}
// Image represents an image attachment at the provider level.
type Image struct {
URL string
Base64 string
ContentType string
}
// Audio represents an audio attachment at the provider level.
type Audio struct {
URL string
Base64 string
ContentType string
}
// ToolCall represents a tool invocation requested by the model.
type ToolCall struct {
ID string
Name string
Arguments string // raw JSON
}
// ToolDef defines a tool available to the model.
type ToolDef struct {
Name string
Description string
Schema map[string]any // JSON Schema
}
// Request is a completion request at the provider level.
type Request struct {
Model string
Messages []Message
Tools []ToolDef
Temperature *float64
MaxTokens *int
TopP *float64
Stop []string
}
// Response is a completion response at the provider level.
type Response struct {
Text string
ToolCalls []ToolCall
Usage *Usage
}
// Usage captures token consumption.
type Usage struct {
InputTokens int
OutputTokens int
TotalTokens int
}
// StreamEventType identifies the kind of stream event.
type StreamEventType int
const (
StreamEventText StreamEventType = iota // Text content delta
StreamEventToolStart // Tool call begins
StreamEventToolDelta // Tool call argument delta
StreamEventToolEnd // Tool call complete
StreamEventDone // Stream complete
StreamEventError // Error occurred
)
// StreamEvent represents a single event in a streaming response.
type StreamEvent struct {
Type StreamEventType
Text string
ToolCall *ToolCall
ToolIndex int
Error error
Response *Response
}
// Provider is the interface that LLM backends implement.
type Provider interface {
// Complete performs a non-streaming completion.
Complete(ctx context.Context, req Request) (Response, error)
// Stream performs a streaming completion, sending events to the channel.
// The provider MUST close the channel when done.
// The provider MUST send exactly one StreamEventDone as the last non-error event.
Stream(ctx context.Context, req Request, events chan<- StreamEvent) error
}

View File

@@ -0,0 +1,90 @@
package provider
import "context"
// Transcriber abstracts a speech-to-text model implementation.
type Transcriber interface {
Transcribe(ctx context.Context, wav []byte, opts TranscriptionOptions) (Transcription, error)
}
// TranscriptionResponseFormat controls the output format requested from a transcriber.
type TranscriptionResponseFormat string
const (
TranscriptionResponseFormatJSON TranscriptionResponseFormat = "json"
TranscriptionResponseFormatVerboseJSON TranscriptionResponseFormat = "verbose_json"
TranscriptionResponseFormatText TranscriptionResponseFormat = "text"
TranscriptionResponseFormatSRT TranscriptionResponseFormat = "srt"
TranscriptionResponseFormatVTT TranscriptionResponseFormat = "vtt"
)
// TranscriptionTimestampGranularity defines the requested timestamp detail.
type TranscriptionTimestampGranularity string
const (
TranscriptionTimestampGranularityWord TranscriptionTimestampGranularity = "word"
TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment"
)
// TranscriptionOptions configures transcription behavior.
type TranscriptionOptions struct {
Language string
Prompt string
Temperature *float64
ResponseFormat TranscriptionResponseFormat
TimestampGranularities []TranscriptionTimestampGranularity
IncludeLogprobs bool
}
// Transcription captures a normalized transcription result.
type Transcription struct {
Provider string
Model string
Text string
Language string
DurationSeconds float64
Segments []TranscriptionSegment
Words []TranscriptionWord
Logprobs []TranscriptionTokenLogprob
Usage TranscriptionUsage
RawJSON string
}
// TranscriptionSegment provides a coarse time-sliced transcription segment.
type TranscriptionSegment struct {
ID int
Start float64
End float64
Text string
Tokens []int
AvgLogprob *float64
CompressionRatio *float64
NoSpeechProb *float64
Words []TranscriptionWord
}
// TranscriptionWord provides a word-level timestamp.
type TranscriptionWord struct {
Word string
Start float64
End float64
Confidence *float64
}
// TranscriptionTokenLogprob captures token-level log probability details.
type TranscriptionTokenLogprob struct {
Token string
Bytes []float64
Logprob float64
}
// TranscriptionUsage captures token or duration usage details.
type TranscriptionUsage struct {
Type string
InputTokens int64
OutputTokens int64
TotalTokens int64
AudioTokens int64
TextTokens int64
Seconds float64
}

37
v2/request.go Normal file
View File

@@ -0,0 +1,37 @@
package llm
// RequestOption configures a single completion request.
type RequestOption func(*requestConfig)
type requestConfig struct {
tools *ToolBox
temperature *float64
maxTokens *int
topP *float64
stop []string
}
// WithTools attaches a toolbox to the request.
func WithTools(tb *ToolBox) RequestOption {
return func(c *requestConfig) { c.tools = tb }
}
// WithTemperature sets the sampling temperature.
func WithTemperature(t float64) RequestOption {
return func(c *requestConfig) { c.temperature = &t }
}
// WithMaxTokens sets the maximum number of tokens to generate.
func WithMaxTokens(n int) RequestOption {
return func(c *requestConfig) { c.maxTokens = &n }
}
// WithTopP sets the nucleus sampling parameter.
func WithTopP(p float64) RequestOption {
return func(c *requestConfig) { c.topP = &p }
}
// WithStop sets stop sequences.
func WithStop(sequences ...string) RequestOption {
return func(c *requestConfig) { c.stop = sequences }
}

137
v2/request_test.go Normal file
View File

@@ -0,0 +1,137 @@
package llm
import (
"context"
"testing"
)
func TestWithTemperature(t *testing.T) {
cfg := &requestConfig{}
WithTemperature(0.7)(cfg)
if cfg.temperature == nil || *cfg.temperature != 0.7 {
t.Errorf("expected temperature 0.7, got %v", cfg.temperature)
}
}
func TestWithMaxTokens(t *testing.T) {
cfg := &requestConfig{}
WithMaxTokens(256)(cfg)
if cfg.maxTokens == nil || *cfg.maxTokens != 256 {
t.Errorf("expected maxTokens 256, got %v", cfg.maxTokens)
}
}
func TestWithTopP(t *testing.T) {
cfg := &requestConfig{}
WithTopP(0.95)(cfg)
if cfg.topP == nil || *cfg.topP != 0.95 {
t.Errorf("expected topP 0.95, got %v", cfg.topP)
}
}
func TestWithStop(t *testing.T) {
cfg := &requestConfig{}
WithStop("END", "STOP", "###")(cfg)
if len(cfg.stop) != 3 {
t.Fatalf("expected 3 stop sequences, got %d", len(cfg.stop))
}
if cfg.stop[0] != "END" || cfg.stop[1] != "STOP" || cfg.stop[2] != "###" {
t.Errorf("unexpected stop sequences: %v", cfg.stop)
}
}
func TestWithTools(t *testing.T) {
tool := DefineSimple("test", "A test tool", func(ctx context.Context) (string, error) {
return "ok", nil
})
tb := NewToolBox(tool)
cfg := &requestConfig{}
WithTools(tb)(cfg)
if cfg.tools == nil {
t.Fatal("expected tools to be set")
}
if len(cfg.tools.AllTools()) != 1 {
t.Errorf("expected 1 tool, got %d", len(cfg.tools.AllTools()))
}
}
func TestBuildProviderRequest(t *testing.T) {
tool := DefineSimple("greet", "Greets", func(ctx context.Context) (string, error) {
return "hi", nil
})
tb := NewToolBox(tool)
temp := 0.8
maxTok := 512
topP := 0.9
cfg := &requestConfig{
tools: tb,
temperature: &temp,
maxTokens: &maxTok,
topP: &topP,
stop: []string{"END"},
}
msgs := []Message{
SystemMessage("be nice"),
UserMessage("hello"),
}
req := buildProviderRequest("test-model", msgs, cfg)
if req.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", req.Model)
}
if len(req.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(req.Messages))
}
if req.Messages[0].Role != "system" {
t.Errorf("expected first message role='system', got %q", req.Messages[0].Role)
}
if req.Messages[1].Role != "user" {
t.Errorf("expected second message role='user', got %q", req.Messages[1].Role)
}
if req.Temperature == nil || *req.Temperature != 0.8 {
t.Errorf("expected temperature 0.8, got %v", req.Temperature)
}
if req.MaxTokens == nil || *req.MaxTokens != 512 {
t.Errorf("expected maxTokens 512, got %v", req.MaxTokens)
}
if req.TopP == nil || *req.TopP != 0.9 {
t.Errorf("expected topP 0.9, got %v", req.TopP)
}
if len(req.Stop) != 1 || req.Stop[0] != "END" {
t.Errorf("expected stop=[END], got %v", req.Stop)
}
if len(req.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(req.Tools))
}
if req.Tools[0].Name != "greet" {
t.Errorf("expected tool name 'greet', got %q", req.Tools[0].Name)
}
}
func TestBuildProviderRequest_EmptyConfig(t *testing.T) {
cfg := &requestConfig{}
msgs := []Message{UserMessage("hi")}
req := buildProviderRequest("model", msgs, cfg)
if req.Temperature != nil {
t.Errorf("expected nil temperature, got %v", req.Temperature)
}
if req.MaxTokens != nil {
t.Errorf("expected nil maxTokens, got %v", req.MaxTokens)
}
if req.TopP != nil {
t.Errorf("expected nil topP, got %v", req.TopP)
}
if len(req.Stop) != 0 {
t.Errorf("expected no stop sequences, got %v", req.Stop)
}
if len(req.Tools) != 0 {
t.Errorf("expected no tools, got %d", len(req.Tools))
}
}

34
v2/response.go Normal file
View File

@@ -0,0 +1,34 @@
package llm
// Response represents the result of a completion request.
type Response struct {
// Text is the assistant's text content. Empty if only tool calls.
Text string
// ToolCalls contains any tool invocations the assistant requested.
ToolCalls []ToolCall
// Usage contains token usage information (if available from provider).
Usage *Usage
// message is the full assistant message for this response.
message Message
}
// Message returns the full assistant Message for this response,
// suitable for appending to the conversation history.
func (r Response) Message() Message {
return r.message
}
// HasToolCalls returns true if the response contains tool call requests.
func (r Response) HasToolCalls() bool {
return len(r.ToolCalls) > 0
}
// Usage captures token consumption.
type Usage struct {
InputTokens int
OutputTokens int
TotalTokens int
}

78
v2/sandbox/doc.go Normal file
View File

@@ -0,0 +1,78 @@
// Package sandbox provides isolated Linux container environments for LLM agents.
//
// It manages the full lifecycle of Proxmox LXC containers — cloning from a template,
// starting, connecting via SSH, executing commands, transferring files, and destroying
// the container when done. Each sandbox is an ephemeral, unprivileged container on an
// isolated network bridge with no LAN access.
//
// # Architecture
//
// The package has three layers:
//
// - ProxmoxClient: thin REST client for the Proxmox VE API (container CRUD, IP discovery)
// - SSHExecutor: persistent SSH/SFTP connection for command execution and file transfer
// - Manager/Sandbox: high-level orchestrator that ties Proxmox + SSH together
//
// # Usage
//
// // Load SSH key for container access.
// signer, err := sandbox.LoadSSHKey("/etc/mort/sandbox_key")
// if err != nil {
// log.Fatal(err)
// }
//
// // Create a manager.
// mgr, err := sandbox.NewManager(sandbox.Config{
// Proxmox: sandbox.ProxmoxConfig{
// BaseURL: "https://proxmox.local:8006",
// TokenID: "mort-sandbox@pve!sandbox-token",
// Secret: os.Getenv("SANDBOX_PROXMOX_SECRET"),
// Node: "pve",
// TemplateID: 9000,
// Pool: "sandbox-pool",
// Bridge: "vmbr1",
// },
// SSH: sandbox.SSHConfig{
// Signer: signer,
// },
// })
// if err != nil {
// log.Fatal(err)
// }
//
// // Create a sandbox.
// ctx := context.Background()
// sb, err := mgr.Create(ctx,
// sandbox.WithHostname("user-abc"),
// sandbox.WithInternet(true),
// )
// if err != nil {
// log.Fatal(err)
// }
// defer sb.Destroy(ctx)
//
// // Execute commands.
// result, err := sb.Exec(ctx, "apt-get update && apt-get install -y nginx")
// if err != nil {
// log.Fatal(err)
// }
// fmt.Printf("exit %d: %s\n", result.ExitCode, result.Output)
//
// // Write files.
// err = sb.WriteFile(ctx, "/var/www/html/index.html", "<h1>Hello</h1>")
//
// // Read files.
// content, err := sb.ReadFile(ctx, "/etc/nginx/nginx.conf")
//
// # Security
//
// Sandboxes are secured through defense in depth:
// - Unprivileged LXC containers (UID mapping to high host UIDs)
// - Isolated network bridge with nftables default-deny outbound
// - Per-container opt-in internet access (HTTP/HTTPS only)
// - Resource limits: CPU, memory, disk, PID count
// - AppArmor confinement (lxc-container-default-cgns)
// - Capability dropping (sys_admin, sys_rawio, sys_ptrace, etc.)
//
// See docs/sandbox-setup.md for the complete Proxmox setup and hardening guide.
package sandbox

410
v2/sandbox/proxmox.go Normal file
View File

@@ -0,0 +1,410 @@
package sandbox
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// ProxmoxConfig holds configuration for connecting to a Proxmox VE host.
type ProxmoxConfig struct {
// BaseURL is the Proxmox API base URL (e.g., "https://proxmox.local:8006").
BaseURL string
// TokenID is the API token identifier (e.g., "mort-sandbox@pve!sandbox-token").
TokenID string
// Secret is the API token secret.
Secret string
// Node is the Proxmox node name (e.g., "pve").
Node string
// TemplateID is the LXC template container ID to clone from (e.g., 9000).
TemplateID int
// Pool is the Proxmox resource pool for sandbox containers (e.g., "sandbox-pool").
Pool string
// Bridge is the network bridge for containers (e.g., "vmbr1").
Bridge string
// InsecureSkipVerify disables TLS certificate verification.
// Use only for self-signed Proxmox certificates.
InsecureSkipVerify bool
}
// ContainerStatus represents the current state of a Proxmox LXC container.
type ContainerStatus struct {
Status string `json:"status"` // "running", "stopped", etc.
CPU float64 `json:"cpu"` // CPU usage (0.01.0)
Mem int64 `json:"mem"` // Current memory usage in bytes
MaxMem int64 `json:"maxmem"` // Maximum memory in bytes
Disk int64 `json:"disk"` // Current disk usage in bytes
MaxDisk int64 `json:"maxdisk"` // Maximum disk in bytes
NetIn int64 `json:"netin"` // Network bytes received
NetOut int64 `json:"netout"` // Network bytes sent
Uptime int64 `json:"uptime"` // Uptime in seconds
}
// ContainerConfig holds settings for creating a new container.
type ContainerConfig struct {
// Hostname for the container.
Hostname string
// CPUs is the number of CPU cores (default 1).
CPUs int
// MemoryMB is the memory limit in megabytes (default 1024).
MemoryMB int
// DiskGB is the root filesystem size in gigabytes (default 8).
DiskGB int
// SSHPublicKey is an optional SSH public key to inject.
SSHPublicKey string
}
// ProxmoxClient is a thin REST API client for Proxmox VE container lifecycle management.
type ProxmoxClient struct {
config ProxmoxConfig
http *http.Client
}
// NewProxmoxClient creates a new Proxmox API client.
func NewProxmoxClient(config ProxmoxConfig) *ProxmoxClient {
transport := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: config.InsecureSkipVerify,
},
}
return &ProxmoxClient{
config: config,
http: &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
},
}
}
// NextAvailableID queries Proxmox for the next free VMID.
func (p *ProxmoxClient) NextAvailableID(ctx context.Context) (int, error) {
var result int
err := p.get(ctx, "/api2/json/cluster/nextid", &result)
if err != nil {
return 0, fmt.Errorf("get next VMID: %w", err)
}
return result, nil
}
// CloneTemplate clones the configured template into a new container with the given VMID.
func (p *ProxmoxClient) CloneTemplate(ctx context.Context, newID int, cfg ContainerConfig) error {
path := fmt.Sprintf("/api2/json/nodes/%s/lxc/%d/clone", p.config.Node, p.config.TemplateID)
hostname := cfg.Hostname
if hostname == "" {
hostname = fmt.Sprintf("sandbox-%d", newID)
}
params := url.Values{
"newid": {fmt.Sprintf("%d", newID)},
"hostname": {hostname},
"full": {"1"},
}
if p.config.Pool != "" {
params.Set("pool", p.config.Pool)
}
taskID, err := p.post(ctx, path, params)
if err != nil {
return fmt.Errorf("clone template %d → %d: %w", p.config.TemplateID, newID, err)
}
return p.waitForTask(ctx, taskID)
}
// ConfigureContainer sets CPU, memory, and network on an existing container.
func (p *ProxmoxClient) ConfigureContainer(ctx context.Context, id int, cfg ContainerConfig) error {
path := fmt.Sprintf("/api2/json/nodes/%s/lxc/%d/config", p.config.Node, id)
cpus := cfg.CPUs
if cpus <= 0 {
cpus = 1
}
mem := cfg.MemoryMB
if mem <= 0 {
mem = 1024
}
params := url.Values{
"cores": {fmt.Sprintf("%d", cpus)},
"memory": {fmt.Sprintf("%d", mem)},
"swap": {"0"},
"net0": {fmt.Sprintf("name=eth0,bridge=%s,ip=dhcp", p.config.Bridge)},
}
_, err := p.put(ctx, path, params)
if err != nil {
return fmt.Errorf("configure container %d: %w", id, err)
}
return nil
}
// StartContainer starts a stopped container.
func (p *ProxmoxClient) StartContainer(ctx context.Context, id int) error {
path := fmt.Sprintf("/api2/json/nodes/%s/lxc/%d/status/start", p.config.Node, id)
taskID, err := p.post(ctx, path, nil)
if err != nil {
return fmt.Errorf("start container %d: %w", id, err)
}
return p.waitForTask(ctx, taskID)
}
// StopContainer stops a running container.
func (p *ProxmoxClient) StopContainer(ctx context.Context, id int) error {
path := fmt.Sprintf("/api2/json/nodes/%s/lxc/%d/status/stop", p.config.Node, id)
taskID, err := p.post(ctx, path, nil)
if err != nil {
return fmt.Errorf("stop container %d: %w", id, err)
}
return p.waitForTask(ctx, taskID)
}
// DestroyContainer stops (if running) and permanently deletes a container.
func (p *ProxmoxClient) DestroyContainer(ctx context.Context, id int) error {
// Try to stop first; ignore errors (might already be stopped).
status, err := p.GetContainerStatus(ctx, id)
if err != nil {
return fmt.Errorf("get status before destroy: %w", err)
}
if status.Status == "running" {
_ = p.StopContainer(ctx, id)
}
path := fmt.Sprintf("/api2/json/nodes/%s/lxc/%d", p.config.Node, id)
params := url.Values{"force": {"1"}, "purge": {"1"}}
taskID, err := p.delete(ctx, path, params)
if err != nil {
return fmt.Errorf("destroy container %d: %w", id, err)
}
return p.waitForTask(ctx, taskID)
}
// GetContainerStatus returns the current status and resource usage of a container.
func (p *ProxmoxClient) GetContainerStatus(ctx context.Context, id int) (ContainerStatus, error) {
path := fmt.Sprintf("/api2/json/nodes/%s/lxc/%d/status/current", p.config.Node, id)
var status ContainerStatus
if err := p.get(ctx, path, &status); err != nil {
return ContainerStatus{}, fmt.Errorf("get container %d status: %w", id, err)
}
return status, nil
}
// GetContainerIP discovers the container's IP address by querying its network interfaces.
// It polls until an IP is found or the context is cancelled.
func (p *ProxmoxClient) GetContainerIP(ctx context.Context, id int) (string, error) {
path := fmt.Sprintf("/api2/json/nodes/%s/lxc/%d/interfaces", p.config.Node, id)
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
var ifaces []struct {
Name string `json:"name"`
HWAddr string `json:"hwaddr"`
Inet string `json:"inet"`
Inet6 string `json:"inet6"`
}
if err := p.get(ctx, path, &ifaces); err == nil {
for _, iface := range ifaces {
if iface.Name == "lo" || iface.Inet == "" {
continue
}
// Inet is in CIDR format (e.g., "10.99.1.5/16")
ip := iface.Inet
if idx := strings.IndexByte(ip, '/'); idx > 0 {
ip = ip[:idx]
}
return ip, nil
}
}
select {
case <-ctx.Done():
return "", fmt.Errorf("get container %d IP: %w", id, ctx.Err())
case <-ticker.C:
}
}
}
// EnableInternet adds a container IP to the nftables internet_allowed set,
// granting outbound HTTP/HTTPS access.
func (p *ProxmoxClient) EnableInternet(ctx context.Context, containerIP string) error {
return p.execOnHost(ctx, fmt.Sprintf("nft add element inet sandbox internet_allowed { %s }", containerIP))
}
// DisableInternet removes a container IP from the nftables internet_allowed set,
// revoking outbound HTTP/HTTPS access.
func (p *ProxmoxClient) DisableInternet(ctx context.Context, containerIP string) error {
return p.execOnHost(ctx, fmt.Sprintf("nft delete element inet sandbox internet_allowed { %s }", containerIP))
}
// execOnHost runs a command on the Proxmox host via the API's node exec endpoint.
func (p *ProxmoxClient) execOnHost(ctx context.Context, command string) error {
path := fmt.Sprintf("/api2/json/nodes/%s/execute", p.config.Node)
params := url.Values{"commands": {command}}
_, err := p.post(ctx, path, params)
if err != nil {
return fmt.Errorf("exec on host: %w", err)
}
return nil
}
// --- HTTP helpers ---
// proxmoxResponse is the standard envelope for all Proxmox API responses.
type proxmoxResponse struct {
Data json.RawMessage `json:"data"`
}
func (p *ProxmoxClient) doRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
u := strings.TrimRight(p.config.BaseURL, "/") + path
req, err := http.NewRequestWithContext(ctx, method, u, body)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("PVEAPIToken=%s=%s", p.config.TokenID, p.config.Secret))
if body != nil {
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
}
resp, err := p.http.Do(req)
if err != nil {
return nil, err
}
return resp, nil
}
func (p *ProxmoxClient) get(ctx context.Context, path string, result any) error {
resp, err := p.doRequest(ctx, http.MethodGet, path, nil)
if err != nil {
return err
}
defer resp.Body.Close()
return p.parseResponse(resp, result)
}
func (p *ProxmoxClient) post(ctx context.Context, path string, params url.Values) (string, error) {
var body io.Reader
if params != nil {
body = strings.NewReader(params.Encode())
}
resp, err := p.doRequest(ctx, http.MethodPost, path, body)
if err != nil {
return "", err
}
defer resp.Body.Close()
var taskID string
if err := p.parseResponse(resp, &taskID); err != nil {
return "", err
}
return taskID, nil
}
func (p *ProxmoxClient) put(ctx context.Context, path string, params url.Values) (string, error) {
var body io.Reader
if params != nil {
body = strings.NewReader(params.Encode())
}
resp, err := p.doRequest(ctx, http.MethodPut, path, body)
if err != nil {
return "", err
}
defer resp.Body.Close()
var result string
if err := p.parseResponse(resp, &result); err != nil {
return "", err
}
return result, nil
}
func (p *ProxmoxClient) delete(ctx context.Context, path string, params url.Values) (string, error) {
path = path + "?" + params.Encode()
resp, err := p.doRequest(ctx, http.MethodDelete, path, nil)
if err != nil {
return "", err
}
defer resp.Body.Close()
var taskID string
if err := p.parseResponse(resp, &taskID); err != nil {
return "", err
}
return taskID, nil
}
func (p *ProxmoxClient) parseResponse(resp *http.Response, result any) error {
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
bodyBytes, _ := io.ReadAll(resp.Body)
return fmt.Errorf("proxmox API error (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
}
var envelope proxmoxResponse
if err := json.NewDecoder(resp.Body).Decode(&envelope); err != nil {
return fmt.Errorf("decode response: %w", err)
}
if result == nil {
return nil
}
if err := json.Unmarshal(envelope.Data, result); err != nil {
return fmt.Errorf("unmarshal data: %w", err)
}
return nil
}
// waitForTask polls a Proxmox task until it completes or the context is cancelled.
func (p *ProxmoxClient) waitForTask(ctx context.Context, taskID string) error {
if taskID == "" {
return nil
}
path := fmt.Sprintf("/api2/json/nodes/%s/tasks/%s/status", p.config.Node, url.PathEscape(taskID))
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
var status struct {
Status string `json:"status"` // "running", "stopped", etc.
ExitCode string `json:"exitstatus"`
}
if err := p.get(ctx, path, &status); err != nil {
return fmt.Errorf("poll task %s: %w", taskID, err)
}
if status.Status != "running" {
if status.ExitCode != "OK" && status.ExitCode != "" {
return fmt.Errorf("task %s failed: %s", taskID, status.ExitCode)
}
return nil
}
select {
case <-ctx.Done():
return fmt.Errorf("wait for task %s: %w", taskID, ctx.Err())
case <-ticker.C:
}
}
}

310
v2/sandbox/sandbox.go Normal file
View File

@@ -0,0 +1,310 @@
package sandbox
import (
"context"
"fmt"
"io"
"os"
"strings"
"time"
"golang.org/x/crypto/ssh"
)
// Config holds all configuration for creating sandboxes.
type Config struct {
Proxmox ProxmoxConfig
SSH SSHConfig
Defaults ContainerConfig
}
// Option configures a Sandbox before creation.
type Option func(*createOpts)
type createOpts struct {
hostname string
cpus int
memoryMB int
diskGB int
internet bool
}
// WithHostname sets the container hostname.
func WithHostname(name string) Option {
return func(o *createOpts) { o.hostname = name }
}
// WithCPUs sets the number of CPU cores for the container.
func WithCPUs(n int) Option {
return func(o *createOpts) { o.cpus = n }
}
// WithMemoryMB sets the memory limit in megabytes.
func WithMemoryMB(mb int) Option {
return func(o *createOpts) { o.memoryMB = mb }
}
// WithDiskGB sets the root filesystem size in gigabytes.
func WithDiskGB(gb int) Option {
return func(o *createOpts) { o.diskGB = gb }
}
// WithInternet enables outbound HTTP/HTTPS access on creation.
func WithInternet(enabled bool) Option {
return func(o *createOpts) { o.internet = enabled }
}
// Sandbox represents an isolated Linux container environment with SSH access.
// It wraps a Proxmox LXC container and provides command execution and file operations.
type Sandbox struct {
// ID is the Proxmox VMID of this container.
ID int
// IP is the container's IP address on the isolated bridge.
IP string
// Internet indicates whether outbound HTTP/HTTPS is enabled.
Internet bool
proxmox *ProxmoxClient
ssh *SSHExecutor
}
// Manager creates and manages sandbox instances.
type Manager struct {
proxmox *ProxmoxClient
sshKey ssh.Signer
defaults ContainerConfig
sshCfg SSHConfig
}
// NewManager creates a new sandbox manager from the given configuration.
func NewManager(cfg Config) (*Manager, error) {
if cfg.SSH.Signer == nil {
return nil, fmt.Errorf("SSH signer is required")
}
return &Manager{
proxmox: NewProxmoxClient(cfg.Proxmox),
sshKey: cfg.SSH.Signer,
defaults: cfg.Defaults,
sshCfg: cfg.SSH,
}, nil
}
// Create provisions a new sandbox container: clones the template, starts it,
// waits for SSH, and optionally enables internet access.
// The returned Sandbox must be destroyed with Destroy when no longer needed.
func (m *Manager) Create(ctx context.Context, opts ...Option) (*Sandbox, error) {
o := &createOpts{
hostname: m.defaults.Hostname,
cpus: m.defaults.CPUs,
memoryMB: m.defaults.MemoryMB,
diskGB: m.defaults.DiskGB,
}
for _, opt := range opts {
opt(o)
}
// Apply defaults for zero values.
if o.cpus <= 0 {
o.cpus = 1
}
if o.memoryMB <= 0 {
o.memoryMB = 1024
}
if o.diskGB <= 0 {
o.diskGB = 8
}
// Get next VMID.
vmid, err := m.proxmox.NextAvailableID(ctx)
if err != nil {
return nil, fmt.Errorf("get next VMID: %w", err)
}
containerCfg := ContainerConfig{
Hostname: o.hostname,
CPUs: o.cpus,
MemoryMB: o.memoryMB,
DiskGB: o.diskGB,
}
// Clone template.
if err := m.proxmox.CloneTemplate(ctx, vmid, containerCfg); err != nil {
return nil, fmt.Errorf("clone template: %w", err)
}
// Configure container resources.
if err := m.proxmox.ConfigureContainer(ctx, vmid, containerCfg); err != nil {
// Clean up the cloned container on failure.
_ = m.proxmox.DestroyContainer(ctx, vmid)
return nil, fmt.Errorf("configure container: %w", err)
}
// Start container.
if err := m.proxmox.StartContainer(ctx, vmid); err != nil {
_ = m.proxmox.DestroyContainer(ctx, vmid)
return nil, fmt.Errorf("start container: %w", err)
}
// Discover IP address (with timeout).
ipCtx, ipCancel := context.WithTimeout(ctx, 30*time.Second)
defer ipCancel()
ip, err := m.proxmox.GetContainerIP(ipCtx, vmid)
if err != nil {
_ = m.proxmox.DestroyContainer(ctx, vmid)
return nil, fmt.Errorf("discover IP: %w", err)
}
// Connect SSH (with timeout).
sshExec := NewSSHExecutor(ip, m.sshCfg)
sshCtx, sshCancel := context.WithTimeout(ctx, 30*time.Second)
defer sshCancel()
if err := sshExec.Connect(sshCtx); err != nil {
_ = m.proxmox.DestroyContainer(ctx, vmid)
return nil, fmt.Errorf("ssh connect: %w", err)
}
sb := &Sandbox{
ID: vmid,
IP: ip,
proxmox: m.proxmox,
ssh: sshExec,
}
// Enable internet if requested.
if o.internet {
if err := sb.SetInternet(ctx, true); err != nil {
sb.Destroy(ctx)
return nil, fmt.Errorf("enable internet: %w", err)
}
}
return sb, nil
}
// Attach reconnects to an existing sandbox container by VMID.
// This is useful for recovering sessions after a restart.
func (m *Manager) Attach(ctx context.Context, vmid int) (*Sandbox, error) {
status, err := m.proxmox.GetContainerStatus(ctx, vmid)
if err != nil {
return nil, fmt.Errorf("get container status: %w", err)
}
if status.Status != "running" {
return nil, fmt.Errorf("container %d is not running (status: %s)", vmid, status.Status)
}
ip, err := m.proxmox.GetContainerIP(ctx, vmid)
if err != nil {
return nil, fmt.Errorf("get container IP: %w", err)
}
sshExec := NewSSHExecutor(ip, m.sshCfg)
if err := sshExec.Connect(ctx); err != nil {
return nil, fmt.Errorf("ssh connect: %w", err)
}
return &Sandbox{
ID: vmid,
IP: ip,
proxmox: m.proxmox,
ssh: sshExec,
}, nil
}
// Exec runs a shell command in the sandbox and returns the result.
func (s *Sandbox) Exec(ctx context.Context, command string) (ExecResult, error) {
return s.ssh.Exec(ctx, command)
}
// WriteFile creates or overwrites a file in the sandbox.
func (s *Sandbox) WriteFile(ctx context.Context, path, content string) error {
return s.ssh.Upload(ctx, strings.NewReader(content), path, 0644)
}
// ReadFile reads a file from the sandbox and returns its contents.
func (s *Sandbox) ReadFile(ctx context.Context, path string) (string, error) {
rc, err := s.ssh.Download(ctx, path)
if err != nil {
return "", err
}
defer rc.Close()
data, err := io.ReadAll(rc)
if err != nil {
return "", fmt.Errorf("read file %s: %w", path, err)
}
return string(data), nil
}
// Upload copies data from an io.Reader to a file in the sandbox.
func (s *Sandbox) Upload(ctx context.Context, reader io.Reader, remotePath string, mode os.FileMode) error {
return s.ssh.Upload(ctx, reader, remotePath, mode)
}
// Download returns an io.ReadCloser for a file in the sandbox.
// The caller must close the returned reader.
func (s *Sandbox) Download(ctx context.Context, remotePath string) (io.ReadCloser, error) {
return s.ssh.Download(ctx, remotePath)
}
// SetInternet enables or disables outbound HTTP/HTTPS access for the sandbox.
func (s *Sandbox) SetInternet(ctx context.Context, enabled bool) error {
if enabled {
if err := s.proxmox.EnableInternet(ctx, s.IP); err != nil {
return err
}
} else {
if err := s.proxmox.DisableInternet(ctx, s.IP); err != nil {
return err
}
}
s.Internet = enabled
return nil
}
// Status returns the current resource usage of the sandbox container.
func (s *Sandbox) Status(ctx context.Context) (ContainerStatus, error) {
return s.proxmox.GetContainerStatus(ctx, s.ID)
}
// IsConnected returns true if the SSH connection to the sandbox is active.
func (s *Sandbox) IsConnected() bool {
return s.ssh.IsConnected()
}
// Destroy stops the container, removes internet access, closes SSH connections,
// and permanently deletes the container from Proxmox.
func (s *Sandbox) Destroy(ctx context.Context) error {
var errs []error
// Remove internet access first (ignore errors — container is being destroyed).
if s.Internet {
_ = s.proxmox.DisableInternet(ctx, s.IP)
}
// Close SSH connections.
if err := s.ssh.Close(); err != nil {
errs = append(errs, fmt.Errorf("close ssh: %w", err))
}
// Destroy the container.
if err := s.proxmox.DestroyContainer(ctx, s.ID); err != nil {
errs = append(errs, fmt.Errorf("destroy container: %w", err))
}
if len(errs) > 0 {
return fmt.Errorf("destroy sandbox %d: %v", s.ID, errs)
}
return nil
}
// DestroyByID destroys a container by VMID without requiring an active SSH connection.
// This is useful for cleaning up orphaned containers after a restart.
func (m *Manager) DestroyByID(ctx context.Context, vmid int) error {
return m.proxmox.DestroyContainer(ctx, vmid)
}

1996
v2/sandbox/sandbox_test.go Normal file

File diff suppressed because it is too large Load Diff

253
v2/sandbox/ssh.go Normal file
View File

@@ -0,0 +1,253 @@
package sandbox
import (
"bytes"
"context"
"fmt"
"io"
"net"
"os"
"sync"
"time"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
// SSHConfig holds configuration for SSH connections to sandbox containers.
type SSHConfig struct {
// User is the SSH username (default "sandbox").
User string
// Signer is the SSH private key signer for authentication.
Signer ssh.Signer
// ConnectTimeout is the maximum time to wait for an SSH connection (default 10s).
ConnectTimeout time.Duration
// CommandTimeout is the default maximum time for a single command execution (default 60s).
CommandTimeout time.Duration
}
// SSHExecutor manages SSH and SFTP connections to a sandbox container.
type SSHExecutor struct {
host string
config SSHConfig
mu sync.Mutex
sshClient *ssh.Client
sftpClient *sftp.Client
}
// NewSSHExecutor creates a new SSH executor for the given host.
func NewSSHExecutor(host string, config SSHConfig) *SSHExecutor {
if config.User == "" {
config.User = "sandbox"
}
if config.ConnectTimeout <= 0 {
config.ConnectTimeout = 10 * time.Second
}
if config.CommandTimeout <= 0 {
config.CommandTimeout = 60 * time.Second
}
return &SSHExecutor{
host: host,
config: config,
}
}
// Connect establishes SSH and SFTP connections to the container.
// It polls until the connection succeeds or the context is cancelled,
// which is useful when waiting for a freshly started container to boot.
func (s *SSHExecutor) Connect(ctx context.Context) error {
sshConfig := &ssh.ClientConfig{
User: s.config.User,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(s.config.Signer),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: s.config.ConnectTimeout,
}
addr := net.JoinHostPort(s.host, "22")
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
var lastErr error
for {
client, err := ssh.Dial("tcp", addr, sshConfig)
if err == nil {
sftpClient, err := sftp.NewClient(client)
if err != nil {
client.Close()
return fmt.Errorf("create SFTP client: %w", err)
}
s.mu.Lock()
s.sshClient = client
s.sftpClient = sftpClient
s.mu.Unlock()
return nil
}
lastErr = err
select {
case <-ctx.Done():
return fmt.Errorf("ssh connect to %s: %w (last error: %v)", addr, ctx.Err(), lastErr)
case <-ticker.C:
}
}
}
// ExecResult contains the output and exit status of a command execution.
type ExecResult struct {
Output string
ExitCode int
}
// Exec runs a shell command on the container and returns the combined stdout/stderr
// output and exit code.
func (s *SSHExecutor) Exec(ctx context.Context, command string) (ExecResult, error) {
s.mu.Lock()
client := s.sshClient
s.mu.Unlock()
if client == nil {
return ExecResult{}, fmt.Errorf("ssh not connected")
}
session, err := client.NewSession()
if err != nil {
return ExecResult{}, fmt.Errorf("create session: %w", err)
}
defer session.Close()
var buf bytes.Buffer
session.Stdout = &buf
session.Stderr = &buf
// Apply context timeout.
done := make(chan error, 1)
go func() {
done <- session.Run(command)
}()
select {
case <-ctx.Done():
_ = session.Signal(ssh.SIGKILL)
return ExecResult{}, fmt.Errorf("exec timed out: %w", ctx.Err())
case err := <-done:
output := buf.String()
if err != nil {
if exitErr, ok := err.(*ssh.ExitError); ok {
return ExecResult{
Output: output,
ExitCode: exitErr.ExitStatus(),
}, nil
}
return ExecResult{Output: output}, fmt.Errorf("exec: %w", err)
}
return ExecResult{Output: output, ExitCode: 0}, nil
}
}
// Upload writes data from an io.Reader to a file on the container.
func (s *SSHExecutor) Upload(ctx context.Context, reader io.Reader, remotePath string, mode os.FileMode) error {
s.mu.Lock()
client := s.sftpClient
s.mu.Unlock()
if client == nil {
return fmt.Errorf("sftp not connected")
}
f, err := client.OpenFile(remotePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC)
if err != nil {
return fmt.Errorf("open remote file %s: %w", remotePath, err)
}
defer f.Close()
if _, err := io.Copy(f, reader); err != nil {
return fmt.Errorf("write to %s: %w", remotePath, err)
}
if err := client.Chmod(remotePath, mode); err != nil {
return fmt.Errorf("chmod %s: %w", remotePath, err)
}
return nil
}
// Download reads a file from the container and returns its contents as an io.ReadCloser.
// The caller must close the returned reader.
func (s *SSHExecutor) Download(ctx context.Context, remotePath string) (io.ReadCloser, error) {
s.mu.Lock()
client := s.sftpClient
s.mu.Unlock()
if client == nil {
return nil, fmt.Errorf("sftp not connected")
}
f, err := client.Open(remotePath)
if err != nil {
return nil, fmt.Errorf("open remote file %s: %w", remotePath, err)
}
return f, nil
}
// Close tears down both SFTP and SSH connections.
func (s *SSHExecutor) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
var errs []error
if s.sftpClient != nil {
if err := s.sftpClient.Close(); err != nil {
errs = append(errs, fmt.Errorf("close SFTP: %w", err))
}
s.sftpClient = nil
}
if s.sshClient != nil {
if err := s.sshClient.Close(); err != nil {
errs = append(errs, fmt.Errorf("close SSH: %w", err))
}
s.sshClient = nil
}
if len(errs) > 0 {
return fmt.Errorf("close ssh executor: %v", errs)
}
return nil
}
// IsConnected returns true if the SSH connection is established.
func (s *SSHExecutor) IsConnected() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.sshClient != nil
}
// LoadSSHKey reads a PEM-encoded private key file and returns an ssh.Signer.
func LoadSSHKey(path string) (ssh.Signer, error) {
keyData, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read SSH key %s: %w", path, err)
}
signer, err := ssh.ParsePrivateKey(keyData)
if err != nil {
return nil, fmt.Errorf("parse SSH key: %w", err)
}
return signer, nil
}
// ParseSSHKey parses a PEM-encoded private key from bytes and returns an ssh.Signer.
func ParseSSHKey(pemBytes []byte) (ssh.Signer, error) {
signer, err := ssh.ParsePrivateKey(pemBytes)
if err != nil {
return nil, fmt.Errorf("parse SSH key: %w", err)
}
return signer, nil
}

163
v2/stream.go Normal file
View File

@@ -0,0 +1,163 @@
package llm
import (
"context"
"fmt"
"io"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
// StreamEventType identifies the kind of stream event.
type StreamEventType = provider.StreamEventType
const (
StreamEventText = provider.StreamEventText
StreamEventToolStart = provider.StreamEventToolStart
StreamEventToolDelta = provider.StreamEventToolDelta
StreamEventToolEnd = provider.StreamEventToolEnd
StreamEventDone = provider.StreamEventDone
StreamEventError = provider.StreamEventError
)
// StreamEvent represents a single event in a streaming response.
type StreamEvent struct {
Type StreamEventType
// Text is set for StreamEventText — the text delta.
Text string
// ToolCall is set for StreamEventToolStart/ToolDelta/ToolEnd.
ToolCall *ToolCall
// ToolIndex identifies which tool call is being updated.
ToolIndex int
// Error is set for StreamEventError.
Error error
// Response is set for StreamEventDone — the complete, aggregated response.
Response *Response
}
// StreamReader reads streaming events from an LLM response.
// Must be closed when done.
type StreamReader struct {
events <-chan StreamEvent
cancel context.CancelFunc
done bool
}
func newStreamReader(ctx context.Context, p provider.Provider, req provider.Request) (*StreamReader, error) {
ctx, cancel := context.WithCancel(ctx)
providerEvents := make(chan provider.StreamEvent, 32)
publicEvents := make(chan StreamEvent, 32)
go func() {
defer close(publicEvents)
for pev := range providerEvents {
ev := convertStreamEvent(pev)
select {
case publicEvents <- ev:
case <-ctx.Done():
return
}
}
}()
go func() {
defer close(providerEvents)
if err := p.Stream(ctx, req, providerEvents); err != nil {
select {
case providerEvents <- provider.StreamEvent{Type: provider.StreamEventError, Error: err}:
default:
}
}
}()
return &StreamReader{
events: publicEvents,
cancel: cancel,
}, nil
}
func convertStreamEvent(pev provider.StreamEvent) StreamEvent {
ev := StreamEvent{
Type: pev.Type,
Text: pev.Text,
ToolIndex: pev.ToolIndex,
}
if pev.Error != nil {
ev.Error = pev.Error
}
if pev.ToolCall != nil {
tc := ToolCall{
ID: pev.ToolCall.ID,
Name: pev.ToolCall.Name,
Arguments: pev.ToolCall.Arguments,
}
ev.ToolCall = &tc
}
if pev.Response != nil {
resp := convertProviderResponse(*pev.Response)
ev.Response = &resp
}
return ev
}
// Next returns the next event from the stream.
// Returns io.EOF when the stream is complete.
func (sr *StreamReader) Next() (StreamEvent, error) {
if sr.done {
return StreamEvent{}, io.EOF
}
ev, ok := <-sr.events
if !ok {
sr.done = true
return StreamEvent{}, io.EOF
}
if ev.Type == StreamEventError {
return ev, ev.Error
}
if ev.Type == StreamEventDone {
sr.done = true
}
return ev, nil
}
// Close closes the stream reader and releases resources.
func (sr *StreamReader) Close() error {
sr.cancel()
return nil
}
// Collect reads all events and returns the final aggregated Response.
func (sr *StreamReader) Collect() (Response, error) {
var lastResp *Response
for {
ev, err := sr.Next()
if err == io.EOF {
break
}
if err != nil {
return Response{}, err
}
if ev.Type == StreamEventDone && ev.Response != nil {
lastResp = ev.Response
}
}
if lastResp == nil {
return Response{}, fmt.Errorf("stream completed without final response")
}
return *lastResp, nil
}
// Text is a convenience that collects the stream and returns just the text.
func (sr *StreamReader) Text() (string, error) {
resp, err := sr.Collect()
if err != nil {
return "", err
}
return resp.Text, nil
}

338
v2/stream_test.go Normal file
View File

@@ -0,0 +1,338 @@
package llm
import (
"context"
"errors"
"io"
"testing"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
func TestStreamReader_TextEvents(t *testing.T) {
events := []provider.StreamEvent{
{Type: provider.StreamEventText, Text: "Hello"},
{Type: provider.StreamEventText, Text: " world"},
{Type: provider.StreamEventDone, Response: &provider.Response{
Text: "Hello world",
Usage: &provider.Usage{
InputTokens: 5,
OutputTokens: 2,
TotalTokens: 7,
},
}},
}
mp := newMockStreamProvider(events)
model := newMockModel(mp)
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer reader.Close()
// Read text events
ev, err := reader.Next()
if err != nil {
t.Fatalf("unexpected error on first event: %v", err)
}
if ev.Type != StreamEventText || ev.Text != "Hello" {
t.Errorf("expected text event 'Hello', got type=%d text=%q", ev.Type, ev.Text)
}
ev, err = reader.Next()
if err != nil {
t.Fatalf("unexpected error on second event: %v", err)
}
if ev.Type != StreamEventText || ev.Text != " world" {
t.Errorf("expected text event ' world', got type=%d text=%q", ev.Type, ev.Text)
}
// Read done event
ev, err = reader.Next()
if err != nil {
t.Fatalf("unexpected error on done event: %v", err)
}
if ev.Type != StreamEventDone {
t.Errorf("expected done event, got type=%d", ev.Type)
}
if ev.Response == nil {
t.Fatal("expected response in done event")
}
if ev.Response.Text != "Hello world" {
t.Errorf("expected final text 'Hello world', got %q", ev.Response.Text)
}
// Subsequent reads should return EOF
_, err = reader.Next()
if !errors.Is(err, io.EOF) {
t.Errorf("expected io.EOF after done, got %v", err)
}
}
func TestStreamReader_ToolCallEvents(t *testing.T) {
events := []provider.StreamEvent{
{
Type: provider.StreamEventToolStart,
ToolIndex: 0,
ToolCall: &provider.ToolCall{ID: "tc1", Name: "search"},
},
{
Type: provider.StreamEventToolDelta,
ToolIndex: 0,
ToolCall: &provider.ToolCall{Arguments: `{"query":`},
},
{
Type: provider.StreamEventToolDelta,
ToolIndex: 0,
ToolCall: &provider.ToolCall{Arguments: `"test"}`},
},
{
Type: provider.StreamEventToolEnd,
ToolIndex: 0,
ToolCall: &provider.ToolCall{ID: "tc1", Name: "search", Arguments: `{"query":"test"}`},
},
{
Type: provider.StreamEventDone,
Response: &provider.Response{
ToolCalls: []provider.ToolCall{
{ID: "tc1", Name: "search", Arguments: `{"query":"test"}`},
},
},
},
}
mp := newMockStreamProvider(events)
model := newMockModel(mp)
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer reader.Close()
// Read tool start
ev, err := reader.Next()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ev.Type != StreamEventToolStart {
t.Errorf("expected tool start, got type=%d", ev.Type)
}
if ev.ToolCall == nil || ev.ToolCall.Name != "search" {
t.Errorf("expected tool call 'search', got %+v", ev.ToolCall)
}
// Read tool deltas
ev, _ = reader.Next()
if ev.Type != StreamEventToolDelta {
t.Errorf("expected tool delta, got type=%d", ev.Type)
}
ev, _ = reader.Next()
if ev.Type != StreamEventToolDelta {
t.Errorf("expected tool delta, got type=%d", ev.Type)
}
// Read tool end
ev, _ = reader.Next()
if ev.Type != StreamEventToolEnd {
t.Errorf("expected tool end, got type=%d", ev.Type)
}
if ev.ToolCall == nil || ev.ToolCall.Arguments != `{"query":"test"}` {
t.Errorf("expected complete arguments, got %+v", ev.ToolCall)
}
// Read done
ev, _ = reader.Next()
if ev.Type != StreamEventDone {
t.Errorf("expected done, got type=%d", ev.Type)
}
if ev.Response == nil || len(ev.Response.ToolCalls) != 1 {
t.Error("expected response with 1 tool call")
}
}
func TestStreamReader_Error(t *testing.T) {
streamErr := errors.New("stream failed")
mp := &mockProvider{
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, nil
},
StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error {
ch <- provider.StreamEvent{Type: provider.StreamEventText, Text: "partial"}
ch <- provider.StreamEvent{Type: provider.StreamEventError, Error: streamErr}
return nil
},
}
model := newMockModel(mp)
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer reader.Close()
// Read partial text
ev, err := reader.Next()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ev.Text != "partial" {
t.Errorf("expected 'partial', got %q", ev.Text)
}
// Read error
_, err = reader.Next()
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, streamErr) {
t.Errorf("expected stream error, got %v", err)
}
}
func TestStreamReader_Close(t *testing.T) {
// Create a stream that sends one event then blocks until context is cancelled
mp := &mockProvider{
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, nil
},
StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error {
ch <- provider.StreamEvent{Type: provider.StreamEventText, Text: "start"}
<-ctx.Done()
return ctx.Err()
},
}
model := newMockModel(mp)
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Read the first event
ev, err := reader.Next()
if err != nil {
t.Fatalf("unexpected error on first event: %v", err)
}
if ev.Text != "start" {
t.Errorf("expected 'start', got %q", ev.Text)
}
// Close should cancel context
if err := reader.Close(); err != nil {
t.Fatalf("close error: %v", err)
}
// After close, Next should eventually terminate with either EOF or context error.
// The exact behavior depends on goroutine scheduling: the channel may close (EOF)
// or the error event from the cancelled context may arrive first.
_, err = reader.Next()
if err == nil {
t.Error("expected error after close, got nil")
}
}
func TestStreamReader_Collect(t *testing.T) {
events := []provider.StreamEvent{
{Type: provider.StreamEventText, Text: "Hello"},
{Type: provider.StreamEventText, Text: " world"},
{Type: provider.StreamEventDone, Response: &provider.Response{
Text: "Hello world",
Usage: &provider.Usage{
InputTokens: 10,
OutputTokens: 2,
TotalTokens: 12,
},
}},
}
mp := newMockStreamProvider(events)
model := newMockModel(mp)
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer reader.Close()
resp, err := reader.Collect()
if err != nil {
t.Fatalf("collect error: %v", err)
}
if resp.Text != "Hello world" {
t.Errorf("expected 'Hello world', got %q", resp.Text)
}
if resp.Usage == nil {
t.Fatal("expected usage")
}
if resp.Usage.InputTokens != 10 {
t.Errorf("expected 10 input tokens, got %d", resp.Usage.InputTokens)
}
}
func TestStreamReader_Text(t *testing.T) {
events := []provider.StreamEvent{
{Type: provider.StreamEventText, Text: "result"},
{Type: provider.StreamEventDone, Response: &provider.Response{Text: "result"}},
}
mp := newMockStreamProvider(events)
model := newMockModel(mp)
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer reader.Close()
text, err := reader.Text()
if err != nil {
t.Fatalf("text error: %v", err)
}
if text != "result" {
t.Errorf("expected 'result', got %q", text)
}
}
func TestStreamReader_EmptyStream(t *testing.T) {
// Stream that completes without a done event (no response)
mp := newMockStreamProvider([]provider.StreamEvent{
{Type: provider.StreamEventText, Text: "hi"},
})
model := newMockModel(mp)
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer reader.Close()
_, err = reader.Collect()
if err == nil {
t.Fatal("expected error for stream without done event")
}
}
func TestStreamReader_StreamFuncError(t *testing.T) {
// Stream function returns error directly
mp := &mockProvider{
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, nil
},
StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error {
return errors.New("stream init failed")
},
}
model := newMockModel(mp)
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error creating reader: %v", err)
}
defer reader.Close()
// The error should come through as an error event
_, err = reader.Collect()
if err == nil {
t.Fatal("expected error from stream function")
}
}

193
v2/tool.go Normal file
View File

@@ -0,0 +1,193 @@
package llm
import (
"context"
"encoding/json"
"fmt"
"reflect"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/internal/schema"
)
// Tool defines a tool that the LLM can invoke.
type Tool struct {
// Name is the tool's unique identifier.
Name string
// Description tells the LLM what this tool does.
Description string
// Schema is the JSON Schema for the tool's parameters.
Schema map[string]any
// fn holds the implementation function (set via Define or DefineSimple).
fn reflect.Value
pTyp reflect.Type // nil for parameterless tools
// isMCP indicates this tool is provided by an MCP server.
isMCP bool
mcpServer *MCPServer
}
// Define creates a tool from a typed handler function.
// T must be a struct. Struct fields become the tool's parameters.
//
// Struct tags:
// - `json:"name"` — parameter name
// - `description:"..."` — parameter description
// - `enum:"a,b,c"` — enum constraint
//
// Pointer fields are optional; non-pointer fields are required.
//
// Example:
//
// type WeatherParams struct {
// City string `json:"city" description:"The city to query"`
// Unit string `json:"unit" description:"Temperature unit" enum:"celsius,fahrenheit"`
// }
//
// llm.Define[WeatherParams]("get_weather", "Get weather for a city",
// func(ctx context.Context, p WeatherParams) (string, error) {
// return fmt.Sprintf("72F in %s", p.City), nil
// },
// )
func Define[T any](name, description string, fn func(context.Context, T) (string, error)) Tool {
var zero T
return Tool{
Name: name,
Description: description,
Schema: schema.FromStruct(zero),
fn: reflect.ValueOf(fn),
pTyp: reflect.TypeOf(zero),
}
}
// DefineSimple creates a parameterless tool.
//
// Example:
//
// llm.DefineSimple("get_time", "Get the current time",
// func(ctx context.Context) (string, error) {
// return time.Now().Format(time.RFC3339), nil
// },
// )
func DefineSimple(name, description string, fn func(context.Context) (string, error)) Tool {
return Tool{
Name: name,
Description: description,
Schema: map[string]any{"type": "object", "properties": map[string]any{}},
fn: reflect.ValueOf(fn),
}
}
// Execute runs the tool with the given JSON arguments string.
func (t Tool) Execute(ctx context.Context, argsJSON string) (string, error) {
if t.isMCP {
var args map[string]any
if argsJSON != "" && argsJSON != "{}" {
if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
return "", fmt.Errorf("invalid MCP tool arguments: %w", err)
}
}
return t.mcpServer.CallTool(ctx, t.Name, args)
}
// Parameterless tool
if t.pTyp == nil {
out := t.fn.Call([]reflect.Value{reflect.ValueOf(ctx)})
if !out[1].IsNil() {
return "", out[1].Interface().(error)
}
return out[0].String(), nil
}
// Typed tool: unmarshal JSON into the struct, call the function
p := reflect.New(t.pTyp)
if argsJSON != "" && argsJSON != "{}" {
if err := json.Unmarshal([]byte(argsJSON), p.Interface()); err != nil {
return "", fmt.Errorf("invalid tool arguments: %w", err)
}
}
out := t.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()})
if !out[1].IsNil() {
return "", out[1].Interface().(error)
}
return out[0].String(), nil
}
// ToolBox is a collection of tools available for use by an LLM.
type ToolBox struct {
tools map[string]Tool
mcpServers []*MCPServer
}
// NewToolBox creates a new ToolBox from the given tools.
func NewToolBox(tools ...Tool) *ToolBox {
tb := &ToolBox{tools: make(map[string]Tool)}
for _, t := range tools {
tb.tools[t.Name] = t
}
return tb
}
// Add adds tools to the toolbox and returns it for chaining.
func (tb *ToolBox) Add(tools ...Tool) *ToolBox {
if tb.tools == nil {
tb.tools = make(map[string]Tool)
}
for _, t := range tools {
tb.tools[t.Name] = t
}
return tb
}
// AddMCP adds an MCP server's tools to the toolbox. The server must be connected.
func (tb *ToolBox) AddMCP(server *MCPServer) *ToolBox {
if tb.tools == nil {
tb.tools = make(map[string]Tool)
}
tb.mcpServers = append(tb.mcpServers, server)
for _, tool := range server.ListTools() {
tb.tools[tool.Name] = tool
}
return tb
}
// AllTools returns all tools (local + MCP) as a slice.
func (tb *ToolBox) AllTools() []Tool {
if tb == nil {
return nil
}
tools := make([]Tool, 0, len(tb.tools))
for _, t := range tb.tools {
tools = append(tools, t)
}
return tools
}
// Execute executes a tool call by name.
func (tb *ToolBox) Execute(ctx context.Context, call ToolCall) (string, error) {
if tb == nil {
return "", ErrNoToolsConfigured
}
tool, ok := tb.tools[call.Name]
if !ok {
return "", fmt.Errorf("%w: %s", ErrToolNotFound, call.Name)
}
return tool.Execute(ctx, call.Arguments)
}
// ExecuteAll executes all tool calls and returns tool result messages.
func (tb *ToolBox) ExecuteAll(ctx context.Context, calls []ToolCall) ([]Message, error) {
var results []Message
for _, call := range calls {
result, err := tb.Execute(ctx, call)
text := result
if err != nil {
text = "Error: " + err.Error()
}
results = append(results, ToolResultMessage(call.ID, text))
}
return results, nil
}

139
v2/tool_test.go Normal file
View File

@@ -0,0 +1,139 @@
package llm
import (
"context"
"encoding/json"
"testing"
)
type calcParams struct {
A float64 `json:"a" description:"First number"`
B float64 `json:"b" description:"Second number"`
Op string `json:"op" description:"Operation" enum:"add,sub,mul,div"`
}
func TestDefine(t *testing.T) {
tool := Define[calcParams]("calc", "Calculator",
func(ctx context.Context, p calcParams) (string, error) {
var result float64
switch p.Op {
case "add":
result = p.A + p.B
case "sub":
result = p.A - p.B
case "mul":
result = p.A * p.B
case "div":
result = p.A / p.B
}
b, err := json.Marshal(result)
return string(b), err
},
)
if tool.Name != "calc" {
t.Errorf("expected name 'calc', got %q", tool.Name)
}
if tool.Description != "Calculator" {
t.Errorf("expected description 'Calculator', got %q", tool.Description)
}
if tool.Schema["type"] != "object" {
t.Errorf("expected schema type=object, got %v", tool.Schema["type"])
}
// Test execution
result, err := tool.Execute(context.Background(), `{"a": 10, "b": 3, "op": "add"}`)
if err != nil {
t.Fatalf("execute failed: %v", err)
}
if result != "13" {
t.Errorf("expected '13', got %q", result)
}
}
func TestDefineSimple(t *testing.T) {
tool := DefineSimple("hello", "Say hello",
func(ctx context.Context) (string, error) {
return "Hello, world!", nil
},
)
result, err := tool.Execute(context.Background(), "")
if err != nil {
t.Fatalf("execute failed: %v", err)
}
if result != "Hello, world!" {
t.Errorf("expected 'Hello, world!', got %q", result)
}
}
func TestToolBox(t *testing.T) {
tool1 := DefineSimple("tool1", "Tool 1", func(ctx context.Context) (string, error) {
return "result1", nil
})
tool2 := DefineSimple("tool2", "Tool 2", func(ctx context.Context) (string, error) {
return "result2", nil
})
tb := NewToolBox(tool1, tool2)
tools := tb.AllTools()
if len(tools) != 2 {
t.Errorf("expected 2 tools, got %d", len(tools))
}
result, err := tb.Execute(context.Background(), ToolCall{ID: "1", Name: "tool1"})
if err != nil {
t.Fatalf("execute failed: %v", err)
}
if result != "result1" {
t.Errorf("expected 'result1', got %q", result)
}
// Test not found
_, err = tb.Execute(context.Background(), ToolCall{ID: "x", Name: "nonexistent"})
if err == nil {
t.Error("expected error for nonexistent tool")
}
}
func TestToolBoxExecuteAll(t *testing.T) {
tb := NewToolBox(
DefineSimple("t1", "T1", func(ctx context.Context) (string, error) {
return "r1", nil
}),
DefineSimple("t2", "T2", func(ctx context.Context) (string, error) {
return "r2", nil
}),
)
calls := []ToolCall{
{ID: "c1", Name: "t1"},
{ID: "c2", Name: "t2"},
}
msgs, err := tb.ExecuteAll(context.Background(), calls)
if err != nil {
t.Fatalf("execute all failed: %v", err)
}
if len(msgs) != 2 {
t.Fatalf("expected 2 messages, got %d", len(msgs))
}
if msgs[0].Role != RoleTool {
t.Errorf("expected role=tool, got %v", msgs[0].Role)
}
if msgs[0].ToolCallID != "c1" {
t.Errorf("expected toolCallID=c1, got %v", msgs[0].ToolCallID)
}
if msgs[0].Content.Text != "r1" {
t.Errorf("expected content=r1, got %v", msgs[0].Content.Text)
}
}
// jsonMarshal helper for calcParams test
func (p calcParams) jsonMarshal(result float64) (string, error) {
b, err := json.Marshal(result)
return string(b), err
}

59
v2/tools/browser.go Normal file
View File

@@ -0,0 +1,59 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
)
// BrowserParams defines parameters for the browser tool.
type BrowserParams struct {
URL string `json:"url" description:"The URL to fetch and extract text from"`
}
// Browser creates a simple web content fetcher tool.
// It fetches a URL and returns the text content.
//
// For a full headless browser, consider using an MCP server like Playwright MCP.
//
// Example:
//
// tools := llm.NewToolBox(tools.Browser())
func Browser() llm.Tool {
return llm.Define[BrowserParams]("browser", "Fetch a web page and return its text content",
func(ctx context.Context, p BrowserParams) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, p.URL, nil)
if err != nil {
return "", fmt.Errorf("creating request: %w", err)
}
req.Header.Set("User-Agent", "go-llm/2.0 (Web Fetcher)")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("fetching URL: %w", err)
}
defer resp.Body.Close()
// Limit to 1MB
limited := io.LimitReader(resp.Body, 1<<20)
body, err := io.ReadAll(limited)
if err != nil {
return "", fmt.Errorf("reading body: %w", err)
}
result := map[string]any{
"url": p.URL,
"status": resp.StatusCode,
"content_type": resp.Header.Get("Content-Type"),
"body": string(body),
}
out, _ := json.MarshalIndent(result, "", " ")
return string(out), nil
},
)
}

101
v2/tools/exec.go Normal file
View File

@@ -0,0 +1,101 @@
package tools
import (
"context"
"fmt"
"os/exec"
"runtime"
"strings"
"time"
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
)
// ExecParams defines parameters for the exec tool.
type ExecParams struct {
Command string `json:"command" description:"The shell command to execute"`
}
// ExecOption configures the exec tool.
type ExecOption func(*execConfig)
type execConfig struct {
allowedCommands []string
workDir string
timeout time.Duration
}
// WithAllowedCommands restricts which commands can be executed.
// If empty, all commands are allowed.
func WithAllowedCommands(cmds []string) ExecOption {
return func(c *execConfig) { c.allowedCommands = cmds }
}
// WithWorkDir sets the working directory for command execution.
func WithWorkDir(dir string) ExecOption {
return func(c *execConfig) { c.workDir = dir }
}
// WithExecTimeout sets the maximum execution time.
func WithExecTimeout(d time.Duration) ExecOption {
return func(c *execConfig) { c.timeout = d }
}
// Exec creates a shell command execution tool.
//
// Example:
//
// tools := llm.NewToolBox(
// tools.Exec(tools.WithAllowedCommands([]string{"ls", "cat", "grep"})),
// )
func Exec(opts ...ExecOption) llm.Tool {
cfg := &execConfig{
timeout: 30 * time.Second,
}
for _, opt := range opts {
opt(cfg)
}
return llm.Define[ExecParams]("exec", "Execute a shell command and return its output",
func(ctx context.Context, p ExecParams) (string, error) {
// Check allowed commands
if len(cfg.allowedCommands) > 0 {
parts := strings.Fields(p.Command)
if len(parts) == 0 {
return "", fmt.Errorf("empty command")
}
allowed := false
for _, cmd := range cfg.allowedCommands {
if parts[0] == cmd {
allowed = true
break
}
}
if !allowed {
return "", fmt.Errorf("command %q is not in the allowed list", parts[0])
}
}
ctx, cancel := context.WithTimeout(ctx, cfg.timeout)
defer cancel()
var cmd *exec.Cmd
if runtime.GOOS == "windows" {
cmd = exec.CommandContext(ctx, "cmd", "/C", p.Command)
} else {
cmd = exec.CommandContext(ctx, "sh", "-c", p.Command)
}
if cfg.workDir != "" {
cmd.Dir = cfg.workDir
}
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Sprintf("Error: %s\nOutput: %s", err.Error(), string(output)), nil
}
return string(output), nil
},
)
}

75
v2/tools/http.go Normal file
View File

@@ -0,0 +1,75 @@
package tools
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
)
// HTTPParams defines parameters for the HTTP request tool.
type HTTPParams struct {
Method string `json:"method" description:"HTTP method" enum:"GET,POST,PUT,DELETE,PATCH,HEAD"`
URL string `json:"url" description:"Request URL"`
Headers map[string]string `json:"headers,omitempty" description:"Request headers"`
Body *string `json:"body,omitempty" description:"Request body"`
}
// HTTP creates an HTTP request tool.
//
// Example:
//
// tools := llm.NewToolBox(tools.HTTP())
func HTTP() llm.Tool {
return llm.Define[HTTPParams]("http_request", "Make an HTTP request and return the response",
func(ctx context.Context, p HTTPParams) (string, error) {
var bodyReader io.Reader
if p.Body != nil {
bodyReader = bytes.NewBufferString(*p.Body)
}
req, err := http.NewRequestWithContext(ctx, p.Method, p.URL, bodyReader)
if err != nil {
return "", fmt.Errorf("creating request: %w", err)
}
for k, v := range p.Headers {
req.Header.Set(k, v)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
// Limit to 1MB
limited := io.LimitReader(resp.Body, 1<<20)
body, err := io.ReadAll(limited)
if err != nil {
return "", fmt.Errorf("reading response: %w", err)
}
headers := map[string]string{}
for k, v := range resp.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
result := map[string]any{
"status": resp.StatusCode,
"status_text": resp.Status,
"headers": headers,
"body": string(body),
}
out, _ := json.MarshalIndent(result, "", " ")
return string(out), nil
},
)
}

81
v2/tools/readfile.go Normal file
View File

@@ -0,0 +1,81 @@
package tools
import (
"bufio"
"context"
"fmt"
"os"
"strings"
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
)
// ReadFileParams defines parameters for the read file tool.
type ReadFileParams struct {
Path string `json:"path" description:"File path to read"`
Start *int `json:"start,omitempty" description:"Starting line number (1-based, inclusive)"`
End *int `json:"end,omitempty" description:"Ending line number (1-based, inclusive)"`
}
// ReadFile creates a file reading tool.
//
// Example:
//
// tools := llm.NewToolBox(tools.ReadFile())
func ReadFile() llm.Tool {
return llm.Define[ReadFileParams]("read_file", "Read the contents of a file",
func(ctx context.Context, p ReadFileParams) (string, error) {
f, err := os.Open(p.Path)
if err != nil {
return "", fmt.Errorf("opening file: %w", err)
}
defer f.Close()
// If no line range specified, read the whole file (limited to 1MB)
if p.Start == nil && p.End == nil {
info, err := f.Stat()
if err != nil {
return "", fmt.Errorf("stat file: %w", err)
}
if info.Size() > 1<<20 {
return "", fmt.Errorf("file too large (%d bytes), use start/end to read a range", info.Size())
}
data, err := os.ReadFile(p.Path)
if err != nil {
return "", fmt.Errorf("reading file: %w", err)
}
return string(data), nil
}
// Read specific line range
start := 1
end := -1
if p.Start != nil {
start = *p.Start
}
if p.End != nil {
end = *p.End
}
var lines []string
scanner := bufio.NewScanner(f)
lineNum := 0
for scanner.Scan() {
lineNum++
if lineNum < start {
continue
}
if end > 0 && lineNum > end {
break
}
lines = append(lines, fmt.Sprintf("%d: %s", lineNum, scanner.Text()))
}
if err := scanner.Err(); err != nil {
return "", fmt.Errorf("scanning file: %w", err)
}
return strings.Join(lines, "\n"), nil
},
)
}

101
v2/tools/websearch.go Normal file
View File

@@ -0,0 +1,101 @@
// Package tools provides ready-to-use tool implementations for common agent patterns.
package tools
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
)
// WebSearchParams defines parameters for the web search tool.
type WebSearchParams struct {
Query string `json:"query" description:"The search query"`
Count *int `json:"count,omitempty" description:"Number of results to return (default 5, max 20)"`
}
// WebSearch creates a web search tool using the Brave Search API.
//
// Get a free API key at https://brave.com/search/api/
//
// Example:
//
// tools := llm.NewToolBox(tools.WebSearch("your-brave-api-key"))
func WebSearch(apiKey string) llm.Tool {
return llm.Define[WebSearchParams]("web_search", "Search the web for information using Brave Search",
func(ctx context.Context, p WebSearchParams) (string, error) {
count := 5
if p.Count != nil && *p.Count > 0 {
count = *p.Count
if count > 20 {
count = 20
}
}
u := fmt.Sprintf("https://api.search.brave.com/res/v1/web/search?q=%s&count=%d",
url.QueryEscape(p.Query), count)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return "", fmt.Errorf("creating request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("X-Subscription-Token", apiKey)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("search request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("reading response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("search API returned %d: %s", resp.StatusCode, string(body))
}
// Parse and simplify the response
var raw map[string]any
if err := json.Unmarshal(body, &raw); err != nil {
return string(body), nil
}
type result struct {
Title string `json:"title"`
URL string `json:"url"`
Description string `json:"description"`
}
var results []result
if web, ok := raw["web"].(map[string]any); ok {
if items, ok := web["results"].([]any); ok {
for _, item := range items {
if m, ok := item.(map[string]any); ok {
r := result{}
if t, ok := m["title"].(string); ok {
r.Title = t
}
if u, ok := m["url"].(string); ok {
r.URL = u
}
if d, ok := m["description"].(string); ok {
r.Description = d
}
results = append(results, r)
}
}
}
}
out, _ := json.MarshalIndent(results, "", " ")
return string(out), nil
},
)
}

31
v2/tools/writefile.go Normal file
View File

@@ -0,0 +1,31 @@
package tools
import (
"context"
"fmt"
"os"
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
)
// WriteFileParams defines parameters for the write file tool.
type WriteFileParams struct {
Path string `json:"path" description:"File path to write"`
Content string `json:"content" description:"Content to write to the file"`
}
// WriteFile creates a file writing tool.
//
// Example:
//
// tools := llm.NewToolBox(tools.WriteFile())
func WriteFile() llm.Tool {
return llm.Define[WriteFileParams]("write_file", "Write content to a file (creates or overwrites)",
func(ctx context.Context, p WriteFileParams) (string, error) {
if err := os.WriteFile(p.Path, []byte(p.Content), 0644); err != nil {
return "", fmt.Errorf("writing file: %w", err)
}
return fmt.Sprintf("Successfully wrote %d bytes to %s", len(p.Content), p.Path), nil
},
)
}

100
v2/transcriber.go Normal file
View File

@@ -0,0 +1,100 @@
package llm
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
// Transcriber abstracts a speech-to-text model implementation.
type Transcriber = provider.Transcriber
// TranscriptionResponseFormat controls the output format requested from a transcriber.
type TranscriptionResponseFormat = provider.TranscriptionResponseFormat
const (
TranscriptionResponseFormatJSON = provider.TranscriptionResponseFormatJSON
TranscriptionResponseFormatVerboseJSON = provider.TranscriptionResponseFormatVerboseJSON
TranscriptionResponseFormatText = provider.TranscriptionResponseFormatText
TranscriptionResponseFormatSRT = provider.TranscriptionResponseFormatSRT
TranscriptionResponseFormatVTT = provider.TranscriptionResponseFormatVTT
)
// TranscriptionTimestampGranularity defines the requested timestamp detail.
type TranscriptionTimestampGranularity = provider.TranscriptionTimestampGranularity
const (
TranscriptionTimestampGranularityWord = provider.TranscriptionTimestampGranularityWord
TranscriptionTimestampGranularitySegment = provider.TranscriptionTimestampGranularitySegment
)
// TranscriptionOptions configures transcription behavior.
type TranscriptionOptions = provider.TranscriptionOptions
// Transcription captures a normalized transcription result.
type Transcription = provider.Transcription
// TranscriptionSegment provides a coarse time-sliced transcription segment.
type TranscriptionSegment = provider.TranscriptionSegment
// TranscriptionWord provides a word-level timestamp.
type TranscriptionWord = provider.TranscriptionWord
// TranscriptionTokenLogprob captures token-level log probability details.
type TranscriptionTokenLogprob = provider.TranscriptionTokenLogprob
// TranscriptionUsage captures token or duration usage details.
type TranscriptionUsage = provider.TranscriptionUsage
// TranscribeFile converts an audio file to WAV (via ffmpeg) and transcribes it.
func TranscribeFile(ctx context.Context, filename string, transcriber Transcriber, opts TranscriptionOptions) (Transcription, error) {
if transcriber == nil {
return Transcription{}, fmt.Errorf("transcriber is nil")
}
wav, err := audioFileToWav(ctx, filename)
if err != nil {
return Transcription{}, err
}
return transcriber.Transcribe(ctx, wav, opts)
}
func audioFileToWav(ctx context.Context, filename string) ([]byte, error) {
if filename == "" {
return nil, fmt.Errorf("filename is empty")
}
if strings.EqualFold(filepath.Ext(filename), ".wav") {
data, err := os.ReadFile(filename)
if err != nil {
return nil, fmt.Errorf("read wav file: %w", err)
}
return data, nil
}
tempFile, err := os.CreateTemp("", "go-llm-audio-*.wav")
if err != nil {
return nil, fmt.Errorf("create temp wav file: %w", err)
}
tempPath := tempFile.Name()
_ = tempFile.Close()
defer os.Remove(tempPath)
cmd := exec.CommandContext(ctx, "ffmpeg", "-hide_banner", "-loglevel", "error", "-y", "-i", filename, "-vn", "-f", "wav", tempPath)
if output, err := cmd.CombinedOutput(); err != nil {
return nil, fmt.Errorf("ffmpeg convert failed: %w (output: %s)", err, strings.TrimSpace(string(output)))
}
data, err := os.ReadFile(tempPath)
if err != nil {
return nil, fmt.Errorf("read converted wav file: %w", err)
}
return data, nil
}