From 34119e5a000e3d6ca078bbf2a0ce503010eeae8d Mon Sep 17 00:00:00 2001 From: Steve Dudenhoeffer Date: Fri, 24 Apr 2026 13:34:39 +0000 Subject: [PATCH] feat: add DeepSeek, Moonshot, xAI, Groq, Ollama; drop v1; migrate TUI to v2 Five OpenAI-compatible providers join the library as first-class constructors (llm.DeepSeek, llm.Moonshot, llm.XAI, llm.Groq, llm.Ollama). Their wire-level implementation is shared via a new v2/openaicompat package which is the extracted guts of the old v2/openai provider; each provider supplies its own Rules value to declare per-model constraints (e.g., DeepSeek Reasoner rejects tools and temperature, Moonshot/xAI accept images only on *-vision* models, Groq rejects audio input). v2/openai itself becomes a thin wrapper that sets RestrictTemperature for o-series and gpt-5 models. A new provider registry (v2/registry.go) exposes llm.Providers() and drives the TUI's provider picker so adding a provider in future is a single-file change. The TUI at cmd/llm was migrated from v1 to v2 and moved to v2/cmd/llm. With nothing else depending on v1, the v1 code at the repo root (all .go files, schema/, internal/, provider/, root go.mod/go.sum) is deleted. Co-Authored-By: Claude Opus 4.7 --- CLAUDE.md | 103 ++--- anthropic.go | 225 ----------- cmd/llm/.env.example | 11 - cmd/llm/commands.go | 182 --------- cmd/llm/tools.go | 105 ------ context.go | 120 ------ error.go | 21 -- function.go | 136 ------- functions.go | 35 -- go.mod | 67 ---- go.sum | 145 -------- google.go | 165 -------- internal/imageutil/compress.go | 114 ------ llm.go | 30 -- mcp.go | 238 ------------ message.go | 115 ------ openai.go | 322 ---------------- openai_transcriber.go | 219 ----------- parse.go | 50 --- provider/anthropic/anthropic.go | 11 - provider/google/google.go | 11 - provider/openai/openai.go | 11 - request.go | 51 --- response.go | 52 --- schema/GetType.go | 142 ------- schema/array.go | 77 ---- schema/basic.go | 165 -------- schema/enum.go | 61 --- schema/object.go | 169 --------- schema/raw.go | 134 ------- schema/type.go | 23 -- toolbox.go | 174 --------- transcriber.go | 145 -------- v2/cmd/llm/.env.example | 27 ++ v2/cmd/llm/commands.go | 136 +++++++ {cmd => v2/cmd}/llm/main.go | 0 {cmd => v2/cmd}/llm/model.go | 160 +++----- {cmd => v2/cmd}/llm/styles.go | 0 v2/cmd/llm/tools.go | 114 ++++++ {cmd => v2/cmd}/llm/update.go | 138 +++---- {cmd => v2/cmd}/llm/view.go | 41 +- v2/constructors.go | 71 ++++ v2/deepseek/deepseek.go | 36 ++ v2/deepseek/deepseek_test.go | 49 +++ v2/go.mod | 29 +- v2/go.sum | 54 ++- v2/groq/groq.go | 33 ++ v2/groq/groq_test.go | 33 ++ v2/moonshot/moonshot.go | 30 ++ v2/moonshot/moonshot_test.go | 33 ++ v2/ollama/ollama.go | 25 ++ v2/ollama/ollama_test.go | 13 + v2/openai/openai.go | 442 ++-------------------- v2/openaicompat/openaicompat.go | 537 +++++++++++++++++++++++++++ v2/openaicompat/openaicompat_test.go | 313 ++++++++++++++++ v2/registry.go | 158 ++++++++ v2/xai/xai.go | 29 ++ v2/xai/xai_test.go | 33 ++ 58 files changed, 1921 insertions(+), 4242 deletions(-) delete mode 100644 anthropic.go delete mode 100644 cmd/llm/.env.example delete mode 100644 cmd/llm/commands.go delete mode 100644 cmd/llm/tools.go delete mode 100644 context.go delete mode 100644 error.go delete mode 100644 function.go delete mode 100644 functions.go delete mode 100644 go.mod delete mode 100644 go.sum delete mode 100644 google.go delete mode 100644 internal/imageutil/compress.go delete mode 100644 llm.go delete mode 100644 mcp.go delete mode 100644 message.go delete mode 100644 openai.go delete mode 100644 openai_transcriber.go delete mode 100644 parse.go delete mode 100644 provider/anthropic/anthropic.go delete mode 100644 provider/google/google.go delete mode 100644 provider/openai/openai.go delete mode 100644 request.go delete mode 100644 response.go delete mode 100644 schema/GetType.go delete mode 100644 schema/array.go delete mode 100644 schema/basic.go delete mode 100644 schema/enum.go delete mode 100644 schema/object.go delete mode 100644 schema/raw.go delete mode 100644 schema/type.go delete mode 100644 toolbox.go delete mode 100644 transcriber.go create mode 100644 v2/cmd/llm/.env.example create mode 100644 v2/cmd/llm/commands.go rename {cmd => v2/cmd}/llm/main.go (100%) rename {cmd => v2/cmd}/llm/model.go (55%) rename {cmd => v2/cmd}/llm/styles.go (100%) create mode 100644 v2/cmd/llm/tools.go rename {cmd => v2/cmd}/llm/update.go (68%) rename {cmd => v2/cmd}/llm/view.go (88%) create mode 100644 v2/deepseek/deepseek.go create mode 100644 v2/deepseek/deepseek_test.go create mode 100644 v2/groq/groq.go create mode 100644 v2/groq/groq_test.go create mode 100644 v2/moonshot/moonshot.go create mode 100644 v2/moonshot/moonshot_test.go create mode 100644 v2/ollama/ollama.go create mode 100644 v2/ollama/ollama_test.go create mode 100644 v2/openaicompat/openaicompat.go create mode 100644 v2/openaicompat/openaicompat_test.go create mode 100644 v2/registry.go create mode 100644 v2/xai/xai.go create mode 100644 v2/xai/xai_test.go diff --git a/CLAUDE.md b/CLAUDE.md index 3824b30..4c3fd66 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,88 +1,31 @@ # CLAUDE.md for go-llm -## Build and Test Commands -- Build project: `go build ./...` -- Run all tests: `go test ./...` -- Run specific test: `go test -v -run ./...` -- Tidy dependencies: `go mod tidy` +All Go code now lives under `v2/`. The module path is +`gitea.stevedudenhoeffer.com/steve/go-llm/v2`. There is no module at the +repository root anymore; the v1 code at the root was deleted after all +consumers migrated to v2. -## Code Style Guidelines -- **Indentation**: Use standard Go tabs for indentation. -- **Naming**: - - Use `camelCase` for internal/private variables and functions. - - Use `PascalCase` for exported types, functions, and struct fields. - - Interface names should be concise (e.g., `LLM`, `ChatCompletion`). -- **Error Handling**: - - Always check and handle errors immediately. - - Wrap errors with context using `fmt.Errorf("%w: ...", err)`. - - Use the project's internal `Error` struct in `error.go` when differentiating between error types is needed. -- **Project Structure**: - - `llm.go`: Contains core interfaces (`LLM`, `ChatCompletion`) and shared types (`Message`, `Role`, `Image`). - - Provider implementations are in `openai.go`, `anthropic.go`, and `google.go`. - - Schema definitions for tool calling are in the `schema/` directory. - - `mcp.go`: MCP (Model Context Protocol) client integration for connecting to MCP servers. -- **Imports**: Organize imports into groups: standard library, then third-party libraries. -- **Documentation**: Use standard Go doc comments for exported symbols. -- **README.md**: The README.md file should always be kept up to date with any significant changes to the project. +See `v2/CLAUDE.md` for build/test commands and per-package guidance. -## CLI Tool -- Build CLI: `go build ./cmd/llm` -- Run CLI: `./llm` (or `llm.exe` on Windows) -- Run without building: `go run ./cmd/llm` +## CLI -### CLI Features -- Interactive TUI for testing all go-llm features -- Support for OpenAI, Anthropic, and Google providers -- Image input (file path, URL, or base64) -- Tool/function calling with demo tools -- Temperature control and settings +The interactive TUI lives at `v2/cmd/llm`: -### Key Bindings -- `Enter` - Send message -- `Ctrl+I` - Add image -- `Ctrl+T` - Toggle tools panel -- `Ctrl+P` - Change provider -- `Ctrl+M` - Change model -- `Ctrl+S` - Settings -- `Ctrl+N` - New conversation -- `Esc` - Exit/Cancel - -## MCP (Model Context Protocol) Support - -The library supports connecting to MCP servers to use their tools. MCP servers can be connected via: -- **stdio**: Run a command as a subprocess -- **sse**: Connect to an SSE endpoint -- **http**: Connect to a streamable HTTP endpoint - -### Usage Example -```go -ctx := context.Background() - -// Create and connect to an MCP server -server := &llm.MCPServer{ - Name: "my-server", - Command: "my-mcp-server", - Args: []string{"--some-flag"}, -} -if err := server.Connect(ctx); err != nil { - log.Fatal(err) -} -defer server.Close() - -// Add the server to a toolbox -toolbox := llm.NewToolBox().WithMCPServer(server) - -// Use the toolbox in requests - MCP tools are automatically available -req := llm.Request{ - Messages: []llm.Message{{Role: llm.RoleUser, Text: "Use the MCP tool"}}, - Toolbox: toolbox, -} +``` +cd v2 && go run ./cmd/llm ``` -### MCPServer Options -- `Name`: Friendly name for logging -- `Command`: Command to run (for stdio transport) -- `Args`: Command arguments -- `Env`: Additional environment variables -- `URL`: Endpoint URL (for sse/http transport) -- `Transport`: "stdio" (default), "sse", or "http" +It iterates `llm.Providers()` so every registered provider (OpenAI, Anthropic, +Google, DeepSeek, Moonshot, xAI, Groq, Ollama) appears in the picker +automatically. Status is derived from each provider's env var; Ollama shows as +"(local)" because it needs no key. + +### Key bindings +- `Enter` — Send message +- `Ctrl+I` — Add image +- `Ctrl+T` — Toggle tools panel +- `Ctrl+P` — Change provider +- `Ctrl+M` — Change model +- `Ctrl+S` — Settings +- `Ctrl+N` — New conversation +- `Esc` — Exit/Cancel diff --git a/anthropic.go b/anthropic.go deleted file mode 100644 index b5e0244..0000000 --- a/anthropic.go +++ /dev/null @@ -1,225 +0,0 @@ -package llm - -import ( - "context" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "log" - "log/slog" - "net/http" - - "gitea.stevedudenhoeffer.com/steve/go-llm/internal/imageutil" - - anth "github.com/liushuangls/go-anthropic/v2" -) - -type anthropicImpl struct { - key string - model string -} - -var _ LLM = anthropicImpl{} - -func (a anthropicImpl) ModelVersion(modelVersion string) (ChatCompletion, error) { - a.model = modelVersion - - // TODO: model verification? - return a, nil -} - -func deferClose(c io.Closer) { - err := c.Close() - if err != nil { - slog.Error("error closing", "error", err) - } -} - -func (a anthropicImpl) requestToAnthropicRequest(req Request) anth.MessagesRequest { - res := anth.MessagesRequest{ - Model: anth.Model(a.model), - MaxTokens: 1000, - } - - msgs := []anth.Message{} - - // we gotta convert messages into anthropic messages, however - // anthropic does not have a "system" message type, so we need to - // append it to the res.System field instead - - for _, msg := range req.Messages { - if msg.Role == RoleSystem { - if len(res.System) > 0 { - res.System += "\n" - } - res.System += msg.Text - } else { - role := anth.RoleUser - - if msg.Role == RoleAssistant { - role = anth.RoleAssistant - } - - m := anth.Message{ - Role: role, - Content: []anth.MessageContent{}, - } - - if msg.Text != "" { - m.Content = append(m.Content, anth.MessageContent{ - Type: anth.MessagesContentTypeText, - Text: &msg.Text, - }) - } - - for _, img := range msg.Images { - // anthropic doesn't allow the assistant to send images, so we need to say it's from the user - if m.Role == anth.RoleAssistant { - m.Role = anth.RoleUser - } - - if img.Base64 != "" { - // Anthropic models expect images to be < 5MiB in size - raw, err := base64.StdEncoding.DecodeString(img.Base64) - - if err != nil { - continue - } - - // Check if image size exceeds 5MiB (5242880 bytes) - if len(raw) >= 5242880 { - - compressed, mime, err := imageutil.CompressImage(img.Base64, 5*1024*1024) - - // just replace the image with the compressed one - if err != nil { - continue - } - - img.Base64 = compressed - img.ContentType = mime - } - - m.Content = append(m.Content, anth.NewImageMessageContent( - anth.NewMessageContentSource( - anth.MessagesContentSourceTypeBase64, - img.ContentType, - img.Base64, - ))) - } else if img.Url != "" { - - // download the image - cl, err := http.NewRequest(http.MethodGet, img.Url, nil) - if err != nil { - log.Println("failed to create request", err) - continue - } - - resp, err := http.DefaultClient.Do(cl) - if err != nil { - log.Println("failed to download image", err) - continue - } - - defer deferClose(resp.Body) - - img.ContentType = resp.Header.Get("Content-Type") - - // read the image - b, err := io.ReadAll(resp.Body) - if err != nil { - log.Println("failed to read image", err) - continue - } - - // base64 encode the image - img.Base64 = string(b) - - m.Content = append(m.Content, anth.NewImageMessageContent( - anth.NewMessageContentSource( - anth.MessagesContentSourceTypeBase64, - img.ContentType, - img.Base64, - ))) - } - } - - // if this has the same role as the previous message, we can append it to the previous message - // as anthropic expects alternating assistant and user roles - if len(msgs) > 0 && msgs[len(msgs)-1].Role == role { - m2 := &msgs[len(msgs)-1] - - m2.Content = append(m2.Content, m.Content...) - } else { - msgs = append(msgs, m) - } - } - } - - for _, tool := range req.Toolbox.Functions() { - res.Tools = append(res.Tools, anth.ToolDefinition{ - Name: tool.Name, - Description: tool.Description, - InputSchema: tool.Parameters.AnthropicInputSchema(), - }) - } - - res.Messages = msgs - - if req.Temperature != nil { - var f = float32(*req.Temperature) - res.Temperature = &f - } - - log.Println("llm request to anthropic request", res) - - return res -} - -func (a anthropicImpl) responseToLLMResponse(in anth.MessagesResponse) Response { - choice := ResponseChoice{} - for _, msg := range in.Content { - - switch msg.Type { - case anth.MessagesContentTypeText: - if msg.Text != nil { - choice.Content += *msg.Text - } - - case anth.MessagesContentTypeToolUse: - if msg.MessageContentToolUse != nil { - b, e := json.Marshal(msg.MessageContentToolUse.Input) - if e != nil { - log.Println("failed to marshal input", e) - } else { - choice.Calls = append(choice.Calls, ToolCall{ - ID: msg.MessageContentToolUse.ID, - FunctionCall: FunctionCall{ - Name: msg.MessageContentToolUse.Name, - Arguments: string(b), - }, - }) - } - } - } - } - - log.Println("anthropic response to llm response", choice) - - return Response{ - Choices: []ResponseChoice{choice}, - } -} - -func (a anthropicImpl) ChatComplete(ctx context.Context, req Request) (Response, error) { - cl := anth.NewClient(a.key) - - res, err := cl.CreateMessages(ctx, a.requestToAnthropicRequest(req)) - - if err != nil { - return Response{}, fmt.Errorf("failed to chat complete: %w", err) - } - - return a.responseToLLMResponse(res), nil -} diff --git a/cmd/llm/.env.example b/cmd/llm/.env.example deleted file mode 100644 index 3ddaa80..0000000 --- a/cmd/llm/.env.example +++ /dev/null @@ -1,11 +0,0 @@ -# go-llm CLI Environment Variables -# Copy this file to .env and fill in your API keys - -# OpenAI API Key (https://platform.openai.com/api-keys) -OPENAI_API_KEY= - -# Anthropic API Key (https://console.anthropic.com/settings/keys) -ANTHROPIC_API_KEY= - -# Google AI API Key (https://aistudio.google.com/apikey) -GOOGLE_API_KEY= diff --git a/cmd/llm/commands.go b/cmd/llm/commands.go deleted file mode 100644 index e5e7941..0000000 --- a/cmd/llm/commands.go +++ /dev/null @@ -1,182 +0,0 @@ -package main - -import ( - "context" - "encoding/base64" - "fmt" - "net/http" - "os" - "strings" - - tea "github.com/charmbracelet/bubbletea" - - llm "gitea.stevedudenhoeffer.com/steve/go-llm" -) - -// Message types for async operations - -// ChatResponseMsg contains the response from a chat completion -type ChatResponseMsg struct { - Response llm.Response - Err error -} - -// ToolExecutionMsg contains results from tool execution -type ToolExecutionMsg struct { - Results []llm.ToolCallResponse - Err error -} - -// ImageLoadedMsg contains a loaded image -type ImageLoadedMsg struct { - Image llm.Image - Err error -} - -// sendChatRequest sends a chat completion request -func sendChatRequest(chat llm.ChatCompletion, req llm.Request) tea.Cmd { - return func() tea.Msg { - resp, err := chat.ChatComplete(context.Background(), req) - return ChatResponseMsg{Response: resp, Err: err} - } -} - -// executeTools executes tool calls and returns results -func executeTools(toolbox llm.ToolBox, req llm.Request, resp llm.ResponseChoice) tea.Cmd { - return func() tea.Msg { - ctx := llm.NewContext(context.Background(), req, &resp, nil) - var results []llm.ToolCallResponse - - for _, call := range resp.Calls { - result, err := toolbox.Execute(ctx, call) - results = append(results, llm.ToolCallResponse{ - ID: call.ID, - Result: result, - Error: err, - }) - } - - return ToolExecutionMsg{Results: results, Err: nil} - } -} - -// loadImageFromPath loads an image from a file path -func loadImageFromPath(path string) tea.Cmd { - return func() tea.Msg { - // Clean up the path - path = strings.TrimSpace(path) - path = strings.Trim(path, "\"'") - - // Read the file - data, err := os.ReadFile(path) - if err != nil { - return ImageLoadedMsg{Err: fmt.Errorf("failed to read image file: %w", err)} - } - - // Detect content type - contentType := http.DetectContentType(data) - if !strings.HasPrefix(contentType, "image/") { - return ImageLoadedMsg{Err: fmt.Errorf("file is not an image: %s", contentType)} - } - - // Base64 encode - encoded := base64.StdEncoding.EncodeToString(data) - - return ImageLoadedMsg{ - Image: llm.Image{ - Base64: encoded, - ContentType: contentType, - }, - } - } -} - -// loadImageFromURL loads an image from a URL -func loadImageFromURL(url string) tea.Cmd { - return func() tea.Msg { - url = strings.TrimSpace(url) - - // For URL images, we can just use the URL directly - return ImageLoadedMsg{ - Image: llm.Image{ - Url: url, - }, - } - } -} - -// loadImageFromBase64 loads an image from base64 data -func loadImageFromBase64(data string) tea.Cmd { - return func() tea.Msg { - data = strings.TrimSpace(data) - - // Check if it's a data URL - if strings.HasPrefix(data, "data:") { - // Parse data URL: data:image/png;base64,.... - parts := strings.SplitN(data, ",", 2) - if len(parts) != 2 { - return ImageLoadedMsg{Err: fmt.Errorf("invalid data URL format")} - } - - // Extract content type from first part - mediaType := strings.TrimPrefix(parts[0], "data:") - mediaType = strings.TrimSuffix(mediaType, ";base64") - - return ImageLoadedMsg{ - Image: llm.Image{ - Base64: parts[1], - ContentType: mediaType, - }, - } - } - - // Assume it's raw base64, try to detect content type - decoded, err := base64.StdEncoding.DecodeString(data) - if err != nil { - return ImageLoadedMsg{Err: fmt.Errorf("invalid base64 data: %w", err)} - } - - contentType := http.DetectContentType(decoded) - if !strings.HasPrefix(contentType, "image/") { - return ImageLoadedMsg{Err: fmt.Errorf("data is not an image: %s", contentType)} - } - - return ImageLoadedMsg{ - Image: llm.Image{ - Base64: data, - ContentType: contentType, - }, - } - } -} - -// buildRequest builds a chat request from the current state -func buildRequest(m *Model, userText string) llm.Request { - // Create the user message with any pending images - userMsg := llm.Message{ - Role: llm.RoleUser, - Text: userText, - Images: m.pendingImages, - } - - req := llm.Request{ - Conversation: m.conversation, - Messages: []llm.Message{ - {Role: llm.RoleSystem, Text: m.systemPrompt}, - userMsg, - }, - Temperature: m.temperature, - } - - // Add toolbox if enabled - if m.toolsEnabled && len(m.toolbox.Functions()) > 0 { - req.Toolbox = m.toolbox.WithRequireTool(false) - } - - return req -} - -// buildFollowUpRequest builds a follow-up request after tool execution -func buildFollowUpRequest(m *Model, previousReq llm.Request, resp llm.ResponseChoice, toolResults []llm.ToolCallResponse) llm.Request { - return previousReq.NextRequest(resp, toolResults) -} diff --git a/cmd/llm/tools.go b/cmd/llm/tools.go deleted file mode 100644 index 58f8164..0000000 --- a/cmd/llm/tools.go +++ /dev/null @@ -1,105 +0,0 @@ -package main - -import ( - "fmt" - "math" - "strconv" - "strings" - "time" - - llm "gitea.stevedudenhoeffer.com/steve/go-llm" -) - -// TimeParams is the parameter struct for the GetTime function -type TimeParams struct{} - -// GetTime returns the current time -func GetTime(_ *llm.Context, _ TimeParams) (any, error) { - return time.Now().Format("Monday, January 2, 2006 3:04:05 PM MST"), nil -} - -// CalcParams is the parameter struct for the Calculate function -type CalcParams struct { - A float64 `json:"a" description:"First number"` - B float64 `json:"b" description:"Second number"` - Op string `json:"op" description:"Operation: add, subtract, multiply, divide, power, sqrt, mod"` -} - -// Calculate performs basic math operations -func Calculate(_ *llm.Context, params CalcParams) (any, error) { - switch strings.ToLower(params.Op) { - case "add", "+": - return params.A + params.B, nil - case "subtract", "sub", "-": - return params.A - params.B, nil - case "multiply", "mul", "*": - return params.A * params.B, nil - case "divide", "div", "/": - if params.B == 0 { - return nil, fmt.Errorf("division by zero") - } - return params.A / params.B, nil - case "power", "pow", "^": - return math.Pow(params.A, params.B), nil - case "sqrt": - if params.A < 0 { - return nil, fmt.Errorf("cannot take square root of negative number") - } - return math.Sqrt(params.A), nil - case "mod", "%": - return math.Mod(params.A, params.B), nil - default: - return nil, fmt.Errorf("unknown operation: %s", params.Op) - } -} - -// WeatherParams is the parameter struct for the GetWeather function -type WeatherParams struct { - Location string `json:"location" description:"City name or location"` -} - -// GetWeather returns mock weather data (for demo purposes) -func GetWeather(_ *llm.Context, params WeatherParams) (any, error) { - // This is a demo function - returns mock data - weathers := []string{"sunny", "cloudy", "rainy", "partly cloudy", "windy"} - temps := []int{65, 72, 58, 80, 45} - - // Use location string to deterministically pick weather - idx := len(params.Location) % len(weathers) - - return map[string]any{ - "location": params.Location, - "temperature": strconv.Itoa(temps[idx]) + "F", - "condition": weathers[idx], - "humidity": "45%", - "note": "This is mock data for demonstration purposes", - }, nil -} - -// RandomNumberParams is the parameter struct for the RandomNumber function -type RandomNumberParams struct { - Min int `json:"min" description:"Minimum value (inclusive)"` - Max int `json:"max" description:"Maximum value (inclusive)"` -} - -// RandomNumber generates a pseudo-random number (using current time nanoseconds) -func RandomNumber(_ *llm.Context, params RandomNumberParams) (any, error) { - if params.Min > params.Max { - return nil, fmt.Errorf("min cannot be greater than max") - } - // Simple pseudo-random using time - n := time.Now().UnixNano() - rangeSize := params.Max - params.Min + 1 - result := params.Min + int(n%int64(rangeSize)) - return result, nil -} - -// createDemoToolbox creates a toolbox with demo tools for testing -func createDemoToolbox() llm.ToolBox { - return llm.NewToolBox( - llm.NewFunction("get_time", "Get the current date and time", GetTime), - llm.NewFunction("calculate", "Perform basic math operations (add, subtract, multiply, divide, power, sqrt, mod)", Calculate), - llm.NewFunction("get_weather", "Get weather information for a location (demo data)", GetWeather), - llm.NewFunction("random_number", "Generate a random number between min and max", RandomNumber), - ) -} diff --git a/context.go b/context.go deleted file mode 100644 index 4990bb0..0000000 --- a/context.go +++ /dev/null @@ -1,120 +0,0 @@ -package llm - -import ( - "context" - "time" -) - -type Context struct { - context.Context - request Request - response *ResponseChoice - toolcall *ToolCall - syntheticFields map[string]string -} - -func (c *Context) ToNewRequest(toolResults ...ToolCallResponse) Request { - var res Request - - res.Toolbox = c.request.Toolbox - res.Temperature = c.request.Temperature - - res.Conversation = make([]Input, len(c.request.Conversation)) - copy(res.Conversation, c.request.Conversation) - - // now for every input message, convert those to an Input to add to the conversation - for _, msg := range c.request.Messages { - res.Conversation = append(res.Conversation, msg) - } - - // if there are tool calls, then we need to add those to the conversation - if c.response != nil { - res.Conversation = append(res.Conversation, *c.response) - } - - // if there are tool results, then we need to add those to the conversation - for _, result := range toolResults { - res.Conversation = append(res.Conversation, result) - } - - return res -} - -func NewContext(ctx context.Context, request Request, response *ResponseChoice, toolcall *ToolCall) *Context { - return &Context{Context: ctx, request: request, response: response, toolcall: toolcall} -} - -func (c *Context) Request() Request { - return c.request -} - -func (c *Context) Response() *ResponseChoice { - return c.response -} - -func (c *Context) ToolCall() *ToolCall { - return c.toolcall -} - -func (c *Context) SyntheticFields() map[string]string { - if c.syntheticFields == nil { - c.syntheticFields = map[string]string{} - } - - return c.syntheticFields -} - -func (c *Context) WithContext(ctx context.Context) *Context { - return &Context{Context: ctx, request: c.request, response: c.response, toolcall: c.toolcall, syntheticFields: c.syntheticFields} -} - -func (c *Context) WithRequest(request Request) *Context { - return &Context{Context: c.Context, request: request, response: c.response, toolcall: c.toolcall, syntheticFields: c.syntheticFields} -} - -func (c *Context) WithResponse(response *ResponseChoice) *Context { - return &Context{Context: c.Context, request: c.request, response: response, toolcall: c.toolcall, syntheticFields: c.syntheticFields} -} - -func (c *Context) WithToolCall(toolcall *ToolCall) *Context { - return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: toolcall, syntheticFields: c.syntheticFields} -} - -func (c *Context) WithSyntheticFields(syntheticFields map[string]string) *Context { - return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: c.toolcall, syntheticFields: syntheticFields} -} - -func (c *Context) Deadline() (deadline time.Time, ok bool) { - return c.Context.Deadline() -} - -func (c *Context) Done() <-chan struct{} { - return c.Context.Done() -} - -func (c *Context) Err() error { - return c.Context.Err() -} - -func (c *Context) Value(key any) any { - switch key { - case "request": - return c.request - - case "response": - return c.response - - case "toolcall": - return c.toolcall - - case "syntheticFields": - return c.syntheticFields - - } - return c.Context.Value(key) -} - -func (c *Context) WithTimeout(timeout time.Duration) (*Context, context.CancelFunc) { - ctx, cancel := context.WithTimeout(c.Context, timeout) - return c.WithContext(ctx), cancel -} diff --git a/error.go b/error.go deleted file mode 100644 index 8f432bb..0000000 --- a/error.go +++ /dev/null @@ -1,21 +0,0 @@ -package llm - -import "fmt" - -// Error is essentially just an error, but it is used to differentiate between a normal error and a fatal error. -type Error struct { - error - - Source error - Parameter error -} - -func newError(parent error, err error) Error { - e := fmt.Errorf("%w: %w", parent, err) - return Error{ - error: e, - - Source: parent, - Parameter: err, - } -} diff --git a/function.go b/function.go deleted file mode 100644 index 4ef5df0..0000000 --- a/function.go +++ /dev/null @@ -1,136 +0,0 @@ -package llm - -import ( - "context" - "encoding/json" - "fmt" - "log/slog" - "reflect" - "time" - - "gitea.stevedudenhoeffer.com/steve/go-llm/schema" -) - -type Function struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - Strict bool `json:"strict,omitempty"` - Parameters schema.Type `json:"parameters"` - - Forced bool `json:"forced,omitempty"` - - // Timeout is the maximum time to wait for the function to complete - Timeout time.Duration `json:"-"` - - // fn is the function to call, only set if this is constructed with NewFunction - fn reflect.Value - - paramType reflect.Type -} - -func (f Function) WithSyntheticField(name string, description string) Function { - if obj, o := f.Parameters.(schema.Object); o { - f.Parameters = obj.WithSyntheticField(name, description) - } - - return f -} - -func (f Function) WithSyntheticFields(fieldsAndDescriptions map[string]string) Function { - if obj, o := f.Parameters.(schema.Object); o { - for k, v := range fieldsAndDescriptions { - obj = obj.WithSyntheticField(k, v) - } - f.Parameters = obj - } - - return f -} - -func (f Function) WithDescription(description string) Function { - f.Description = description - return f -} - -func (f Function) Execute(ctx *Context, input string) (any, error) { - if !f.fn.IsValid() { - return "", fmt.Errorf("function %s is not implemented", f.Name) - } - - slog.Info("Function.Execute", "name", f.Name, "input", input, "f", f.paramType) - // first, we need to parse the input into the struct - p := reflect.New(f.paramType) - fmt.Println("Function.Execute", f.Name, "input:", input) - - var vals map[string]any - err := json.Unmarshal([]byte(input), &vals) - - var syntheticFields map[string]string - - // first eat up any synthetic fields - if obj, o := f.Parameters.(schema.Object); o { - for k := range obj.SyntheticFields() { - key := schema.SyntheticFieldPrefix + k - if val, ok := vals[key]; ok { - if syntheticFields == nil { - syntheticFields = map[string]string{} - } - - syntheticFields[k] = fmt.Sprint(val) - delete(vals, key) - } - } - } - - // now for any remaining fields, re-marshal them into json and then unmarshal into the struct - b, err := json.Marshal(vals) - if err != nil { - return "", fmt.Errorf("failed to marshal input: %w (input: %s)", err, input) - } - - // now we can unmarshal the input into the struct - err = json.Unmarshal(b, p.Interface()) - if err != nil { - return "", fmt.Errorf("failed to unmarshal input: %w (input: %s)", err, input) - } - - // now we can call the function - exec := func(ctx *Context) (any, error) { - out := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()}) - - if len(out) != 2 { - return "", fmt.Errorf("function %s must return two values, got %d", f.Name, len(out)) - } - - if out[1].IsNil() { - return out[0].Interface(), nil - } - - return "", out[1].Interface().(error) - } - - var cancel context.CancelFunc - if f.Timeout > 0 { - ctx, cancel = ctx.WithTimeout(f.Timeout) - defer cancel() - } - - return exec(ctx) -} - -type FunctionCall struct { - Name string `json:"name,omitempty"` - Arguments string `json:"arguments,omitempty"` -} - -func (fc *FunctionCall) toRaw() map[string]any { - res := map[string]interface{}{ - "name": fc.Name, - } - - if fc.Arguments != "" { - res["arguments"] = fc.Arguments - } - - return res -} diff --git a/functions.go b/functions.go deleted file mode 100644 index dfdcebd..0000000 --- a/functions.go +++ /dev/null @@ -1,35 +0,0 @@ -package llm - -import ( - "reflect" - - "gitea.stevedudenhoeffer.com/steve/go-llm/schema" -) - -// Parse takes a function pointer and returns a function object. -// fn must be a pointer to a function that takes a context.Context as its first argument, and then a struct that contains -// the parameters for the function. The struct must contain only the types: string, int, float64, bool, and pointers to -// those types. -// The struct parameters can have the following tags: -// - Description: a string that describes the parameter, passed to openaiImpl to tell it what the parameter is for - -func NewFunction[T any](name string, description string, fn func(*Context, T) (any, error)) Function { - var o T - - res := Function{ - Name: name, - Description: description, - Parameters: schema.GetType(o), - fn: reflect.ValueOf(fn), - paramType: reflect.TypeOf(o), - } - - if res.fn.Kind() != reflect.Func { - panic("fn must be a function") - } - if res.paramType.Kind() != reflect.Struct { - panic("function parameter must be a struct") - } - - return res -} diff --git a/go.mod b/go.mod deleted file mode 100644 index 8e10002..0000000 --- a/go.mod +++ /dev/null @@ -1,67 +0,0 @@ -module gitea.stevedudenhoeffer.com/steve/go-llm - -go 1.24.0 - -toolchain go1.24.2 - -require ( - github.com/charmbracelet/bubbles v0.21.0 - github.com/charmbracelet/bubbletea v1.3.10 - github.com/charmbracelet/lipgloss v1.1.0 - github.com/joho/godotenv v1.5.1 - github.com/liushuangls/go-anthropic/v2 v2.17.0 - github.com/modelcontextprotocol/go-sdk v1.2.0 - github.com/openai/openai-go v1.12.0 - golang.org/x/image v0.35.0 - google.golang.org/genai v1.43.0 -) - -require ( - cloud.google.com/go v0.123.0 // indirect - cloud.google.com/go/auth v0.18.1 // indirect - cloud.google.com/go/compute/metadata v0.9.0 // indirect - github.com/atotto/clipboard v0.1.4 // indirect - github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect - github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect - github.com/charmbracelet/x/ansi v0.10.1 // indirect - github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect - github.com/charmbracelet/x/term v0.2.1 // indirect - github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect - github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/go-logr/logr v1.4.3 // indirect - github.com/go-logr/stdr v1.2.2 // indirect - github.com/google/go-cmp v0.7.0 // indirect - github.com/google/jsonschema-go v0.3.0 // indirect - github.com/google/s2a-go v0.1.9 // indirect - github.com/googleapis/enterprise-certificate-proxy v0.3.11 // indirect - github.com/googleapis/gax-go/v2 v2.16.0 // indirect - github.com/gorilla/websocket v1.5.3 // indirect - github.com/lucasb-eyer/go-colorful v1.2.0 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-localereader v0.0.1 // indirect - github.com/mattn/go-runewidth v0.0.16 // indirect - github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect - github.com/muesli/cancelreader v0.2.2 // indirect - github.com/muesli/termenv v0.16.0 // indirect - github.com/rivo/uniseg v0.4.7 // indirect - github.com/tidwall/gjson v1.18.0 // indirect - github.com/tidwall/match v1.2.0 // indirect - github.com/tidwall/pretty v1.2.1 // indirect - github.com/tidwall/sjson v1.2.5 // indirect - github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect - github.com/yosida95/uritemplate/v3 v3.0.2 // indirect - go.opentelemetry.io/auto/sdk v1.2.1 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 // indirect - go.opentelemetry.io/otel v1.39.0 // indirect - go.opentelemetry.io/otel/metric v1.39.0 // indirect - go.opentelemetry.io/otel/trace v1.39.0 // indirect - golang.org/x/crypto v0.47.0 // indirect - golang.org/x/net v0.49.0 // indirect - golang.org/x/oauth2 v0.32.0 // indirect - golang.org/x/sys v0.40.0 // indirect - golang.org/x/text v0.33.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20260122232226-8e98ce8d340d // indirect - google.golang.org/grpc v1.78.0 // indirect - google.golang.org/protobuf v1.36.11 // indirect -) diff --git a/go.sum b/go.sum deleted file mode 100644 index a7a8e88..0000000 --- a/go.sum +++ /dev/null @@ -1,145 +0,0 @@ -cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE= -cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU= -cloud.google.com/go/auth v0.18.1 h1:IwTEx92GFUo2pJ6Qea0EU3zYvKnTAeRCODxfA/G5UWs= -cloud.google.com/go/auth v0.18.1/go.mod h1:GfTYoS9G3CWpRA3Va9doKN9mjPGRS+v41jmZAhBzbrA= -cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= -cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= -github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= -github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= -github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= -github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= -github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= -github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs= -github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg= -github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw= -github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4= -github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs= -github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk= -github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= -github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= -github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ= -github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE= -github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8= -github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs= -github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= -github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= -github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= -github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= -github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= -github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= -github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= -github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= -github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= -github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= -github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= -github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= -github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= -github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= -github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/googleapis/enterprise-certificate-proxy v0.3.11 h1:vAe81Msw+8tKUxi2Dqh/NZMz7475yUvmRIkXr4oN2ao= -github.com/googleapis/enterprise-certificate-proxy v0.3.11/go.mod h1:RFV7MUdlb7AgEq2v7FmMCfeSMCllAzWxFgRdusoGks8= -github.com/googleapis/gax-go/v2 v2.16.0 h1:iHbQmKLLZrexmb0OSsNGTeSTS0HO4YvFOG8g5E4Zd0Y= -github.com/googleapis/gax-go/v2 v2.16.0/go.mod h1:o1vfQjjNZn4+dPnRdl/4ZD7S9414Y4xA+a/6Icj6l14= -github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= -github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= -github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= -github.com/liushuangls/go-anthropic/v2 v2.17.0 h1:iBA6h7aghi1q86owEQ95XE2R2MF/0dQ7bCxtwTxOg4c= -github.com/liushuangls/go-anthropic/v2 v2.17.0/go.mod h1:a550cJXPoTG2FL3DvfKG2zzD5O2vjgvo4tHtoGPzFLU= -github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= -github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= -github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= -github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= -github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= -github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s= -github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= -github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= -github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= -github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= -github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= -github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= -github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= -github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0= -github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= -github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= -github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= -github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= -github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= -github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM= -github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= -github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= -github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= -github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= -github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= -github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= -github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= -github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= -go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= -go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 h1:ssfIgGNANqpVFCndZvcuyKbl0g+UAVcbBcqGkG28H0Y= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0/go.mod h1:GQ/474YrbE4Jx8gZ4q5I4hrhUzM6UPzyrqJYV2AqPoQ= -go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48= -go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8= -go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0= -go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs= -go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18= -go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE= -go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8= -go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= -go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= -go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= -golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= -golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= -golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E= -golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= -golang.org/x/image v0.35.0 h1:LKjiHdgMtO8z7Fh18nGY6KDcoEtVfsgLDPeLyguqb7I= -golang.org/x/image v0.35.0/go.mod h1:MwPLTVgvxSASsxdLzKrl8BRFuyqMyGhLwmC+TO1Sybk= -golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= -golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= -golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY= -golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= -golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= -golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= -golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= -golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= -gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= -gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= -google.golang.org/genai v1.43.0 h1:8vhqhzJNZu1U94e2m+KvDq/TUUjSmDrs1aKkvTa8SoM= -google.golang.org/genai v1.43.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260122232226-8e98ce8d340d h1:xXzuihhT3gL/ntduUZwHECzAn57E8dA6l8SOtYWdD8Q= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260122232226-8e98ce8d340d/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= -google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc= -google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U= -google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= -google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/google.go b/google.go deleted file mode 100644 index 6316f63..0000000 --- a/google.go +++ /dev/null @@ -1,165 +0,0 @@ -package llm - -import ( - "context" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - - "google.golang.org/genai" -) - -type googleImpl struct { - key string - model string -} - -var _ LLM = googleImpl{} - -func (g googleImpl) ModelVersion(modelVersion string) (ChatCompletion, error) { - g.model = modelVersion - - return g, nil -} - -func (g googleImpl) requestToContents(in Request) ([]*genai.Content, *genai.GenerateContentConfig) { - var contents []*genai.Content - var cfg genai.GenerateContentConfig - - for _, tool := range in.Toolbox.Functions() { - cfg.Tools = append(cfg.Tools, &genai.Tool{ - FunctionDeclarations: []*genai.FunctionDeclaration{ - { - Name: tool.Name, - Description: tool.Description, - Parameters: tool.Parameters.GoogleParameters(), - }, - }, - }) - } - - if in.Toolbox.RequiresTool() { - cfg.ToolConfig = &genai.ToolConfig{FunctionCallingConfig: &genai.FunctionCallingConfig{ - Mode: genai.FunctionCallingConfigModeAny, - }} - } - - for _, c := range in.Messages { - var role genai.Role - switch c.Role { - case RoleAssistant, RoleSystem: - role = genai.RoleModel - case RoleUser: - role = genai.RoleUser - } - - var parts []*genai.Part - if c.Text != "" { - parts = append(parts, genai.NewPartFromText(c.Text)) - } - - for _, img := range c.Images { - if img.Url != "" { - // gemini does not support URLs, so we need to download the image and convert it to a blob - resp, err := http.Get(img.Url) - if err != nil { - panic(fmt.Sprintf("error downloading image: %v", err)) - } - defer resp.Body.Close() - - if resp.ContentLength > 20*1024*1024 { - panic(fmt.Sprintf("image size exceeds 20MB: %d bytes", resp.ContentLength)) - } - - data, err := io.ReadAll(resp.Body) - if err != nil { - panic(fmt.Sprintf("error reading image data: %v", err)) - } - - mimeType := http.DetectContentType(data) - switch mimeType { - case "image/jpeg", "image/png", "image/gif": - // MIME type is valid - default: - panic(fmt.Sprintf("unsupported image MIME type: %s", mimeType)) - } - - parts = append(parts, genai.NewPartFromBytes(data, mimeType)) - } else { - b, e := base64.StdEncoding.DecodeString(img.Base64) - if e != nil { - panic(fmt.Sprintf("error decoding base64: %v", e)) - } - - parts = append(parts, genai.NewPartFromBytes(b, img.ContentType)) - } - } - - contents = append(contents, genai.NewContentFromParts(parts, role)) - } - - return contents, &cfg -} - -func (g googleImpl) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) { - res := Response{} - - for _, c := range in.Candidates { - var choice ResponseChoice - var set = false - if c.Content != nil { - for _, p := range c.Content.Parts { - if p.Text != "" { - set = true - choice.Content = p.Text - } else if p.FunctionCall != nil { - v := p.FunctionCall - b, e := json.Marshal(v.Args) - if e != nil { - return Response{}, fmt.Errorf("error marshalling args: %w", e) - } - - call := ToolCall{ - ID: v.Name, - FunctionCall: FunctionCall{ - Name: v.Name, - Arguments: string(b), - }, - } - - choice.Calls = append(choice.Calls, call) - set = true - } - } - } - - if set { - choice.Role = RoleAssistant - res.Choices = append(res.Choices, choice) - } - } - - return res, nil -} - -func (g googleImpl) ChatComplete(ctx context.Context, req Request) (Response, error) { - cl, err := genai.NewClient(ctx, &genai.ClientConfig{ - APIKey: g.key, - Backend: genai.BackendGeminiAPI, - }) - - if err != nil { - return Response{}, fmt.Errorf("error creating genai client: %w", err) - } - - contents, cfg := g.requestToContents(req) - - resp, err := cl.Models.GenerateContent(ctx, g.model, contents, cfg) - if err != nil { - return Response{}, fmt.Errorf("error generating content: %w", err) - } - - return g.responseToLLMResponse(resp) -} diff --git a/internal/imageutil/compress.go b/internal/imageutil/compress.go deleted file mode 100644 index cba4d25..0000000 --- a/internal/imageutil/compress.go +++ /dev/null @@ -1,114 +0,0 @@ -package imageutil - -import ( - "bytes" - "encoding/base64" - "fmt" - "image" - "image/gif" - "image/jpeg" - "net/http" - - "golang.org/x/image/draw" -) - -// CompressImage takes a base-64-encoded image (JPEG, PNG or GIF) and returns -// a base-64-encoded version that is at most maxLength in size, or an error. -func CompressImage(b64 string, maxLength int) (string, string, error) { - raw, err := base64.StdEncoding.DecodeString(b64) - if err != nil { - return "", "", fmt.Errorf("base64 decode: %w", err) - } - - mime := http.DetectContentType(raw) - if len(raw) <= maxLength { - return b64, mime, nil // small enough already - } - - switch mime { - case "image/gif": - return compressGIF(raw, maxLength) - - default: // jpeg, png, webp, etc. -> treat as raster - return compressRaster(raw, maxLength) - } -} - -// ---------- Raster path (jpeg / png / single-frame gif) ---------- - -func compressRaster(src []byte, maxLength int) (string, string, error) { - img, _, err := image.Decode(bytes.NewReader(src)) - if err != nil { - return "", "", fmt.Errorf("decode raster: %w", err) - } - - quality := 95 - for { - var buf bytes.Buffer - if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: quality}); err != nil { - return "", "", fmt.Errorf("jpeg encode: %w", err) - } - if buf.Len() <= maxLength { - return base64.StdEncoding.EncodeToString(buf.Bytes()), "image/jpeg", nil - } - - if quality > 20 { - quality -= 5 - continue - } - - // down-scale 80% - b := img.Bounds() - if b.Dx() < 100 || b.Dy() < 100 { - return "", "", fmt.Errorf("cannot compress below %.02fMiB without destroying image", float64(maxLength)/1048576.0) - } - dst := image.NewRGBA(image.Rect(0, 0, int(float64(b.Dx())*0.8), int(float64(b.Dy())*0.8))) - draw.ApproxBiLinear.Scale(dst, dst.Bounds(), img, b, draw.Over, nil) - img = dst - quality = 95 // restart ladder - } -} - -// ---------- Animated GIF path ---------- - -func compressGIF(src []byte, maxLength int) (string, string, error) { - g, err := gif.DecodeAll(bytes.NewReader(src)) - if err != nil { - return "", "", fmt.Errorf("gif decode: %w", err) - } - - for { - var buf bytes.Buffer - if err := gif.EncodeAll(&buf, g); err != nil { - return "", "", fmt.Errorf("gif encode: %w", err) - } - if buf.Len() <= maxLength { - return base64.StdEncoding.EncodeToString(buf.Bytes()), "image/gif", nil - } - - // down-scale every frame by 80% - w, h := g.Config.Width, g.Config.Height - if w < 100 || h < 100 { - return "", "", fmt.Errorf("cannot compress animated GIF below 5 MiB without excessive quality loss") - } - - nw, nh := int(float64(w)*0.8), int(float64(h)*0.8) - for i, frm := range g.Image { - // convert paletted frame -> RGBA for scaling - rgba := image.NewRGBA(frm.Bounds()) - draw.Draw(rgba, rgba.Bounds(), frm, frm.Bounds().Min, draw.Src) - - // scaled destination - dst := image.NewRGBA(image.Rect(0, 0, nw, nh)) - draw.ApproxBiLinear.Scale(dst, dst.Bounds(), rgba, rgba.Bounds(), draw.Over, nil) - - // quantize back to paletted using default encoder quantizer - paletted := image.NewPaletted(dst.Bounds(), nil) - draw.FloydSteinberg.Draw(paletted, paletted.Bounds(), dst, dst.Bounds().Min) - - g.Image[i] = paletted - } - g.Config.Width, g.Config.Height = nw, nh - // loop back and test size again ... - } -} diff --git a/llm.go b/llm.go deleted file mode 100644 index 3038430..0000000 --- a/llm.go +++ /dev/null @@ -1,30 +0,0 @@ -package llm - -import ( - "context" -) - -// ChatCompletion is the interface for chat completion. -type ChatCompletion interface { - ChatComplete(ctx context.Context, req Request) (Response, error) -} - -// LLM is the interface for language model providers. -type LLM interface { - ModelVersion(modelVersion string) (ChatCompletion, error) -} - -// OpenAI creates a new OpenAI LLM provider with the given API key. -func OpenAI(key string) LLM { - return openaiImpl{key: key} -} - -// Anthropic creates a new Anthropic LLM provider with the given API key. -func Anthropic(key string) LLM { - return anthropicImpl{key: key} -} - -// Google creates a new Google LLM provider with the given API key. -func Google(key string) LLM { - return googleImpl{key: key} -} diff --git a/mcp.go b/mcp.go deleted file mode 100644 index 342a6fb..0000000 --- a/mcp.go +++ /dev/null @@ -1,238 +0,0 @@ -package llm - -import ( - "context" - "encoding/json" - "fmt" - "os" - "os/exec" - "sync" - - "github.com/modelcontextprotocol/go-sdk/mcp" - - "gitea.stevedudenhoeffer.com/steve/go-llm/schema" -) - -// MCPServer represents a connection to an MCP server. -// It manages the lifecycle of the connection and provides access to the server's tools. -type MCPServer struct { - // Name is a friendly name for this server (used for logging/identification) - Name string - - // Command is the command to run the MCP server (for stdio transport) - Command string - - // Args are arguments to pass to the command - Args []string - - // Env are environment variables to set for the command (in addition to current environment) - Env []string - - // URL is the URL for SSE or HTTP transport (alternative to Command) - URL string - - // Transport specifies the transport type: "stdio" (default), "sse", or "http" - Transport string - - client *mcp.Client - session *mcp.ClientSession - tools map[string]*mcp.Tool // tool name -> tool definition - mu sync.RWMutex -} - -// Connect establishes a connection to the MCP server. -func (m *MCPServer) Connect(ctx context.Context) error { - m.mu.Lock() - defer m.mu.Unlock() - - if m.session != nil { - return nil // Already connected - } - - m.client = mcp.NewClient(&mcp.Implementation{ - Name: "go-llm", - Version: "1.0.0", - }, nil) - - var transport mcp.Transport - - switch m.Transport { - case "sse": - transport = &mcp.SSEClientTransport{ - Endpoint: m.URL, - } - case "http": - transport = &mcp.StreamableClientTransport{ - Endpoint: m.URL, - } - default: // "stdio" or empty - cmd := exec.Command(m.Command, m.Args...) - cmd.Env = append(os.Environ(), m.Env...) - transport = &mcp.CommandTransport{ - Command: cmd, - } - } - - session, err := m.client.Connect(ctx, transport, nil) - if err != nil { - return fmt.Errorf("failed to connect to MCP server %s: %w", m.Name, err) - } - - m.session = session - - // Load tools - m.tools = make(map[string]*mcp.Tool) - for tool, err := range session.Tools(ctx, nil) { - if err != nil { - m.session.Close() - m.session = nil - return fmt.Errorf("failed to list tools from %s: %w", m.Name, err) - } - m.tools[tool.Name] = tool - } - - return nil -} - -// Close closes the connection to the MCP server. -func (m *MCPServer) Close() error { - m.mu.Lock() - defer m.mu.Unlock() - - if m.session == nil { - return nil - } - - err := m.session.Close() - m.session = nil - m.tools = nil - return err -} - -// IsConnected returns true if the server is connected. -func (m *MCPServer) IsConnected() bool { - m.mu.RLock() - defer m.mu.RUnlock() - return m.session != nil -} - -// Tools returns the list of tool names available from this server. -func (m *MCPServer) Tools() []string { - m.mu.RLock() - defer m.mu.RUnlock() - - var names []string - for name := range m.tools { - names = append(names, name) - } - return names -} - -// HasTool returns true if this server provides the named tool. -func (m *MCPServer) HasTool(name string) bool { - m.mu.RLock() - defer m.mu.RUnlock() - _, ok := m.tools[name] - return ok -} - -// CallTool calls a tool on the MCP server. -func (m *MCPServer) CallTool(ctx context.Context, name string, arguments map[string]any) (any, error) { - m.mu.RLock() - session := m.session - m.mu.RUnlock() - - if session == nil { - return nil, fmt.Errorf("not connected to MCP server %s", m.Name) - } - - result, err := session.CallTool(ctx, &mcp.CallToolParams{ - Name: name, - Arguments: arguments, - }) - if err != nil { - return nil, err - } - - // Process the result content - if len(result.Content) == 0 { - return nil, nil - } - - // If there's a single text content, return it as a string - if len(result.Content) == 1 { - if textContent, ok := result.Content[0].(*mcp.TextContent); ok { - return textContent.Text, nil - } - } - - // For multiple contents or non-text, serialize to string - return contentToString(result.Content), nil -} - -// toFunction converts an MCP tool to a go-llm Function (for schema purposes only). -func (m *MCPServer) toFunction(tool *mcp.Tool) Function { - var inputSchema map[string]any - if tool.InputSchema != nil { - data, err := json.Marshal(tool.InputSchema) - if err == nil { - _ = json.Unmarshal(data, &inputSchema) - } - } - - if inputSchema == nil { - inputSchema = map[string]any{ - "type": "object", - "properties": map[string]any{}, - } - } - - return Function{ - Name: tool.Name, - Description: tool.Description, - Parameters: schema.NewRaw(inputSchema), - } -} - -// contentToString converts MCP content to a string representation. -func contentToString(content []mcp.Content) string { - var parts []string - for _, c := range content { - switch tc := c.(type) { - case *mcp.TextContent: - parts = append(parts, tc.Text) - default: - if data, err := json.Marshal(c); err == nil { - parts = append(parts, string(data)) - } - } - } - if len(parts) == 1 { - return parts[0] - } - data, _ := json.Marshal(parts) - return string(data) -} - -// WithMCPServer adds an MCP server to the toolbox. -// The server must already be connected. Tools from the server will be available -// for use, and tool calls will be routed to the appropriate server. -func (t ToolBox) WithMCPServer(server *MCPServer) ToolBox { - if t.mcpServers == nil { - t.mcpServers = make(map[string]*MCPServer) - } - - server.mu.RLock() - defer server.mu.RUnlock() - - for name, tool := range server.tools { - // Add the function definition (for schema) - fn := server.toFunction(tool) - t.functions[name] = fn - - // Track which server owns this tool - t.mcpServers[name] = server - } - - return t -} diff --git a/message.go b/message.go deleted file mode 100644 index b1296ba..0000000 --- a/message.go +++ /dev/null @@ -1,115 +0,0 @@ -package llm - -// Role represents the role of a message in a conversation. -type Role string - -const ( - RoleSystem Role = "system" - RoleUser Role = "user" - RoleAssistant Role = "assistant" -) - -// Image represents an image that can be included in a message. -type Image struct { - Base64 string - ContentType string - Url string -} - -func (i Image) toRaw() map[string]any { - res := map[string]any{ - "base64": i.Base64, - "contenttype": i.ContentType, - "url": i.Url, - } - - return res -} - -func (i *Image) fromRaw(raw map[string]any) Image { - var res Image - - res.Base64 = raw["base64"].(string) - res.ContentType = raw["contenttype"].(string) - res.Url = raw["url"].(string) - - return res -} - -// Message represents a message in a conversation. -type Message struct { - Role Role - Name string - Text string - Images []Image -} - -func (m Message) toRaw() map[string]any { - res := map[string]any{ - "role": m.Role, - "name": m.Name, - "text": m.Text, - } - - images := make([]map[string]any, 0, len(m.Images)) - for _, img := range m.Images { - images = append(images, img.toRaw()) - } - - res["images"] = images - - return res -} - -func (m *Message) fromRaw(raw map[string]any) Message { - var res Message - - res.Role = Role(raw["role"].(string)) - res.Name = raw["name"].(string) - res.Text = raw["text"].(string) - - images := raw["images"].([]map[string]any) - for _, img := range images { - var i Image - - res.Images = append(res.Images, i.fromRaw(img)) - } - - return res -} - -// ToolCall represents a tool call made by an assistant. -type ToolCall struct { - ID string - FunctionCall FunctionCall -} - -func (t ToolCall) toRaw() map[string]any { - res := map[string]any{ - "id": t.ID, - } - - res["function"] = t.FunctionCall.toRaw() - - return res -} - -// ToolCallResponse represents the response to a tool call. -type ToolCallResponse struct { - ID string - Result any - Error error -} - -func (t ToolCallResponse) toRaw() map[string]any { - res := map[string]any{ - "id": t.ID, - "result": t.Result, - } - - if t.Error != nil { - res["error"] = t.Error.Error() - } - - return res -} diff --git a/openai.go b/openai.go deleted file mode 100644 index 5d71402..0000000 --- a/openai.go +++ /dev/null @@ -1,322 +0,0 @@ -package llm - -import ( - "context" - "fmt" - "strings" - - "github.com/openai/openai-go" - "github.com/openai/openai-go/option" - "github.com/openai/openai-go/packages/param" - "github.com/openai/openai-go/shared" -) - -type openaiImpl struct { - key string - model string - baseUrl string -} - -var _ LLM = openaiImpl{} - -func (o openaiImpl) newRequestToOpenAIRequest(request Request) openai.ChatCompletionNewParams { - res := openai.ChatCompletionNewParams{ - Model: o.model, - } - - for _, i := range request.Conversation { - res.Messages = append(res.Messages, inputToChatCompletionMessages(i, o.model)...) - } - - for _, msg := range request.Messages { - res.Messages = append(res.Messages, messageToChatCompletionMessages(msg, o.model)...) - } - - for _, tool := range request.Toolbox.Functions() { - res.Tools = append(res.Tools, openai.ChatCompletionToolParam{ - Type: "function", - Function: shared.FunctionDefinitionParam{ - Name: tool.Name, - Description: openai.String(tool.Description), - Strict: openai.Bool(tool.Strict), - Parameters: tool.Parameters.OpenAIParameters(), - }, - }) - } - - if request.Toolbox.RequiresTool() { - res.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{ - OfAuto: openai.String("required"), - } - } - - if request.Temperature != nil { - // these are known models that do not support custom temperatures - // all the o* models - // gpt-5* models - if !strings.HasPrefix(o.model, "o") && !strings.HasPrefix(o.model, "gpt-5") { - res.Temperature = openai.Float(*request.Temperature) - } - } - - return res -} - -func (o openaiImpl) responseToLLMResponse(response *openai.ChatCompletion) Response { - var res Response - - if response == nil { - return res - } - - if len(response.Choices) == 0 { - return res - } - - for _, choice := range response.Choices { - var toolCalls []ToolCall - for _, call := range choice.Message.ToolCalls { - toolCall := ToolCall{ - ID: call.ID, - FunctionCall: FunctionCall{ - Name: call.Function.Name, - Arguments: strings.TrimSpace(call.Function.Arguments), - }, - } - - toolCalls = append(toolCalls, toolCall) - - } - res.Choices = append(res.Choices, ResponseChoice{ - Content: choice.Message.Content, - Role: Role(choice.Message.Role), - Refusal: choice.Message.Refusal, - Calls: toolCalls, - }) - } - - return res -} - -func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) { - var opts = []option.RequestOption{ - option.WithAPIKey(o.key), - } - - if o.baseUrl != "" { - opts = append(opts, option.WithBaseURL(o.baseUrl)) - } - - cl := openai.NewClient(opts...) - - req := o.newRequestToOpenAIRequest(request) - - resp, err := cl.Chat.Completions.New(ctx, req) - - if err != nil { - return Response{}, fmt.Errorf("unhandled openai error: %w", err) - } - - return o.responseToLLMResponse(resp), nil -} - -func (o openaiImpl) ModelVersion(modelVersion string) (ChatCompletion, error) { - return openaiImpl{ - key: o.key, - model: modelVersion, - baseUrl: o.baseUrl, - }, nil -} - -// inputToChatCompletionMessages converts an Input to OpenAI chat completion messages. -func inputToChatCompletionMessages(input Input, model string) []openai.ChatCompletionMessageParamUnion { - switch v := input.(type) { - case Message: - return messageToChatCompletionMessages(v, model) - case ToolCall: - return toolCallToChatCompletionMessages(v) - case ToolCallResponse: - return toolCallResponseToChatCompletionMessages(v) - case ResponseChoice: - return responseChoiceToChatCompletionMessages(v) - default: - return nil - } -} - -func messageToChatCompletionMessages(m Message, model string) []openai.ChatCompletionMessageParamUnion { - var res openai.ChatCompletionMessageParamUnion - - var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam - var textContent param.Opt[string] - - for _, img := range m.Images { - if img.Base64 != "" { - arrayOfContentParts = append(arrayOfContentParts, - openai.ChatCompletionContentPartUnionParam{ - OfImageURL: &openai.ChatCompletionContentPartImageParam{ - ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ - URL: "data:" + img.ContentType + ";base64," + img.Base64, - }, - }, - }, - ) - } else if img.Url != "" { - arrayOfContentParts = append(arrayOfContentParts, - openai.ChatCompletionContentPartUnionParam{ - OfImageURL: &openai.ChatCompletionContentPartImageParam{ - ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ - URL: img.Url, - }, - }, - }, - ) - } - } - - if m.Text != "" { - if len(arrayOfContentParts) > 0 { - arrayOfContentParts = append(arrayOfContentParts, - openai.ChatCompletionContentPartUnionParam{ - OfText: &openai.ChatCompletionContentPartTextParam{ - Text: "\n", - }, - }, - ) - } else { - textContent = openai.String(m.Text) - } - } - - a := strings.Split(model, "-") - - useSystemInsteadOfDeveloper := true - if len(a) > 1 && a[0][0] == 'o' { - useSystemInsteadOfDeveloper = false - } - - switch m.Role { - case RoleSystem: - if useSystemInsteadOfDeveloper { - res = openai.ChatCompletionMessageParamUnion{ - OfSystem: &openai.ChatCompletionSystemMessageParam{ - Content: openai.ChatCompletionSystemMessageParamContentUnion{ - OfString: textContent, - }, - }, - } - } else { - res = openai.ChatCompletionMessageParamUnion{ - OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{ - Content: openai.ChatCompletionDeveloperMessageParamContentUnion{ - OfString: textContent, - }, - }, - } - } - - case RoleUser: - var name param.Opt[string] - if m.Name != "" { - name = openai.String(m.Name) - } - - res = openai.ChatCompletionMessageParamUnion{ - OfUser: &openai.ChatCompletionUserMessageParam{ - Name: name, - Content: openai.ChatCompletionUserMessageParamContentUnion{ - OfString: textContent, - OfArrayOfContentParts: arrayOfContentParts, - }, - }, - } - - case RoleAssistant: - var name param.Opt[string] - if m.Name != "" { - name = openai.String(m.Name) - } - - res = openai.ChatCompletionMessageParamUnion{ - OfAssistant: &openai.ChatCompletionAssistantMessageParam{ - Name: name, - Content: openai.ChatCompletionAssistantMessageParamContentUnion{ - OfString: textContent, - }, - }, - } - } - - return []openai.ChatCompletionMessageParamUnion{res} -} - -func toolCallToChatCompletionMessages(t ToolCall) []openai.ChatCompletionMessageParamUnion { - return []openai.ChatCompletionMessageParamUnion{{ - OfAssistant: &openai.ChatCompletionAssistantMessageParam{ - ToolCalls: []openai.ChatCompletionMessageToolCallParam{ - { - ID: t.ID, - Function: openai.ChatCompletionMessageToolCallFunctionParam{ - Name: t.FunctionCall.Name, - Arguments: t.FunctionCall.Arguments, - }, - }, - }, - }, - }} -} - -func toolCallResponseToChatCompletionMessages(t ToolCallResponse) []openai.ChatCompletionMessageParamUnion { - var refusal string - if t.Error != nil { - refusal = t.Error.Error() - } - - result := t.Result - if refusal != "" { - if result != "" { - result = fmt.Sprint(result) + " (error in execution: " + refusal + ")" - } else { - result = "error in execution:" + refusal - } - } - - return []openai.ChatCompletionMessageParamUnion{{ - OfTool: &openai.ChatCompletionToolMessageParam{ - ToolCallID: t.ID, - Content: openai.ChatCompletionToolMessageParamContentUnion{ - OfString: openai.String(fmt.Sprint(result)), - }, - }, - }} -} - -func responseChoiceToChatCompletionMessages(r ResponseChoice) []openai.ChatCompletionMessageParamUnion { - var as openai.ChatCompletionAssistantMessageParam - - if r.Name != "" { - as.Name = openai.String(r.Name) - } - if r.Refusal != "" { - as.Refusal = openai.String(r.Refusal) - } - - if r.Content != "" { - as.Content.OfString = openai.String(r.Content) - } - - for _, call := range r.Calls { - as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{ - ID: call.ID, - Function: openai.ChatCompletionMessageToolCallFunctionParam{ - Name: call.FunctionCall.Name, - Arguments: call.FunctionCall.Arguments, - }, - }) - } - return []openai.ChatCompletionMessageParamUnion{ - { - OfAssistant: &as, - }, - } -} diff --git a/openai_transcriber.go b/openai_transcriber.go deleted file mode 100644 index 9afe434..0000000 --- a/openai_transcriber.go +++ /dev/null @@ -1,219 +0,0 @@ -package llm - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "strings" - - "github.com/openai/openai-go" - "github.com/openai/openai-go/option" -) - -type openaiTranscriber struct { - key string - model string - baseUrl string -} - -var _ Transcriber = openaiTranscriber{} - -// OpenAITranscriber creates a transcriber backed by OpenAI's audio models. -// If model is empty, whisper-1 is used by default. -func OpenAITranscriber(key string, model string) Transcriber { - if strings.TrimSpace(model) == "" { - model = "whisper-1" - } - return openaiTranscriber{ - key: key, - model: model, - } -} - -func (o openaiTranscriber) Transcribe(ctx context.Context, wav []byte, opts TranscriptionOptions) (Transcription, error) { - if len(wav) == 0 { - return Transcription{}, fmt.Errorf("wav data is empty") - } - - format := opts.ResponseFormat - if format == "" { - if strings.HasPrefix(o.model, "gpt-4o") { - format = TranscriptionResponseFormatJSON - } else { - format = TranscriptionResponseFormatVerboseJSON - } - } - - if format != TranscriptionResponseFormatJSON && format != TranscriptionResponseFormatVerboseJSON { - return Transcription{}, fmt.Errorf("openai transcriber requires response_format json or verbose_json for structured output") - } - - if len(opts.TimestampGranularities) > 0 && format != TranscriptionResponseFormatVerboseJSON { - return Transcription{}, fmt.Errorf("timestamp granularities require response_format=verbose_json") - } - - params := openai.AudioTranscriptionNewParams{ - File: openai.File(bytes.NewReader(wav), "audio.wav", "audio/wav"), - Model: openai.AudioModel(o.model), - } - - if opts.Language != "" { - params.Language = openai.String(opts.Language) - } - if opts.Prompt != "" { - params.Prompt = openai.String(opts.Prompt) - } - if opts.Temperature != nil { - params.Temperature = openai.Float(*opts.Temperature) - } - - params.ResponseFormat = openai.AudioResponseFormat(format) - - if opts.IncludeLogprobs { - params.Include = []openai.TranscriptionInclude{openai.TranscriptionIncludeLogprobs} - } - - if len(opts.TimestampGranularities) > 0 { - for _, granularity := range opts.TimestampGranularities { - params.TimestampGranularities = append(params.TimestampGranularities, string(granularity)) - } - } - - clientOptions := []option.RequestOption{ - option.WithAPIKey(o.key), - } - if o.baseUrl != "" { - clientOptions = append(clientOptions, option.WithBaseURL(o.baseUrl)) - } - - client := openai.NewClient(clientOptions...) - resp, err := client.Audio.Transcriptions.New(ctx, params) - if err != nil { - return Transcription{}, fmt.Errorf("openai transcription failed: %w", err) - } - - return openaiTranscriptionToResult(o.model, resp), nil -} - -type openaiVerboseTranscription struct { - Text string `json:"text"` - Language string `json:"language"` - Duration float64 `json:"duration"` - Segments []openaiVerboseSegment `json:"segments"` - Words []openaiVerboseWord `json:"words"` -} - -type openaiVerboseSegment struct { - ID int `json:"id"` - Start float64 `json:"start"` - End float64 `json:"end"` - Text string `json:"text"` - Tokens []int `json:"tokens"` - AvgLogprob *float64 `json:"avg_logprob"` - CompressionRatio *float64 `json:"compression_ratio"` - NoSpeechProb *float64 `json:"no_speech_prob"` - Words []openaiVerboseWord `json:"words"` -} - -type openaiVerboseWord struct { - Word string `json:"word"` - Start float64 `json:"start"` - End float64 `json:"end"` -} - -func openaiTranscriptionToResult(model string, resp *openai.Transcription) Transcription { - result := Transcription{ - Provider: "openai", - Model: model, - } - if resp == nil { - return result - } - - result.Text = resp.Text - result.RawJSON = resp.RawJSON() - - for _, logprob := range resp.Logprobs { - result.Logprobs = append(result.Logprobs, TranscriptionTokenLogprob{ - Token: logprob.Token, - Bytes: logprob.Bytes, - Logprob: logprob.Logprob, - }) - } - - if usage := openaiUsageToTranscriptionUsage(resp.Usage); usage.Type != "" { - result.Usage = usage - } - - if result.RawJSON == "" { - return result - } - - var verbose openaiVerboseTranscription - if err := json.Unmarshal([]byte(result.RawJSON), &verbose); err != nil { - return result - } - - if verbose.Text != "" { - result.Text = verbose.Text - } - result.Language = verbose.Language - result.DurationSeconds = verbose.Duration - - for _, seg := range verbose.Segments { - segment := TranscriptionSegment{ - ID: seg.ID, - Start: seg.Start, - End: seg.End, - Text: seg.Text, - Tokens: append([]int(nil), seg.Tokens...), - AvgLogprob: seg.AvgLogprob, - CompressionRatio: seg.CompressionRatio, - NoSpeechProb: seg.NoSpeechProb, - } - - for _, word := range seg.Words { - segment.Words = append(segment.Words, TranscriptionWord{ - Word: word.Word, - Start: word.Start, - End: word.End, - }) - } - - result.Segments = append(result.Segments, segment) - } - - for _, word := range verbose.Words { - result.Words = append(result.Words, TranscriptionWord{ - Word: word.Word, - Start: word.Start, - End: word.End, - }) - } - - return result -} - -func openaiUsageToTranscriptionUsage(usage openai.TranscriptionUsageUnion) TranscriptionUsage { - switch usage.Type { - case "tokens": - tokens := usage.AsTokens() - return TranscriptionUsage{ - Type: usage.Type, - InputTokens: tokens.InputTokens, - OutputTokens: tokens.OutputTokens, - TotalTokens: tokens.TotalTokens, - AudioTokens: tokens.InputTokenDetails.AudioTokens, - TextTokens: tokens.InputTokenDetails.TextTokens, - } - case "duration": - duration := usage.AsDuration() - return TranscriptionUsage{ - Type: usage.Type, - Seconds: duration.Seconds, - } - default: - return TranscriptionUsage{} - } -} diff --git a/parse.go b/parse.go deleted file mode 100644 index a9614ff..0000000 --- a/parse.go +++ /dev/null @@ -1,50 +0,0 @@ -package llm - -import ( - "strings" -) - -// Providers are the allowed shortcuts in the providers, e.g.: if you set { "openai": OpenAI("key") } that'll allow -// for the "openai" provider to be used when parsed. -type Providers map[string]LLM - -// Parse will parse the provided input and attempt to return a LLM chat completion interface. -// Input should be in the provided format: -// - provider/modelname -// -// where provider is a key inside Providers, and the modelname being passed to the LLM interface's GetModel -func (providers Providers) Parse(input string) ChatCompletion { - sections := strings.Split(input, "/") - - var provider LLM - var ok bool - var modelVersion string - - if len(sections) < 2 { - // is there a default provider? - provider, ok = providers["default"] - if !ok { - panic("expected format: \"provider/model\" or provide a \"default\" provider to the Parse callback") - } - - modelVersion = sections[0] - } else { - provider, ok = providers[sections[0]] - modelVersion = sections[1] - } - - if !ok { - panic("expected format: \"provider/model\" or provide a \"default\" provider to the Parse callback") - } - - if provider == nil { - panic("unknown provider: " + sections[0]) - } - - res, err := provider.ModelVersion(modelVersion) - if err != nil { - panic(err) - } - - return res -} diff --git a/provider/anthropic/anthropic.go b/provider/anthropic/anthropic.go deleted file mode 100644 index 32e13e4..0000000 --- a/provider/anthropic/anthropic.go +++ /dev/null @@ -1,11 +0,0 @@ -// Package anthropic provides the Anthropic LLM provider. -package anthropic - -import ( - llm "gitea.stevedudenhoeffer.com/steve/go-llm" -) - -// New creates a new Anthropic LLM provider with the given API key. -func New(key string) llm.LLM { - return llm.Anthropic(key) -} diff --git a/provider/google/google.go b/provider/google/google.go deleted file mode 100644 index df5fe72..0000000 --- a/provider/google/google.go +++ /dev/null @@ -1,11 +0,0 @@ -// Package google provides the Google LLM provider. -package google - -import ( - llm "gitea.stevedudenhoeffer.com/steve/go-llm" -) - -// New creates a new Google LLM provider with the given API key. -func New(key string) llm.LLM { - return llm.Google(key) -} diff --git a/provider/openai/openai.go b/provider/openai/openai.go deleted file mode 100644 index 2c853f9..0000000 --- a/provider/openai/openai.go +++ /dev/null @@ -1,11 +0,0 @@ -// Package openai provides the OpenAI LLM provider. -package openai - -import ( - llm "gitea.stevedudenhoeffer.com/steve/go-llm" -) - -// New creates a new OpenAI LLM provider with the given API key. -func New(key string) llm.LLM { - return llm.OpenAI(key) -} diff --git a/request.go b/request.go deleted file mode 100644 index 43e430f..0000000 --- a/request.go +++ /dev/null @@ -1,51 +0,0 @@ -package llm - -// Input is the interface for conversation inputs. -// Types that implement this interface can be part of a conversation: -// Message, ToolCall, ToolCallResponse, and ResponseChoice. -type Input interface { - // isInput is a marker method to ensure only valid types implement this interface. - isInput() -} - -// Implement Input interface for all valid input types. -func (Message) isInput() {} -func (ToolCall) isInput() {} -func (ToolCallResponse) isInput() {} -func (ResponseChoice) isInput() {} - -// Request represents a request to a language model. -type Request struct { - Conversation []Input - Messages []Message - Toolbox ToolBox - Temperature *float64 -} - -// NextRequest will take the current request's conversation, messages, the response, and any tool results, and -// return a new request with the conversation updated to include the response and tool results. -func (req Request) NextRequest(resp ResponseChoice, toolResults []ToolCallResponse) Request { - var res Request - - res.Toolbox = req.Toolbox - res.Temperature = req.Temperature - - res.Conversation = make([]Input, len(req.Conversation)) - copy(res.Conversation, req.Conversation) - - // now for every input message, convert those to an Input to add to the conversation - for _, msg := range req.Messages { - res.Conversation = append(res.Conversation, msg) - } - - if resp.Content != "" || resp.Refusal != "" || len(resp.Calls) > 0 { - res.Conversation = append(res.Conversation, resp) - } - - // if there are tool results, then we need to add those to the conversation - for _, result := range toolResults { - res.Conversation = append(res.Conversation, result) - } - - return res -} diff --git a/response.go b/response.go deleted file mode 100644 index e7043b8..0000000 --- a/response.go +++ /dev/null @@ -1,52 +0,0 @@ -package llm - -// ResponseChoice represents a single choice in a response. -type ResponseChoice struct { - Index int - Role Role - Content string - Refusal string - Name string - Calls []ToolCall -} - -func (r ResponseChoice) toRaw() map[string]any { - res := map[string]any{ - "index": r.Index, - "role": r.Role, - "content": r.Content, - "refusal": r.Refusal, - "name": r.Name, - } - - calls := make([]map[string]any, 0, len(r.Calls)) - for _, call := range r.Calls { - calls = append(calls, call.toRaw()) - } - - res["tool_calls"] = calls - - return res -} - -func (r ResponseChoice) toInput() []Input { - var res []Input - - for _, call := range r.Calls { - res = append(res, call) - } - - if r.Content != "" || r.Refusal != "" { - res = append(res, Message{ - Role: RoleAssistant, - Text: r.Content, - }) - } - - return res -} - -// Response represents a response from a language model. -type Response struct { - Choices []ResponseChoice -} diff --git a/schema/GetType.go b/schema/GetType.go deleted file mode 100644 index 342dea2..0000000 --- a/schema/GetType.go +++ /dev/null @@ -1,142 +0,0 @@ -package schema - -import ( - "reflect" - "strings" -) - -// GetType will, given an interface{} that is a struct (NOT a pointer to a struct), return the Type of the struct that -// can be used to generate a json schema and build an object from a parsed json object. -func GetType(a any) Type { - t := reflect.TypeOf(a) - - if t.Kind() != reflect.Struct { - panic("GetType expects a struct") - } - - return getObject(t) -} - -func getFromType(t reflect.Type, b basic) Type { - if t.Kind() == reflect.Ptr { - t = t.Elem() - b.required = false - } - - switch t.Kind() { - case reflect.String: - b.DataType = TypeString - b.typeName = "string" - return b - - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - b.DataType = TypeInteger - b.typeName = "integer" - return b - - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - b.DataType = TypeInteger - b.typeName = "integer" - return b - - case reflect.Float32, reflect.Float64: - b.DataType = TypeNumber - b.typeName = "number" - return b - - case reflect.Bool: - b.DataType = TypeBoolean - b.typeName = "boolean" - return b - - case reflect.Struct: - o := getObject(t) - - o.basic.required = b.required - o.basic.index = b.index - o.basic.description = b.description - - return o - - case reflect.Slice: - return getArray(t) - - default: - panic("unhandled default case for " + t.Kind().String() + " in getFromType") - } -} - -func getField(f reflect.StructField, index int) Type { - b := basic{ - index: index, - required: true, - description: "", - } - - t := f.Type - - // if the tag "description" is set, use that as the description - if desc, ok := f.Tag.Lookup("description"); ok { - b.description = desc - } - - // now if the tag "enum" is set, we need to create an enum type - if v, ok := f.Tag.Lookup("enum"); ok { - vals := strings.Split(v, ",") - - for i := 0; i < len(vals); i++ { - vals[i] = strings.TrimSpace(vals[i]) - - if vals[i] == "" { - vals = append(vals[:i], vals[i+1:]...) - } - } - - b.DataType = TypeString - b.typeName = "string" - return enum{ - basic: b, - values: vals, - } - - } - - return getFromType(t, b) -} - -func getObject(t reflect.Type) Object { - fields := make(map[string]Type, t.NumField()) - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - - if field.Anonymous { - // if the field is anonymous, we need to get the fields of the anonymous struct - // and add them to the object - anon := getObject(field.Type) - for k, v := range anon.fields { - fields[k] = v - } - continue - } else { - fields[field.Name] = getField(field, i) - } - } - - return Object{ - basic: basic{DataType: TypeObject, typeName: "object"}, - fields: fields, - } -} - -func getArray(t reflect.Type) array { - res := array{ - basic: basic{ - DataType: TypeArray, - typeName: "array", - }, - } - - res.items = getFromType(t.Elem(), basic{}) - - return res -} diff --git a/schema/array.go b/schema/array.go deleted file mode 100644 index c18c75c..0000000 --- a/schema/array.go +++ /dev/null @@ -1,77 +0,0 @@ -package schema - -import ( - "errors" - "reflect" - - "github.com/openai/openai-go" - "google.golang.org/genai" -) - -type array struct { - basic - - // items is the schema of the items in the array - items Type -} - -func (a array) OpenAIParameters() openai.FunctionParameters { - return openai.FunctionParameters{ - "type": "array", - "description": a.Description(), - "items": a.items.OpenAIParameters(), - } -} - -func (a array) GoogleParameters() *genai.Schema { - return &genai.Schema{ - Type: genai.TypeArray, - Description: a.Description(), - Items: a.items.GoogleParameters(), - } -} - -func (a array) AnthropicInputSchema() map[string]any { - return map[string]any{ - "type": "array", - "description": a.Description(), - "items": a.items.AnthropicInputSchema(), - } -} - -func (a array) FromAny(val any) (reflect.Value, error) { - v := reflect.ValueOf(val) - - // first realize we may have a pointer to a slice if this type is not required - if !a.required && v.Kind() == reflect.Ptr { - v = v.Elem() - } - - if v.Kind() != reflect.Slice { - return reflect.Value{}, errors.New("expected slice, got " + v.Kind().String()) - } - - // if the slice is nil, we can just return it - if v.IsNil() { - return v, nil - } - - // if the slice is not nil, we need to convert each item - items := make([]reflect.Value, v.Len()) - for i := 0; i < v.Len(); i++ { - item, err := a.items.FromAny(v.Index(i).Interface()) - if err != nil { - return reflect.Value{}, err - } - items[i] = item - } - - return reflect.ValueOf(items), nil -} - -func (a array) SetValue(obj reflect.Value, val reflect.Value) { - if !a.required { - val = val.Addr() - } - obj.Field(a.index).Set(val) -} diff --git a/schema/basic.go b/schema/basic.go deleted file mode 100644 index a38f5a4..0000000 --- a/schema/basic.go +++ /dev/null @@ -1,165 +0,0 @@ -package schema - -import ( - "errors" - "reflect" - "strconv" - - "github.com/openai/openai-go" - "google.golang.org/genai" -) - -// just enforcing that basic implements Type -var _ Type = basic{} - -type DataType string - -const ( - TypeString DataType = "string" - TypeInteger DataType = "integer" - TypeNumber DataType = "number" - TypeBoolean DataType = "boolean" - TypeObject DataType = "object" - TypeArray DataType = "array" -) - -type basic struct { - DataType - typeName string - - // index is the position of the parameter in the StructField of the function's parameter struct - index int - - // required is a flag that indicates whether the parameter is required in the function's parameter struct. - // this is inferred by if the parameter is a pointer type or not. - required bool - - // description is a llm-readable description of the parameter passed to openai - description string -} - -func (b basic) OpenAIParameters() openai.FunctionParameters { - return openai.FunctionParameters{ - "type": b.typeName, - "description": b.description, - } -} - -func (b basic) GoogleParameters() *genai.Schema { - var t = genai.TypeUnspecified - - switch b.DataType { - case TypeString: - t = genai.TypeString - case TypeInteger: - t = genai.TypeInteger - case TypeNumber: - t = genai.TypeNumber - case TypeBoolean: - t = genai.TypeBoolean - case TypeObject: - t = genai.TypeObject - case TypeArray: - t = genai.TypeArray - default: - t = genai.TypeUnspecified - } - return &genai.Schema{ - Type: t, - Description: b.description, - } -} - -func (b basic) AnthropicInputSchema() map[string]any { - var t = "string" - - switch b.DataType { - case TypeString: - t = "string" - case TypeInteger: - t = "integer" - case TypeNumber: - t = "number" - case TypeBoolean: - t = "boolean" - case TypeObject: - t = "object" - case TypeArray: - t = "array" - default: - t = "unknown" - } - - return map[string]any{ - "type": t, - "description": b.description, - } -} - -func (b basic) Required() bool { - return b.required -} - -func (b basic) Description() string { - return b.description -} - -func (b basic) FromAny(val any) (reflect.Value, error) { - v := reflect.ValueOf(val) - - switch b.DataType { - case TypeString: - var val = v.String() - - return reflect.ValueOf(val), nil - - case TypeInteger: - if v.Kind() == reflect.Float64 { - return v.Convert(reflect.TypeOf(int(0))), nil - } else if v.Kind() != reflect.Int { - return reflect.Value{}, errors.New("expected int, got " + v.Kind().String()) - } else { - return v, nil - } - - case TypeNumber: - if v.Kind() == reflect.Float64 { - return v.Convert(reflect.TypeOf(float64(0))), nil - } else if v.Kind() != reflect.Float64 { - return reflect.Value{}, errors.New("expected float64, got " + v.Kind().String()) - } else { - return v, nil - } - - case TypeBoolean: - if v.Kind() == reflect.Bool { - return v, nil - } else if v.Kind() == reflect.String { - b, err := strconv.ParseBool(v.String()) - if err != nil { - return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String()) - } - return reflect.ValueOf(b), nil - } else { - return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String()) - } - - default: - return reflect.Value{}, errors.New("unknown type") - } -} - -func (b basic) SetValueOnField(obj reflect.Value, val reflect.Value) { - // if this basic type is not required that means it's a pointer type - // so we need to create a new value of the type of the pointer - if !b.required { - vv := reflect.New(obj.Field(b.index).Type().Elem()) - - // and then set the value of the pointer to the new value - vv.Elem().Set(val) - - obj.Field(b.index).Set(vv) - return - } - obj.Field(b.index).Set(val) -} diff --git a/schema/enum.go b/schema/enum.go deleted file mode 100644 index 1b473a7..0000000 --- a/schema/enum.go +++ /dev/null @@ -1,61 +0,0 @@ -package schema - -import ( - "errors" - "reflect" - "slices" - - "github.com/openai/openai-go" - "google.golang.org/genai" -) - -type enum struct { - basic - - values []string -} - -func (e enum) FunctionParameters() openai.FunctionParameters { - return openai.FunctionParameters{ - "type": "string", - "description": e.Description(), - "enum": e.values, - } -} - -func (e enum) GoogleParameters() *genai.Schema { - return &genai.Schema{ - Type: genai.TypeString, - Description: e.Description(), - Enum: e.values, - } -} - -func (e enum) AnthropicInputSchema() map[string]any { - return map[string]any{ - "type": "string", - "description": e.Description(), - "enum": e.values, - } -} - -func (e enum) FromAny(val any) (reflect.Value, error) { - v := reflect.ValueOf(val) - if v.Kind() != reflect.String { - return reflect.Value{}, errors.New("expected string, got " + v.Kind().String()) - } - - s := v.String() - if !slices.Contains(e.values, s) { - return reflect.Value{}, errors.New("value " + s + " not in enum") - } - - return v, nil -} - -func (e enum) SetValueOnField(obj reflect.Value, val reflect.Value) { - if !e.required { - val = val.Addr() - } - obj.Field(e.index).Set(val) -} diff --git a/schema/object.go b/schema/object.go deleted file mode 100644 index 053baa8..0000000 --- a/schema/object.go +++ /dev/null @@ -1,169 +0,0 @@ -package schema - -import ( - "errors" - "reflect" - - "github.com/openai/openai-go" - "google.golang.org/genai" -) - -const ( - // SyntheticFieldPrefix is any prefix that is added to any synthetic fields that are added to the object, to prevent - // collisions with the fields in the struct. - SyntheticFieldPrefix = "__" -) - -type Object struct { - basic - - ref reflect.Type - - fields map[string]Type - - // syntheticFields are fields that are not in the struct but are generated by a system. - synetheticFields map[string]Type -} - -func (o Object) WithSyntheticField(name string, description string) Object { - if o.synetheticFields == nil { - o.synetheticFields = map[string]Type{} - } - - o.synetheticFields[name] = basic{ - DataType: TypeString, - typeName: "string", - index: -1, - required: false, - description: description, - } - - return o -} - -func (o Object) SyntheticFields() map[string]Type { - return o.synetheticFields -} - -func (o Object) OpenAIParameters() openai.FunctionParameters { - var properties = map[string]openai.FunctionParameters{} - var required []string - for k, v := range o.fields { - properties[k] = v.OpenAIParameters() - if v.Required() { - required = append(required, k) - } - } - - for k, v := range o.synetheticFields { - properties[SyntheticFieldPrefix+k] = v.OpenAIParameters() - if v.Required() { - required = append(required, SyntheticFieldPrefix+k) - } - } - - var res = openai.FunctionParameters{ - "type": "object", - "description": o.Description(), - "properties": properties, - } - - if len(required) > 0 { - res["required"] = required - } - - return res -} - -func (o Object) GoogleParameters() *genai.Schema { - var properties = map[string]*genai.Schema{} - var required []string - for k, v := range o.fields { - properties[k] = v.GoogleParameters() - if v.Required() { - required = append(required, k) - } - } - - var res = &genai.Schema{ - Type: genai.TypeObject, - Description: o.Description(), - Properties: properties, - } - - if len(required) > 0 { - res.Required = required - } - - return res -} - -func (o Object) AnthropicInputSchema() map[string]any { - var properties = map[string]any{} - var required []string - for k, v := range o.fields { - properties[k] = v.AnthropicInputSchema() - if v.Required() { - required = append(required, k) - } - } - - var res = map[string]any{ - "type": "object", - "description": o.Description(), - "properties": properties, - } - - if len(required) > 0 { - res["required"] = required - } - - return res -} - -// FromAny converts the value from any to the correct type, returning the value, and an error if any -func (o Object) FromAny(val any) (reflect.Value, error) { - // if the value is nil, we can't do anything - if val == nil { - return reflect.Value{}, nil - } - - // now make a new object of the type we're trying to parse - obj := reflect.New(o.ref).Elem() - - // now we need to iterate over the fields and set the values - for k, v := range o.fields { - // get the field by name - field := obj.FieldByName(k) - if !field.IsValid() { - return reflect.Value{}, errors.New("field " + k + " not found") - } - - // get the value from the map - val2, ok := val.(map[string]interface{})[k] - if !ok { - return reflect.Value{}, errors.New("field " + k + " not found in map") - } - - // now we need to convert the value to the correct type - val3, err := v.FromAny(val2) - if err != nil { - return reflect.Value{}, err - } - - // now we need to set the value on the field - v.SetValueOnField(field, val3) - - } - - return obj, nil -} - -func (o Object) SetValueOnField(obj reflect.Value, val reflect.Value) { - // if this basic type is not required that means it's a pointer type so we need to set the value to the address of the value - if !o.required { - val = val.Addr() - } - - obj.Field(o.index).Set(val) -} diff --git a/schema/raw.go b/schema/raw.go deleted file mode 100644 index b4dcb27..0000000 --- a/schema/raw.go +++ /dev/null @@ -1,134 +0,0 @@ -package schema - -import ( - "encoding/json" - "fmt" - "reflect" - - "github.com/openai/openai-go" - "google.golang.org/genai" -) - -// Raw represents a raw JSON schema that is passed through directly. -// This is used for MCP tools where we receive the schema from the server. -type Raw struct { - schema map[string]any -} - -// NewRaw creates a new Raw schema from a map. -func NewRaw(schema map[string]any) Raw { - if schema == nil { - schema = map[string]any{ - "type": "object", - "properties": map[string]any{}, - } - } - return Raw{schema: schema} -} - -// NewRawFromJSON creates a new Raw schema from JSON bytes. -func NewRawFromJSON(data []byte) (Raw, error) { - var schema map[string]any - if err := json.Unmarshal(data, &schema); err != nil { - return Raw{}, fmt.Errorf("failed to parse JSON schema: %w", err) - } - return NewRaw(schema), nil -} - -func (r Raw) OpenAIParameters() openai.FunctionParameters { - return openai.FunctionParameters(r.schema) -} - -func (r Raw) GoogleParameters() *genai.Schema { - return mapToGenaiSchema(r.schema) -} - -func (r Raw) AnthropicInputSchema() map[string]any { - return r.schema -} - -func (r Raw) Required() bool { - return false -} - -func (r Raw) Description() string { - if desc, ok := r.schema["description"].(string); ok { - return desc - } - return "" -} - -func (r Raw) FromAny(val any) (reflect.Value, error) { - return reflect.ValueOf(val), nil -} - -func (r Raw) SetValueOnField(obj reflect.Value, val reflect.Value) { - // No-op for raw schemas -} - -// mapToGenaiSchema converts a map[string]any JSON schema to genai.Schema -func mapToGenaiSchema(m map[string]any) *genai.Schema { - if m == nil { - return nil - } - - schema := &genai.Schema{} - - // Type - if t, ok := m["type"].(string); ok { - switch t { - case "string": - schema.Type = genai.TypeString - case "number": - schema.Type = genai.TypeNumber - case "integer": - schema.Type = genai.TypeInteger - case "boolean": - schema.Type = genai.TypeBoolean - case "array": - schema.Type = genai.TypeArray - case "object": - schema.Type = genai.TypeObject - } - } - - // Description - if desc, ok := m["description"].(string); ok { - schema.Description = desc - } - - // Enum - if enum, ok := m["enum"].([]any); ok { - for _, e := range enum { - if s, ok := e.(string); ok { - schema.Enum = append(schema.Enum, s) - } - } - } - - // Properties (for objects) - if props, ok := m["properties"].(map[string]any); ok { - schema.Properties = make(map[string]*genai.Schema) - for k, v := range props { - if vm, ok := v.(map[string]any); ok { - schema.Properties[k] = mapToGenaiSchema(vm) - } - } - } - - // Required - if req, ok := m["required"].([]any); ok { - for _, r := range req { - if s, ok := r.(string); ok { - schema.Required = append(schema.Required, s) - } - } - } - - // Items (for arrays) - if items, ok := m["items"].(map[string]any); ok { - schema.Items = mapToGenaiSchema(items) - } - - return schema -} diff --git a/schema/type.go b/schema/type.go deleted file mode 100644 index 318f9b6..0000000 --- a/schema/type.go +++ /dev/null @@ -1,23 +0,0 @@ -package schema - -import ( - "reflect" - - "github.com/openai/openai-go" - "google.golang.org/genai" -) - -type Type interface { - OpenAIParameters() openai.FunctionParameters - GoogleParameters() *genai.Schema - AnthropicInputSchema() map[string]any - - //SchemaType() jsonschema.DataType - //Definition() jsonschema.Definition - - Required() bool - Description() string - - FromAny(any) (reflect.Value, error) - SetValueOnField(obj reflect.Value, val reflect.Value) -} diff --git a/toolbox.go b/toolbox.go deleted file mode 100644 index 10cb892..0000000 --- a/toolbox.go +++ /dev/null @@ -1,174 +0,0 @@ -package llm - -import ( - "context" - "encoding/json" - "errors" - "fmt" -) - -// ToolBox is a collection of tools that OpenAI can use to execute functions. -// It is a wrapper around a collection of functions, and provides a way to automatically call the correct function with -// the correct parameters. -type ToolBox struct { - functions map[string]Function - mcpServers map[string]*MCPServer // tool name -> MCP server that provides it - dontRequireTool bool -} - -func NewToolBox(fns ...Function) ToolBox { - res := ToolBox{ - functions: map[string]Function{}, - } - - for _, f := range fns { - res.functions[f.Name] = f - } - - return res -} - -func (t ToolBox) Functions() []Function { - var res []Function - - for _, f := range t.functions { - res = append(res, f) - } - - return res -} - -func (t ToolBox) WithFunction(f Function) ToolBox { - t.functions[f.Name] = f - - return t -} - -func (t ToolBox) WithFunctions(fns ...Function) ToolBox { - for _, f := range fns { - t.functions[f.Name] = f - } - - return t -} - -func (t ToolBox) WithSyntheticFieldsAddedToAllFunctions(fieldsAndDescriptions map[string]string) ToolBox { - for k, v := range t.functions { - t.functions[k] = v.WithSyntheticFields(fieldsAndDescriptions) - } - - return t -} - -func (t ToolBox) ForEachFunction(fn func(f Function)) { - for _, f := range t.functions { - fn(f) - } -} - -func (t ToolBox) WithFunctionRemoved(name string) ToolBox { - delete(t.functions, name) - return t -} - -func (t ToolBox) WithRequireTool(val bool) ToolBox { - t.dontRequireTool = !val - return t -} - -func (t ToolBox) RequiresTool() bool { - return !t.dontRequireTool && len(t.functions) > 0 -} - -func (t ToolBox) ToToolChoice() any { - if len(t.functions) == 0 { - return nil - } - - return "required" -} - -var ( - ErrFunctionNotFound = errors.New("function not found") -) - -func (t ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) { - // Check if this is an MCP tool - if server, ok := t.mcpServers[functionName]; ok { - var args map[string]any - if params != "" { - if err := json.Unmarshal([]byte(params), &args); err != nil { - return nil, fmt.Errorf("failed to parse MCP tool arguments: %w", err) - } - } - return server.CallTool(ctx, functionName, args) - } - - // Regular function - f, ok := t.functions[functionName] - - if !ok { - return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName)) - } - - return f.Execute(ctx, params) -} - -func (t ToolBox) Execute(ctx *Context, toolCall ToolCall) (any, error) { - return t.executeFunction(ctx.WithToolCall(&toolCall), toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments) -} - -func (t ToolBox) GetSyntheticParametersFromFunctionContext(ctx context.Context) map[string]string { - val := ctx.Value("syntheticParameters") - - if val == nil { - return nil - } - - syntheticParameters, ok := val.(map[string]string) - if !ok { - return nil - } - - return syntheticParameters -} - -// ExecuteCallbacks will execute all the tool calls in the given list, and call the given callbacks when a new function is created, and when a function is finished. -// OnNewFunction is called when a new function is created -// OnFunctionFinished is called when a function is finished -func (t ToolBox) ExecuteCallbacks(ctx *Context, toolCalls []ToolCall, OnNewFunction func(ctx context.Context, funcName string, parameter string) (any, error), OnFunctionFinished func(ctx context.Context, funcName string, parameter string, result any, err error, newFunctionResult any) error) ([]ToolCallResponse, error) { - var res []ToolCallResponse - - for _, call := range toolCalls { - ctx := ctx.WithToolCall(&call) - if call.FunctionCall.Name == "" { - return nil, newError(ErrFunctionNotFound, errors.New("function name is empty")) - } - - var arg any - if OnNewFunction != nil { - var err error - arg, err = OnNewFunction(ctx, call.FunctionCall.Name, call.FunctionCall.Arguments) - - if err != nil { - return nil, newError(ErrFunctionNotFound, err) - } - } - out, err := t.Execute(ctx, call) - - if OnFunctionFinished != nil { - err := OnFunctionFinished(ctx, call.FunctionCall.Name, call.FunctionCall.Arguments, out, err, arg) - if err != nil { - return nil, newError(ErrFunctionNotFound, err) - } - } - - res = append(res, ToolCallResponse{ - ID: call.ID, - Result: out, - Error: err, - }) - } - - return res, nil -} diff --git a/transcriber.go b/transcriber.go deleted file mode 100644 index 9cef7f3..0000000 --- a/transcriber.go +++ /dev/null @@ -1,145 +0,0 @@ -package llm - -import ( - "context" - "fmt" - "os" - "os/exec" - "path/filepath" - "strings" -) - -// Transcriber abstracts a speech-to-text model implementation. -type Transcriber interface { - Transcribe(ctx context.Context, wav []byte, opts TranscriptionOptions) (Transcription, error) -} - -// TranscriptionResponseFormat controls the output format requested from a transcriber. -type TranscriptionResponseFormat string - -const ( - TranscriptionResponseFormatJSON TranscriptionResponseFormat = "json" - TranscriptionResponseFormatVerboseJSON TranscriptionResponseFormat = "verbose_json" - TranscriptionResponseFormatText TranscriptionResponseFormat = "text" - TranscriptionResponseFormatSRT TranscriptionResponseFormat = "srt" - TranscriptionResponseFormatVTT TranscriptionResponseFormat = "vtt" -) - -// TranscriptionTimestampGranularity defines the requested timestamp detail. -type TranscriptionTimestampGranularity string - -const ( - TranscriptionTimestampGranularityWord TranscriptionTimestampGranularity = "word" - TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment" -) - -// TranscriptionOptions configures transcription behavior. -type TranscriptionOptions struct { - Language string - Prompt string - Temperature *float64 - ResponseFormat TranscriptionResponseFormat - TimestampGranularities []TranscriptionTimestampGranularity - IncludeLogprobs bool -} - -// Transcription captures a normalized transcription result. -type Transcription struct { - Provider string - Model string - Text string - Language string - DurationSeconds float64 - Segments []TranscriptionSegment - Words []TranscriptionWord - Logprobs []TranscriptionTokenLogprob - Usage TranscriptionUsage - RawJSON string -} - -// TranscriptionSegment provides a coarse time-sliced transcription segment. -type TranscriptionSegment struct { - ID int - Start float64 - End float64 - Text string - Tokens []int - AvgLogprob *float64 - CompressionRatio *float64 - NoSpeechProb *float64 - Words []TranscriptionWord -} - -// TranscriptionWord provides a word-level timestamp. -type TranscriptionWord struct { - Word string - Start float64 - End float64 - Confidence *float64 -} - -// TranscriptionTokenLogprob captures token-level log probability details. -type TranscriptionTokenLogprob struct { - Token string - Bytes []float64 - Logprob float64 -} - -// TranscriptionUsage captures token or duration usage details. -type TranscriptionUsage struct { - Type string - InputTokens int64 - OutputTokens int64 - TotalTokens int64 - AudioTokens int64 - TextTokens int64 - Seconds float64 -} - -// TranscribeFile converts an audio file to WAV and transcribes it. -func TranscribeFile(ctx context.Context, filename string, transcriber Transcriber, opts TranscriptionOptions) (Transcription, error) { - if transcriber == nil { - return Transcription{}, fmt.Errorf("transcriber is nil") - } - - wav, err := audioFileToWav(ctx, filename) - if err != nil { - return Transcription{}, err - } - - return transcriber.Transcribe(ctx, wav, opts) -} - -func audioFileToWav(ctx context.Context, filename string) ([]byte, error) { - if filename == "" { - return nil, fmt.Errorf("filename is empty") - } - - if strings.EqualFold(filepath.Ext(filename), ".wav") { - data, err := os.ReadFile(filename) - if err != nil { - return nil, fmt.Errorf("read wav file: %w", err) - } - return data, nil - } - - tempFile, err := os.CreateTemp("", "go-llm-audio-*.wav") - if err != nil { - return nil, fmt.Errorf("create temp wav file: %w", err) - } - tempPath := tempFile.Name() - _ = tempFile.Close() - defer os.Remove(tempPath) - - cmd := exec.CommandContext(ctx, "ffmpeg", "-hide_banner", "-loglevel", "error", "-y", "-i", filename, "-vn", "-f", "wav", tempPath) - if output, err := cmd.CombinedOutput(); err != nil { - return nil, fmt.Errorf("ffmpeg convert failed: %w (output: %s)", err, strings.TrimSpace(string(output))) - } - - data, err := os.ReadFile(tempPath) - if err != nil { - return nil, fmt.Errorf("read converted wav file: %w", err) - } - - return data, nil -} diff --git a/v2/cmd/llm/.env.example b/v2/cmd/llm/.env.example new file mode 100644 index 0000000..b8369fd --- /dev/null +++ b/v2/cmd/llm/.env.example @@ -0,0 +1,27 @@ +# go-llm CLI environment variables +# Copy this file to .env and fill in the keys for providers you use. + +# OpenAI API Key (https://platform.openai.com/api-keys) +OPENAI_API_KEY= + +# Anthropic API Key (https://console.anthropic.com/settings/keys) +ANTHROPIC_API_KEY= + +# Google AI API Key (https://aistudio.google.com/apikey) +GOOGLE_API_KEY= + +# DeepSeek API Key (https://platform.deepseek.com) +DEEPSEEK_API_KEY= + +# Moonshot / Kimi API Key (https://platform.moonshot.ai) +MOONSHOT_API_KEY= + +# xAI / Grok API Key (https://x.ai/api) +XAI_API_KEY= + +# Groq API Key (https://console.groq.com/keys) +GROQ_API_KEY= + +# Ollama runs locally with no API key required. +# Override the endpoint if you're not using localhost:11434. +# OLLAMA_BASE_URL=http://localhost:11434/v1 diff --git a/v2/cmd/llm/commands.go b/v2/cmd/llm/commands.go new file mode 100644 index 0000000..03cefb5 --- /dev/null +++ b/v2/cmd/llm/commands.go @@ -0,0 +1,136 @@ +package main + +import ( + "context" + "encoding/base64" + "fmt" + "net/http" + "os" + "strings" + + tea "github.com/charmbracelet/bubbletea" + + llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2" +) + +// Message types for async operations. + +// ChatResponseMsg contains the response from a chat completion. +type ChatResponseMsg struct { + Response llm.Response + Err error +} + +// ToolExecutionMsg contains results from executing tool calls, one Message +// (RoleTool) per ToolCall, in the same order. +type ToolExecutionMsg struct { + Results []llm.Message + Err error +} + +// ImageLoadedMsg contains a loaded image. +type ImageLoadedMsg struct { + Image llm.Image + Err error +} + +// sendChatRequest sends a completion request with the current conversation, +// returning a ChatResponseMsg tea.Msg when the provider responds. +func sendChatRequest(model *llm.Model, messages []llm.Message, toolbox *llm.ToolBox, toolsEnabled bool, temperature *float64) tea.Cmd { + return func() tea.Msg { + opts := buildOpts(toolbox, toolsEnabled, temperature) + resp, err := model.Complete(context.Background(), messages, opts...) + return ChatResponseMsg{Response: resp, Err: err} + } +} + +// executeTools runs each tool call via the toolbox and returns ToolExecutionMsg +// with one RoleTool Message per call, in the same order. +func executeTools(toolbox *llm.ToolBox, calls []llm.ToolCall) tea.Cmd { + return func() tea.Msg { + ctx := context.Background() + results, err := toolbox.ExecuteAll(ctx, calls) + return ToolExecutionMsg{Results: results, Err: err} + } +} + +// buildOpts constructs RequestOptions from the current CLI state. +func buildOpts(toolbox *llm.ToolBox, toolsEnabled bool, temperature *float64) []llm.RequestOption { + var opts []llm.RequestOption + if toolsEnabled && toolbox != nil && len(toolbox.AllTools()) > 0 { + opts = append(opts, llm.WithTools(toolbox)) + } + if temperature != nil { + opts = append(opts, llm.WithTemperature(*temperature)) + } + return opts +} + +// loadImageFromPath loads an image from a file path. +func loadImageFromPath(path string) tea.Cmd { + return func() tea.Msg { + path = strings.TrimSpace(path) + path = strings.Trim(path, "\"'") + + data, err := os.ReadFile(path) + if err != nil { + return ImageLoadedMsg{Err: fmt.Errorf("failed to read image file: %w", err)} + } + + contentType := http.DetectContentType(data) + if !strings.HasPrefix(contentType, "image/") { + return ImageLoadedMsg{Err: fmt.Errorf("file is not an image: %s", contentType)} + } + + return ImageLoadedMsg{ + Image: llm.Image{ + Base64: base64.StdEncoding.EncodeToString(data), + ContentType: contentType, + }, + } + } +} + +// loadImageFromURL loads an image from a URL (kept as URL, not fetched). +func loadImageFromURL(url string) tea.Cmd { + return func() tea.Msg { + return ImageLoadedMsg{Image: llm.Image{URL: strings.TrimSpace(url)}} + } +} + +// loadImageFromBase64 loads an image from base64 data (raw or data: URL). +func loadImageFromBase64(data string) tea.Cmd { + return func() tea.Msg { + data = strings.TrimSpace(data) + + if strings.HasPrefix(data, "data:") { + parts := strings.SplitN(data, ",", 2) + if len(parts) != 2 { + return ImageLoadedMsg{Err: fmt.Errorf("invalid data URL format")} + } + mediaType := strings.TrimPrefix(parts[0], "data:") + mediaType = strings.TrimSuffix(mediaType, ";base64") + return ImageLoadedMsg{ + Image: llm.Image{ + Base64: parts[1], + ContentType: mediaType, + }, + } + } + + decoded, err := base64.StdEncoding.DecodeString(data) + if err != nil { + return ImageLoadedMsg{Err: fmt.Errorf("invalid base64 data: %w", err)} + } + contentType := http.DetectContentType(decoded) + if !strings.HasPrefix(contentType, "image/") { + return ImageLoadedMsg{Err: fmt.Errorf("data is not an image: %s", contentType)} + } + return ImageLoadedMsg{ + Image: llm.Image{ + Base64: data, + ContentType: contentType, + }, + } + } +} diff --git a/cmd/llm/main.go b/v2/cmd/llm/main.go similarity index 100% rename from cmd/llm/main.go rename to v2/cmd/llm/main.go diff --git a/cmd/llm/model.go b/v2/cmd/llm/model.go similarity index 55% rename from cmd/llm/model.go rename to v2/cmd/llm/model.go index 282a0b5..6bcfa8c 100644 --- a/cmd/llm/model.go +++ b/v2/cmd/llm/model.go @@ -7,10 +7,10 @@ import ( "github.com/charmbracelet/bubbles/viewport" tea "github.com/charmbracelet/bubbletea" - llm "gitea.stevedudenhoeffer.com/steve/go-llm" + llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2" ) -// State represents the current view/screen of the application +// State represents the current view/screen of the application. type State int const ( @@ -23,43 +23,42 @@ const ( StateAPIKeyInput ) -// DisplayMessage represents a message for display in the UI +// DisplayMessage represents a message for display in the UI. type DisplayMessage struct { Role llm.Role Content string Images int // number of images attached } -// ProviderInfo contains information about a provider -type ProviderInfo struct { - Name string - EnvVar string - Models []string +// ProviderEntry is a CLI-local view of a registered provider, enriched with +// UI state (which model is currently chosen, whether we have a key, etc.). +type ProviderEntry struct { + Info llm.ProviderInfo HasAPIKey bool ModelIndex int } -// Model is the main Bubble Tea model +// Model is the main Bubble Tea model. type Model struct { // State state State previousState State // Provider - provider llm.LLM + client *llm.Client + chat *llm.Model providerName string - chat llm.ChatCompletion modelName string apiKeys map[string]string - providers []ProviderInfo + providers []ProviderEntry providerIndex int // Conversation - conversation []llm.Input + conversation []llm.Message messages []DisplayMessage // Tools - toolbox llm.ToolBox + toolbox *llm.ToolBox toolsEnabled bool // Settings @@ -90,7 +89,7 @@ type Model struct { apiKeyInput textinput.Model } -// InitialModel creates and returns the initial model +// InitialModel creates and returns the initial model. func InitialModel() Model { ti := textinput.New() ti.Placeholder = "Type your message..." @@ -104,60 +103,21 @@ func InitialModel() Model { aki.Width = 60 aki.EchoMode = textinput.EchoPassword - // Initialize providers with environment variable checks - providers := []ProviderInfo{ - { - Name: "OpenAI", - EnvVar: "OPENAI_API_KEY", - Models: []string{ - "gpt-4.1", - "gpt-4.1-mini", - "gpt-4.1-nano", - "gpt-4o", - "gpt-4o-mini", - "gpt-4-turbo", - "gpt-3.5-turbo", - "o1", - "o1-mini", - "o1-preview", - "o3-mini", - }, - }, - { - Name: "Anthropic", - EnvVar: "ANTHROPIC_API_KEY", - Models: []string{ - "claude-sonnet-4-20250514", - "claude-opus-4-20250514", - "claude-3-7-sonnet-20250219", - "claude-3-5-sonnet-20241022", - "claude-3-5-haiku-20241022", - "claude-3-opus-20240229", - "claude-3-sonnet-20240229", - "claude-3-haiku-20240307", - }, - }, - { - Name: "Google", - EnvVar: "GOOGLE_API_KEY", - Models: []string{ - "gemini-2.0-flash", - "gemini-2.0-flash-lite", - "gemini-1.5-pro", - "gemini-1.5-flash", - "gemini-1.5-flash-8b", - "gemini-1.0-pro", - }, - }, - } - - // Check for API keys in environment + // Build provider list from the go-llm registry. + registry := llm.Providers() + providers := make([]ProviderEntry, len(registry)) apiKeys := make(map[string]string) - for i := range providers { - if key := os.Getenv(providers[i].EnvVar); key != "" { - apiKeys[providers[i].Name] = key - providers[i].HasAPIKey = true + + for i, info := range registry { + entry := ProviderEntry{Info: info} + if info.EnvKey == "" { + // Key-less provider (e.g., Ollama). + entry.HasAPIKey = true + } else if key := os.Getenv(info.EnvKey); key != "" { + apiKeys[info.Name] = key + entry.HasAPIKey = true } + providers[i] = entry } m := Model{ @@ -170,97 +130,87 @@ func InitialModel() Model { toolbox: createDemoToolbox(), toolsEnabled: false, messages: []DisplayMessage{}, - conversation: []llm.Input{}, + conversation: []llm.Message{}, } - // Build list items for provider selection + // Build list items for provider selection. m.listItems = make([]string, len(providers)) for i, p := range providers { status := " (no key)" if p.HasAPIKey { status = " (ready)" + if p.Info.EnvKey == "" { + status = " (local)" + } } - m.listItems[i] = p.Name + status + m.listItems[i] = p.Info.DisplayName + status } return m } -// Init initializes the model +// Init initializes the model. func (m Model) Init() tea.Cmd { return textinput.Blink } -// selectProvider sets up the selected provider +// selectProvider sets up the selected provider. func (m *Model) selectProvider(index int) error { if index < 0 || index >= len(m.providers) { return nil } p := m.providers[index] - key, ok := m.apiKeys[p.Name] - if !ok || key == "" { + key := m.apiKeys[p.Info.Name] // empty for key-less providers like Ollama + + if p.Info.EnvKey != "" && key == "" { return nil } - m.providerName = p.Name + m.providerName = p.Info.DisplayName m.providerIndex = index + m.client = p.Info.New(key) - switch p.Name { - case "OpenAI": - m.provider = llm.OpenAI(key) - case "Anthropic": - m.provider = llm.Anthropic(key) - case "Google": - m.provider = llm.Google(key) - } - - // Select default model - if len(p.Models) > 0 { + // Select default model. + if len(p.Info.Models) > 0 { return m.selectModel(p.ModelIndex) } return nil } -// selectModel sets the current model +// selectModel sets the current model. func (m *Model) selectModel(index int) error { - if m.provider == nil { + if m.client == nil { return nil } p := m.providers[m.providerIndex] - if index < 0 || index >= len(p.Models) { + if index < 0 || index >= len(p.Info.Models) { return nil } - modelName := p.Models[index] - chat, err := m.provider.ModelVersion(modelName) - if err != nil { - return err - } - - m.chat = chat + modelName := p.Info.Models[index] + m.chat = m.client.Model(modelName) m.modelName = modelName m.providers[m.providerIndex].ModelIndex = index return nil } -// newConversation resets the conversation +// newConversation resets the conversation. func (m *Model) newConversation() { - m.conversation = []llm.Input{} + m.conversation = []llm.Message{} m.messages = []DisplayMessage{} m.pendingImages = []llm.Image{} m.err = nil } -// addUserMessage adds a user message to the conversation +// addUserMessage adds a user message to the conversation. func (m *Model) addUserMessage(text string, images []llm.Image) { msg := llm.Message{ - Role: llm.RoleUser, - Text: text, - Images: images, + Role: llm.RoleUser, + Content: llm.Content{Text: text, Images: images}, } m.conversation = append(m.conversation, msg) m.messages = append(m.messages, DisplayMessage{ @@ -270,7 +220,7 @@ func (m *Model) addUserMessage(text string, images []llm.Image) { }) } -// addAssistantMessage adds an assistant message to the conversation +// addAssistantMessage adds an assistant message to the conversation display. func (m *Model) addAssistantMessage(content string) { m.messages = append(m.messages, DisplayMessage{ Role: llm.RoleAssistant, @@ -278,7 +228,7 @@ func (m *Model) addAssistantMessage(content string) { }) } -// addToolCallMessage adds a tool call message to display +// addToolCallMessage adds a tool call message to display. func (m *Model) addToolCallMessage(name string, args string) { m.messages = append(m.messages, DisplayMessage{ Role: llm.Role("tool_call"), @@ -286,7 +236,7 @@ func (m *Model) addToolCallMessage(name string, args string) { }) } -// addToolResultMessage adds a tool result message to display +// addToolResultMessage adds a tool result message to display. func (m *Model) addToolResultMessage(name string, result string) { m.messages = append(m.messages, DisplayMessage{ Role: llm.Role("tool_result"), diff --git a/cmd/llm/styles.go b/v2/cmd/llm/styles.go similarity index 100% rename from cmd/llm/styles.go rename to v2/cmd/llm/styles.go diff --git a/v2/cmd/llm/tools.go b/v2/cmd/llm/tools.go new file mode 100644 index 0000000..5f9bead --- /dev/null +++ b/v2/cmd/llm/tools.go @@ -0,0 +1,114 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "math" + "strconv" + "strings" + "time" + + llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2" +) + +// TimeParams is the parameter struct for the GetTime function. +type TimeParams struct{} + +// GetTime returns the current time. +func GetTime(_ context.Context, _ TimeParams) (string, error) { + return time.Now().Format("Monday, January 2, 2006 3:04:05 PM MST"), nil +} + +// CalcParams is the parameter struct for the Calculate function. +type CalcParams struct { + A float64 `json:"a" description:"First number"` + B float64 `json:"b" description:"Second number"` + Op string `json:"op" description:"Operation: add, subtract, multiply, divide, power, sqrt, mod"` +} + +// Calculate performs basic math operations. +func Calculate(_ context.Context, params CalcParams) (string, error) { + var result float64 + switch strings.ToLower(params.Op) { + case "add", "+": + result = params.A + params.B + case "subtract", "sub", "-": + result = params.A - params.B + case "multiply", "mul", "*": + result = params.A * params.B + case "divide", "div", "/": + if params.B == 0 { + return "", fmt.Errorf("division by zero") + } + result = params.A / params.B + case "power", "pow", "^": + result = math.Pow(params.A, params.B) + case "sqrt": + if params.A < 0 { + return "", fmt.Errorf("cannot take square root of negative number") + } + result = math.Sqrt(params.A) + case "mod", "%": + result = math.Mod(params.A, params.B) + default: + return "", fmt.Errorf("unknown operation: %s", params.Op) + } + return strconv.FormatFloat(result, 'f', -1, 64), nil +} + +// WeatherParams is the parameter struct for the GetWeather function. +type WeatherParams struct { + Location string `json:"location" description:"City name or location"` +} + +// GetWeather returns mock weather data (for demo purposes). +func GetWeather(_ context.Context, params WeatherParams) (string, error) { + weathers := []string{"sunny", "cloudy", "rainy", "partly cloudy", "windy"} + temps := []int{65, 72, 58, 80, 45} + idx := len(params.Location) % len(weathers) + + out := map[string]any{ + "location": params.Location, + "temperature": strconv.Itoa(temps[idx]) + "F", + "condition": weathers[idx], + "humidity": "45%", + "note": "This is mock data for demonstration purposes", + } + b, err := json.Marshal(out) + if err != nil { + return "", err + } + return string(b), nil +} + +// RandomNumberParams is the parameter struct for the RandomNumber function. +type RandomNumberParams struct { + Min int `json:"min" description:"Minimum value (inclusive)"` + Max int `json:"max" description:"Maximum value (inclusive)"` +} + +// RandomNumber generates a pseudo-random number (using current time nanoseconds). +func RandomNumber(_ context.Context, params RandomNumberParams) (string, error) { + if params.Min > params.Max { + return "", fmt.Errorf("min cannot be greater than max") + } + n := time.Now().UnixNano() + rangeSize := params.Max - params.Min + 1 + result := params.Min + int(n%int64(rangeSize)) + return strconv.Itoa(result), nil +} + +// createDemoToolbox creates a toolbox with demo tools for testing. +func createDemoToolbox() *llm.ToolBox { + return llm.NewToolBox( + llm.Define[TimeParams]("get_time", "Get the current date and time", GetTime), + llm.Define[CalcParams]("calculate", + "Perform basic math operations (add, subtract, multiply, divide, power, sqrt, mod)", + Calculate), + llm.Define[WeatherParams]("get_weather", + "Get weather information for a location (demo data)", GetWeather), + llm.Define[RandomNumberParams]("random_number", + "Generate a random number between min and max", RandomNumber), + ) +} diff --git a/cmd/llm/update.go b/v2/cmd/llm/update.go similarity index 68% rename from cmd/llm/update.go rename to v2/cmd/llm/update.go index 4592683..54391b9 100644 --- a/cmd/llm/update.go +++ b/v2/cmd/llm/update.go @@ -8,14 +8,14 @@ import ( "github.com/charmbracelet/bubbles/viewport" tea "github.com/charmbracelet/bubbletea" - llm "gitea.stevedudenhoeffer.com/steve/go-llm" + llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2" ) -// pendingRequest stores the request being processed for follow-up -var pendingRequest llm.Request -var pendingResponse llm.ResponseChoice +// pendingToolCalls stores the last response's tool calls so we can pair them +// with tool execution results for display. +var pendingToolCalls []llm.ToolCall -// Update handles messages and updates the model +// Update handles messages and updates the model. func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmd tea.Cmd var cmds []tea.Cmd @@ -53,40 +53,30 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, nil } - if len(msg.Response.Choices) == 0 { - m.err = fmt.Errorf("no response choices returned") - return m, nil + resp := msg.Response + + // Add the assistant message to the conversation history. + m.conversation = append(m.conversation, resp.Message()) + + // Show any text the assistant produced alongside tool calls. + if resp.Text != "" { + m.addAssistantMessage(resp.Text) } - choice := msg.Response.Choices[0] + if resp.HasToolCalls() && m.toolsEnabled { + pendingToolCalls = resp.ToolCalls - // Check for tool calls - if len(choice.Calls) > 0 && m.toolsEnabled { - // Store for follow-up - pendingResponse = choice - - // Add assistant's response to conversation if there's content - if choice.Content != "" { - m.addAssistantMessage(choice.Content) - } - - // Display tool calls - for _, call := range choice.Calls { - m.addToolCallMessage(call.FunctionCall.Name, call.FunctionCall.Arguments) + for _, call := range resp.ToolCalls { + m.addToolCallMessage(call.Name, call.Arguments) } m.viewport.SetContent(m.renderMessages()) m.viewport.GotoBottom() - // Execute tools m.loading = true - return m, executeTools(m.toolbox, pendingRequest, choice) + return m, executeTools(m.toolbox, resp.ToolCalls) } - // Regular response - add to conversation and display - m.conversation = append(m.conversation, choice) - m.addAssistantMessage(choice.Content) - m.viewport.SetContent(m.renderMessages()) m.viewport.GotoBottom() @@ -97,31 +87,24 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, nil } - // Display tool results + // Display results paired with the tool calls that produced them. for i, result := range msg.Results { - name := pendingResponse.Calls[i].FunctionCall.Name - resultStr := fmt.Sprintf("%v", result.Result) - if result.Error != nil { - resultStr = "Error: " + result.Error.Error() + name := "" + if i < len(pendingToolCalls) { + name = pendingToolCalls[i].Name } - m.addToolResultMessage(name, resultStr) + m.addToolResultMessage(name, result.Content.Text) } - // Add tool call responses to conversation - for _, result := range msg.Results { - m.conversation = append(m.conversation, result) - } - - // Add the assistant's response to conversation - m.conversation = append(m.conversation, pendingResponse) + // Append the raw tool result messages to the conversation so the + // assistant can reference them on the next turn. + m.conversation = append(m.conversation, msg.Results...) m.viewport.SetContent(m.renderMessages()) m.viewport.GotoBottom() - // Send follow-up request - followUp := buildFollowUpRequest(&m, pendingRequest, pendingResponse, msg.Results) - pendingRequest = followUp - return m, sendChatRequest(m.chat, followUp) + // Ask the model to continue given the tool results. + return m, sendChatRequest(m.chat, m.conversation, m.toolbox, m.toolsEnabled, m.temperature) case ImageLoadedMsg: if msg.Err != nil { @@ -135,7 +118,7 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.err = nil default: - // Update text input + // Update text input. if m.state == StateChat { m.input, cmd = m.input.Update(msg) cmds = append(cmds, cmd) @@ -148,13 +131,11 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, tea.Batch(cmds...) } -// handleKeyMsg handles keyboard input +// handleKeyMsg handles keyboard input. func (m Model) handleKeyMsg(msg tea.KeyMsg) (tea.Model, tea.Cmd) { - // Global key handling switch msg.String() { case "ctrl+c": return m, tea.Quit - case "esc": if m.state != StateChat { m.state = StateChat @@ -164,7 +145,6 @@ func (m Model) handleKeyMsg(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return m, tea.Quit } - // State-specific key handling switch m.state { case StateChat: return m.handleChatKeys(msg) @@ -185,7 +165,7 @@ func (m Model) handleKeyMsg(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return m, nil } -// handleChatKeys handles keys in chat state +// handleChatKeys handles keys in chat state. func (m Model) handleChatKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { switch msg.String() { case "enter": @@ -203,14 +183,13 @@ func (m Model) handleChatKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return m, nil } - // Build and send request - req := buildRequest(&m, text) - pendingRequest = req + // Ensure a system message is at the head of the conversation. + if len(m.conversation) == 0 && m.systemPrompt != "" { + m.conversation = append(m.conversation, llm.SystemMessage(m.systemPrompt)) + } - // Add user message to display m.addUserMessage(text, m.pendingImages) - // Clear input and pending images m.input.Reset() m.pendingImages = nil m.err = nil @@ -219,7 +198,7 @@ func (m Model) handleChatKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { m.viewport.SetContent(m.renderMessages()) m.viewport.GotoBottom() - return m, sendChatRequest(m.chat, req) + return m, sendChatRequest(m.chat, m.conversation, m.toolbox, m.toolsEnabled, m.temperature) case "ctrl+i": m.previousState = StateChat @@ -238,12 +217,12 @@ func (m Model) handleChatKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return m, nil case "ctrl+m": - if m.provider == nil { + if m.client == nil { m.err = fmt.Errorf("select a provider first") return m, nil } m.state = StateModelSelect - m.listItems = m.providers[m.providerIndex].Models + m.listItems = m.providers[m.providerIndex].Info.Models m.listIndex = m.providers[m.providerIndex].ModelIndex return m, nil @@ -268,7 +247,7 @@ func (m Model) handleChatKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { } } -// handleProviderSelectKeys handles keys in provider selection state +// handleProviderSelectKeys handles keys in provider selection state. func (m Model) handleProviderSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { switch msg.String() { case "up", "k": @@ -282,15 +261,13 @@ func (m Model) handleProviderSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { case "enter": p := m.providers[m.listIndex] if !p.HasAPIKey { - // Need to get API key m.state = StateAPIKeyInput m.apiKeyInput.Focus() m.apiKeyInput.SetValue("") return m, textinput.Blink } - err := m.selectProvider(m.listIndex) - if err != nil { + if err := m.selectProvider(m.listIndex); err != nil { m.err = err return m, nil } @@ -303,7 +280,7 @@ func (m Model) handleProviderSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return m, nil } -// handleAPIKeyInputKeys handles keys in API key input state +// handleAPIKeyInputKeys handles keys in API key input state. func (m Model) handleAPIKeyInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { switch msg.String() { case "enter": @@ -312,23 +289,22 @@ func (m Model) handleAPIKeyInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return m, nil } - // Store the API key p := m.providers[m.listIndex] - m.apiKeys[p.Name] = key + m.apiKeys[p.Info.Name] = key m.providers[m.listIndex].HasAPIKey = true - // Update list items for i, prov := range m.providers { status := " (no key)" if prov.HasAPIKey { status = " (ready)" + if prov.Info.EnvKey == "" { + status = " (local)" + } } - m.listItems[i] = prov.Name + status + m.listItems[i] = prov.Info.DisplayName + status } - // Select the provider - err := m.selectProvider(m.listIndex) - if err != nil { + if err := m.selectProvider(m.listIndex); err != nil { m.err = err return m, nil } @@ -345,7 +321,7 @@ func (m Model) handleAPIKeyInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { } } -// handleModelSelectKeys handles keys in model selection state +// handleModelSelectKeys handles keys in model selection state. func (m Model) handleModelSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { switch msg.String() { case "up", "k": @@ -357,8 +333,7 @@ func (m Model) handleModelSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { m.listIndex++ } case "enter": - err := m.selectModel(m.listIndex) - if err != nil { + if err := m.selectModel(m.listIndex); err != nil { m.err = err return m, nil } @@ -368,7 +343,7 @@ func (m Model) handleModelSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return m, nil } -// handleImageInputKeys handles keys in image input state +// handleImageInputKeys handles keys in image input state. func (m Model) handleImageInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { switch msg.String() { case "enter": @@ -381,12 +356,12 @@ func (m Model) handleImageInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { m.input.Placeholder = "Type your message..." - // Determine input type and load - if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") { + switch { + case strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://"): return m, loadImageFromURL(input) - } else if strings.HasPrefix(input, "data:") || len(input) > 100 && !strings.Contains(input, "/") && !strings.Contains(input, "\\") { + case strings.HasPrefix(input, "data:") || (len(input) > 100 && !strings.Contains(input, "/") && !strings.Contains(input, "\\")): return m, loadImageFromBase64(input) - } else { + default: return m, loadImageFromPath(input) } @@ -397,7 +372,7 @@ func (m Model) handleImageInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { } } -// handleToolsPanelKeys handles keys in tools panel state +// handleToolsPanelKeys handles keys in tools panel state. func (m Model) handleToolsPanelKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { switch msg.String() { case "t": @@ -409,11 +384,10 @@ func (m Model) handleToolsPanelKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return m, nil } -// handleSettingsKeys handles keys in settings state +// handleSettingsKeys handles keys in settings state. func (m Model) handleSettingsKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { switch msg.String() { case "1": - // Set temperature to nil (default) m.temperature = nil case "2": t := 0.0 diff --git a/cmd/llm/view.go b/v2/cmd/llm/view.go similarity index 88% rename from cmd/llm/view.go rename to v2/cmd/llm/view.go index bc63e0e..66bc449 100644 --- a/cmd/llm/view.go +++ b/v2/cmd/llm/view.go @@ -6,10 +6,10 @@ import ( "github.com/charmbracelet/lipgloss" - llm "gitea.stevedudenhoeffer.com/steve/go-llm" + llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2" ) -// View renders the current state +// View renders the current state. func (m Model) View() string { switch m.state { case StateProviderSelect: @@ -29,11 +29,10 @@ func (m Model) View() string { } } -// renderChat renders the main chat view +// renderChat renders the main chat view. func (m Model) renderChat() string { var b strings.Builder - // Header provider := m.providerName if provider == "" { provider = "None" @@ -49,43 +48,37 @@ func (m Model) renderChat() string { b.WriteString(header) b.WriteString("\n") - // Messages viewport if m.viewportReady { b.WriteString(m.viewport.View()) b.WriteString("\n") } - // Image indicator if len(m.pendingImages) > 0 { b.WriteString(imageIndicatorStyle.Render(fmt.Sprintf(" [%d image(s) attached]", len(m.pendingImages)))) b.WriteString("\n") } - // Error if m.err != nil { b.WriteString(errorStyle.Render(" Error: " + m.err.Error())) b.WriteString("\n") } - // Loading if m.loading { b.WriteString(loadingStyle.Render(" Thinking...")) b.WriteString("\n") } - // Input inputBox := inputStyle.Render(m.input.View()) b.WriteString(inputBox) b.WriteString("\n") - // Help help := inputHelpStyle.Render("Enter: send | Ctrl+I: image | Ctrl+T: tools | Ctrl+P: provider | Ctrl+M: model | Ctrl+S: settings | Ctrl+N: new | Esc: quit") b.WriteString(help) return appStyle.Render(b.String()) } -// renderMessages renders all messages for the viewport +// renderMessages renders all messages for the viewport. func (m Model) renderMessages() string { var b strings.Builder @@ -133,7 +126,7 @@ func (m Model) renderMessages() string { return b.String() } -// renderProviderSelect renders the provider selection view +// renderProviderSelect renders the provider selection view. func (m Model) renderProviderSelect() string { var b strings.Builder @@ -157,16 +150,18 @@ func (m Model) renderProviderSelect() string { return appStyle.Render(b.String()) } -// renderAPIKeyInput renders the API key input view +// renderAPIKeyInput renders the API key input view. func (m Model) renderAPIKeyInput() string { var b strings.Builder provider := m.providers[m.listIndex] - b.WriteString(headerStyle.Render(fmt.Sprintf("Enter API Key for %s", provider.Name))) + b.WriteString(headerStyle.Render(fmt.Sprintf("Enter API Key for %s", provider.Info.DisplayName))) b.WriteString("\n\n") - b.WriteString(fmt.Sprintf("Environment variable: %s\n\n", provider.EnvVar)) + if provider.Info.EnvKey != "" { + b.WriteString(fmt.Sprintf("Environment variable: %s\n\n", provider.Info.EnvKey)) + } b.WriteString("Enter your API key below (it will be hidden):\n\n") inputBox := inputStyle.Render(m.apiKeyInput.View()) @@ -178,7 +173,7 @@ func (m Model) renderAPIKeyInput() string { return appStyle.Render(b.String()) } -// renderModelSelect renders the model selection view +// renderModelSelect renders the model selection view. func (m Model) renderModelSelect() string { var b strings.Builder @@ -205,7 +200,7 @@ func (m Model) renderModelSelect() string { return appStyle.Render(b.String()) } -// renderImageInput renders the image input view +// renderImageInput renders the image input view. func (m Model) renderImageInput() string { var b strings.Builder @@ -230,7 +225,7 @@ func (m Model) renderImageInput() string { return appStyle.Render(b.String()) } -// renderToolsPanel renders the tools panel +// renderToolsPanel renders the tools panel. func (m Model) renderToolsPanel() string { var b strings.Builder @@ -249,8 +244,10 @@ func (m Model) renderToolsPanel() string { b.WriteString("\n\n") b.WriteString("Available tools:\n") - for _, fn := range m.toolbox.Functions() { - b.WriteString(fmt.Sprintf(" - %s: %s\n", selectedItemStyle.Render(fn.Name), fn.Description)) + if m.toolbox != nil { + for _, t := range m.toolbox.AllTools() { + b.WriteString(fmt.Sprintf(" - %s: %s\n", selectedItemStyle.Render(t.Name), t.Description)) + } } b.WriteString("\n") @@ -259,14 +256,13 @@ func (m Model) renderToolsPanel() string { return appStyle.Render(b.String()) } -// renderSettings renders the settings view +// renderSettings renders the settings view. func (m Model) renderSettings() string { var b strings.Builder b.WriteString(headerStyle.Render("Settings")) b.WriteString("\n\n") - // Temperature tempStr := "default" if m.temperature != nil { tempStr = fmt.Sprintf("%.1f", *m.temperature) @@ -284,7 +280,6 @@ func (m Model) renderSettings() string { b.WriteString("\n") - // System prompt b.WriteString(settingLabelStyle.Render("System Prompt:")) b.WriteString("\n") b.WriteString(settingValueStyle.Render(" " + m.systemPrompt)) diff --git a/v2/constructors.go b/v2/constructors.go index 9fa7d11..846c697 100644 --- a/v2/constructors.go +++ b/v2/constructors.go @@ -2,8 +2,13 @@ package llm import ( anthProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/anthropic" + deepseekProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/deepseek" googleProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/google" + groqProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/groq" + moonshotProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/moonshot" + ollamaProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/ollama" openaiProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openai" + xaiProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/xai" ) // OpenAI creates an OpenAI client. @@ -46,3 +51,69 @@ func Google(apiKey string, opts ...ClientOption) *Client { _ = cfg // Google doesn't support custom base URL in the SDK return NewClient(googleProvider.New(apiKey)) } + +// DeepSeek creates a DeepSeek client (OpenAI-compatible). +// +// Example: +// +// model := llm.DeepSeek("sk-...").Model("deepseek-chat") +func DeepSeek(apiKey string, opts ...ClientOption) *Client { + cfg := &clientConfig{} + for _, opt := range opts { + opt(cfg) + } + return NewClient(deepseekProvider.New(apiKey, cfg.baseURL)) +} + +// Moonshot creates a Moonshot AI (Kimi) client (OpenAI-compatible). +// +// Example: +// +// model := llm.Moonshot("sk-...").Model("kimi-k2-0711-preview") +func Moonshot(apiKey string, opts ...ClientOption) *Client { + cfg := &clientConfig{} + for _, opt := range opts { + opt(cfg) + } + return NewClient(moonshotProvider.New(apiKey, cfg.baseURL)) +} + +// XAI creates an xAI (Grok) client (OpenAI-compatible). +// +// Example: +// +// model := llm.XAI("xai-...").Model("grok-2") +func XAI(apiKey string, opts ...ClientOption) *Client { + cfg := &clientConfig{} + for _, opt := range opts { + opt(cfg) + } + return NewClient(xaiProvider.New(apiKey, cfg.baseURL)) +} + +// Groq creates a Groq client (OpenAI-compatible). +// +// Example: +// +// model := llm.Groq("gsk-...").Model("llama-3.3-70b-versatile") +func Groq(apiKey string, opts ...ClientOption) *Client { + cfg := &clientConfig{} + for _, opt := range opts { + opt(cfg) + } + return NewClient(groqProvider.New(apiKey, cfg.baseURL)) +} + +// Ollama creates a client for a local Ollama instance (OpenAI-compatible). +// No API key is required. Use WithBaseURL to point at a non-default host/port. +// +// Example: +// +// model := llm.Ollama().Model("llama3.2") +func Ollama(opts ...ClientOption) *Client { + cfg := &clientConfig{} + for _, opt := range opts { + opt(cfg) + } + return NewClient(ollamaProvider.New("", cfg.baseURL)) +} diff --git a/v2/deepseek/deepseek.go b/v2/deepseek/deepseek.go new file mode 100644 index 0000000..be67913 --- /dev/null +++ b/v2/deepseek/deepseek.go @@ -0,0 +1,36 @@ +// Package deepseek implements the go-llm v2 provider interface for DeepSeek +// (https://platform.deepseek.com). DeepSeek speaks the OpenAI Chat Completions +// protocol, so this package is a thin wrapper around openaicompat with its own +// defaults and per-model Rules. +package deepseek + +import ( + "strings" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" +) + +// DefaultBaseURL is the public DeepSeek API endpoint. +const DefaultBaseURL = "https://api.deepseek.com/v1" + +// Provider is a type alias over openaicompat.Provider. +type Provider = openaicompat.Provider + +// New creates a new DeepSeek provider. An empty baseURL uses DefaultBaseURL. +func New(apiKey, baseURL string) *Provider { + if baseURL == "" { + baseURL = DefaultBaseURL + } + return openaicompat.New(apiKey, baseURL, openaicompat.Rules{ + // DeepSeek's chat and reasoner models are text-only. + SupportsVision: func(string) bool { return false }, + // Reasoner doesn't accept tool calls. + SupportsTools: func(m string) bool { + return !strings.Contains(m, "reasoner") + }, + // Reasoner rejects user-supplied temperature. + RestrictTemperature: func(m string) bool { + return strings.Contains(m, "reasoner") + }, + }) +} diff --git a/v2/deepseek/deepseek_test.go b/v2/deepseek/deepseek_test.go new file mode 100644 index 0000000..0d9b629 --- /dev/null +++ b/v2/deepseek/deepseek_test.go @@ -0,0 +1,49 @@ +package deepseek_test + +import ( + "context" + "errors" + "testing" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/deepseek" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +func TestNew_DefaultBaseURL(t *testing.T) { + if p := deepseek.New("key", ""); p == nil { + t.Fatal("New returned nil") + } +} + +func TestRules_ReasonerRejectsTools(t *testing.T) { + p := deepseek.New("key", "") + req := provider.Request{ + Model: "deepseek-reasoner", + Messages: []provider.Message{{Role: "user", Content: "hi"}}, + Tools: []provider.ToolDef{ + {Name: "x", Schema: map[string]any{"type": "object"}}, + }, + } + _, err := p.Complete(context.Background(), req) + var fue *openaicompat.FeatureUnsupportedError + if !errors.As(err, &fue) || fue.Feature != "tools" { + t.Fatalf("want FeatureUnsupportedError(tools), got %v", err) + } +} + +func TestRules_ChatRejectsImages(t *testing.T) { + p := deepseek.New("key", "") + req := provider.Request{ + Model: "deepseek-chat", + Messages: []provider.Message{{ + Role: "user", + Images: []provider.Image{{URL: "a"}}, + }}, + } + _, err := p.Complete(context.Background(), req) + var fue *openaicompat.FeatureUnsupportedError + if !errors.As(err, &fue) || fue.Feature != "vision" { + t.Fatalf("want FeatureUnsupportedError(vision), got %v", err) + } +} diff --git a/v2/go.mod b/v2/go.mod index d57588e..cedc62d 100644 --- a/v2/go.mod +++ b/v2/go.mod @@ -1,10 +1,12 @@ module gitea.stevedudenhoeffer.com/steve/go-llm/v2 -go 1.24.0 - -toolchain go1.24.2 +go 1.24.2 require ( + github.com/charmbracelet/bubbles v1.0.0 + github.com/charmbracelet/bubbletea v1.3.10 + github.com/charmbracelet/lipgloss v1.1.0 + github.com/joho/godotenv v1.5.1 github.com/liushuangls/go-anthropic/v2 v2.17.0 github.com/modelcontextprotocol/go-sdk v1.2.0 github.com/openai/openai-go v1.12.0 @@ -18,6 +20,16 @@ require ( cloud.google.com/go v0.116.0 // indirect cloud.google.com/go/auth v0.9.3 // indirect cloud.google.com/go/compute/metadata v0.5.0 // indirect + github.com/atotto/clipboard v0.1.4 // indirect + github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect + github.com/charmbracelet/colorprofile v0.4.1 // indirect + github.com/charmbracelet/x/ansi v0.11.6 // indirect + github.com/charmbracelet/x/cellbuf v0.0.15 // indirect + github.com/charmbracelet/x/term v0.2.2 // indirect + github.com/clipperhouse/displaywidth v0.9.0 // indirect + github.com/clipperhouse/stringish v0.1.1 // indirect + github.com/clipperhouse/uax29/v2 v2.5.0 // indirect + github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/jsonschema-go v0.3.0 // indirect @@ -25,15 +37,24 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect github.com/gorilla/websocket v1.5.3 // indirect github.com/kr/fs v0.1.0 // indirect + github.com/lucasb-eyer/go-colorful v1.3.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-localereader v0.0.1 // indirect + github.com/mattn/go-runewidth v0.0.19 // indirect + github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect + github.com/muesli/cancelreader v0.2.2 // indirect + github.com/muesli/termenv v0.16.0 // indirect + github.com/rivo/uniseg v0.4.7 // indirect github.com/tidwall/gjson v1.14.4 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect + github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.opencensus.io v0.24.0 // indirect golang.org/x/net v0.42.0 // indirect golang.org/x/oauth2 v0.30.0 // indirect - golang.org/x/sys v0.35.0 // indirect + golang.org/x/sys v0.38.0 // indirect golang.org/x/text v0.33.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect google.golang.org/grpc v1.66.2 // indirect diff --git a/v2/go.sum b/v2/go.sum index bd64989..5ef6ced 100644 --- a/v2/go.sum +++ b/v2/go.sum @@ -6,8 +6,32 @@ cloud.google.com/go/auth v0.9.3/go.mod h1:7z6VY+7h3KUdRov5F1i8NDP5ZzWKYmEPO842Bg cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY= cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= +github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= +github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= +github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc= +github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E= +github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw= +github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4= +github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk= +github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk= +github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= +github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= +github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8= +github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ= +github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI= +github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q= +github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk= +github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA= +github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA= +github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= +github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= +github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U= +github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -16,6 +40,8 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= @@ -49,12 +75,28 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gT github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/liushuangls/go-anthropic/v2 v2.17.0 h1:iBA6h7aghi1q86owEQ95XE2R2MF/0dQ7bCxtwTxOg4c= github.com/liushuangls/go-anthropic/v2 v2.17.0/go.mod h1:a550cJXPoTG2FL3DvfKG2zzD5O2vjgvo4tHtoGPzFLU= +github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag= +github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= +github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= +github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= +github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s= github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= +github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= +github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= +github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= +github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0= github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/pkg/sftp v1.13.10 h1:+5FbKNTe5Z9aspU88DPIKJ9z2KZoaGCu6Sr6kKR/5mU= @@ -62,6 +104,8 @@ github.com/pkg/sftp v1.13.10/go.mod h1:bJ1a7uDhrX/4OII+agvy28lzRvQrmIQuaHrcI1Hbe github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -80,6 +124,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= @@ -89,6 +135,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= golang.org/x/image v0.35.0 h1:LKjiHdgMtO8z7Fh18nGY6KDcoEtVfsgLDPeLyguqb7I= golang.org/x/image v0.35.0/go.mod h1:MwPLTVgvxSASsxdLzKrl8BRFuyqMyGhLwmC+TO1Sybk= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -114,8 +162,10 @@ golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= -golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4= golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/v2/groq/groq.go b/v2/groq/groq.go new file mode 100644 index 0000000..cd1d9a7 --- /dev/null +++ b/v2/groq/groq.go @@ -0,0 +1,33 @@ +// Package groq implements the go-llm v2 provider interface for Groq +// (https://console.groq.com). Groq hosts open-source models behind an OpenAI +// Chat Completions-compatible endpoint, so this package is a thin wrapper over +// openaicompat with its own defaults and per-model Rules. +package groq + +import ( + "strings" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" +) + +// DefaultBaseURL is the public Groq OpenAI-compatible endpoint. +const DefaultBaseURL = "https://api.groq.com/openai/v1" + +// Provider is a type alias over openaicompat.Provider. +type Provider = openaicompat.Provider + +// New creates a new Groq provider. An empty baseURL uses DefaultBaseURL. +func New(apiKey, baseURL string) *Provider { + if baseURL == "" { + baseURL = DefaultBaseURL + } + return openaicompat.New(apiKey, baseURL, openaicompat.Rules{ + // Only Groq-hosted vision variants (e.g. *-vision-preview) accept images. + SupportsVision: func(m string) bool { + return strings.Contains(m, "vision") + }, + // Chat completions endpoint does not accept audio input; audio is via + // dedicated transcription endpoints, which go-llm doesn't cover here. + SupportsAudio: func(string) bool { return false }, + }) +} diff --git a/v2/groq/groq_test.go b/v2/groq/groq_test.go new file mode 100644 index 0000000..d50c118 --- /dev/null +++ b/v2/groq/groq_test.go @@ -0,0 +1,33 @@ +package groq_test + +import ( + "context" + "errors" + "testing" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/groq" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +func TestNew_Basic(t *testing.T) { + if p := groq.New("key", ""); p == nil { + t.Fatal("New returned nil") + } +} + +func TestRules_AudioRejected(t *testing.T) { + p := groq.New("key", "") + req := provider.Request{ + Model: "llama-3.3-70b-versatile", + Messages: []provider.Message{{ + Role: "user", + Audio: []provider.Audio{{Base64: "AAA=", ContentType: "audio/wav"}}, + }}, + } + _, err := p.Complete(context.Background(), req) + var fue *openaicompat.FeatureUnsupportedError + if !errors.As(err, &fue) || fue.Feature != "audio" { + t.Fatalf("want FeatureUnsupportedError(audio), got %v", err) + } +} diff --git a/v2/moonshot/moonshot.go b/v2/moonshot/moonshot.go new file mode 100644 index 0000000..8d7fbe1 --- /dev/null +++ b/v2/moonshot/moonshot.go @@ -0,0 +1,30 @@ +// Package moonshot implements the go-llm v2 provider interface for Moonshot +// AI (Kimi, https://platform.moonshot.ai). Moonshot speaks OpenAI Chat +// Completions, so this package is a thin wrapper over openaicompat with its +// own defaults and per-model Rules. +package moonshot + +import ( + "strings" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" +) + +// DefaultBaseURL is the public Moonshot API endpoint (international). +const DefaultBaseURL = "https://api.moonshot.ai/v1" + +// Provider is a type alias over openaicompat.Provider. +type Provider = openaicompat.Provider + +// New creates a new Moonshot provider. An empty baseURL uses DefaultBaseURL. +func New(apiKey, baseURL string) *Provider { + if baseURL == "" { + baseURL = DefaultBaseURL + } + return openaicompat.New(apiKey, baseURL, openaicompat.Rules{ + // Only Moonshot models whose name contains "vision" accept images. + SupportsVision: func(m string) bool { + return strings.Contains(m, "vision") + }, + }) +} diff --git a/v2/moonshot/moonshot_test.go b/v2/moonshot/moonshot_test.go new file mode 100644 index 0000000..d3f3c6f --- /dev/null +++ b/v2/moonshot/moonshot_test.go @@ -0,0 +1,33 @@ +package moonshot_test + +import ( + "context" + "errors" + "testing" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/moonshot" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +func TestNew_Basic(t *testing.T) { + if p := moonshot.New("key", ""); p == nil { + t.Fatal("New returned nil") + } +} + +func TestRules_NonVisionModelRejectsImages(t *testing.T) { + p := moonshot.New("key", "") + req := provider.Request{ + Model: "moonshot-v1-8k", + Messages: []provider.Message{{ + Role: "user", + Images: []provider.Image{{URL: "a"}}, + }}, + } + _, err := p.Complete(context.Background(), req) + var fue *openaicompat.FeatureUnsupportedError + if !errors.As(err, &fue) || fue.Feature != "vision" { + t.Fatalf("want FeatureUnsupportedError(vision), got %v", err) + } +} diff --git a/v2/ollama/ollama.go b/v2/ollama/ollama.go new file mode 100644 index 0000000..75623ca --- /dev/null +++ b/v2/ollama/ollama.go @@ -0,0 +1,25 @@ +// Package ollama implements the go-llm v2 provider interface for Ollama +// (https://ollama.com), a local model runner that exposes an OpenAI Chat +// Completions-compatible endpoint. No API key is required; capability depends +// on whichever model the user has pulled locally, so Rules are intentionally +// empty — we trust the local user. +package ollama + +import ( + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" +) + +// DefaultBaseURL points at a local Ollama instance with default port. +const DefaultBaseURL = "http://localhost:11434/v1" + +// Provider is a type alias over openaicompat.Provider. +type Provider = openaicompat.Provider + +// New creates a new Ollama provider. An empty baseURL uses DefaultBaseURL. +// Ollama ignores the API key; callers may pass "". +func New(apiKey, baseURL string) *Provider { + if baseURL == "" { + baseURL = DefaultBaseURL + } + return openaicompat.New(apiKey, baseURL, openaicompat.Rules{}) +} diff --git a/v2/ollama/ollama_test.go b/v2/ollama/ollama_test.go new file mode 100644 index 0000000..5c1e216 --- /dev/null +++ b/v2/ollama/ollama_test.go @@ -0,0 +1,13 @@ +package ollama_test + +import ( + "testing" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/ollama" +) + +func TestNew_NoKeyNeeded(t *testing.T) { + if p := ollama.New("", ""); p == nil { + t.Fatal("New returned nil") + } +} diff --git a/v2/openai/openai.go b/v2/openai/openai.go index 4193542..5b40964 100644 --- a/v2/openai/openai.go +++ b/v2/openai/openai.go @@ -1,433 +1,35 @@ // Package openai implements the go-llm v2 provider interface for OpenAI. +// +// The actual wire-protocol logic lives in the shared openaicompat package; +// this file encodes OpenAI-specific Rules (temperature is rejected on o-series +// and gpt-5* models) and supplies the default base URL. package openai import ( - "context" - "encoding/base64" - "fmt" - "io" - "net/http" - "path" "strings" - "github.com/openai/openai-go" - "github.com/openai/openai-go/option" - "github.com/openai/openai-go/packages/param" - "github.com/openai/openai-go/shared" - - "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" ) -// Provider implements the provider.Provider interface for OpenAI. -type Provider struct { - apiKey string - baseURL string -} +// DefaultBaseURL is the public OpenAI Chat Completions endpoint. +const DefaultBaseURL = "https://api.openai.com/v1" -// New creates a new OpenAI provider. +// Provider is the OpenAI chat-completion provider. It's a type alias over +// openaicompat.Provider so existing callers using openai.Provider keep compiling. +type Provider = openaicompat.Provider + +// New creates a new OpenAI provider. An empty baseURL uses DefaultBaseURL. func New(apiKey string, baseURL string) *Provider { - return &Provider{apiKey: apiKey, baseURL: baseURL} + if baseURL == "" { + baseURL = DefaultBaseURL + } + return openaicompat.New(apiKey, baseURL, openaicompat.Rules{ + RestrictTemperature: restrictTemperature, + }) } -// Complete performs a non-streaming completion. -func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) { - var opts []option.RequestOption - opts = append(opts, option.WithAPIKey(p.apiKey)) - if p.baseURL != "" { - opts = append(opts, option.WithBaseURL(p.baseURL)) - } - - cl := openai.NewClient(opts...) - oaiReq := p.buildRequest(req) - - resp, err := cl.Chat.Completions.New(ctx, oaiReq) - if err != nil { - return provider.Response{}, fmt.Errorf("openai completion error: %w", err) - } - - return p.convertResponse(resp), nil -} - -// Stream performs a streaming completion. -func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error { - var opts []option.RequestOption - opts = append(opts, option.WithAPIKey(p.apiKey)) - if p.baseURL != "" { - opts = append(opts, option.WithBaseURL(p.baseURL)) - } - - cl := openai.NewClient(opts...) - oaiReq := p.buildRequest(req) - oaiReq.StreamOptions = openai.ChatCompletionStreamOptionsParam{ - IncludeUsage: openai.Bool(true), - } - - stream := cl.Chat.Completions.NewStreaming(ctx, oaiReq) - - var fullText strings.Builder - var toolCalls []provider.ToolCall - toolCallArgs := map[int]*strings.Builder{} - var usage *provider.Usage - - for stream.Next() { - chunk := stream.Current() - - // Capture usage from the final chunk (present when StreamOptions.IncludeUsage is true) - if chunk.Usage.TotalTokens > 0 { - usage = &provider.Usage{ - InputTokens: int(chunk.Usage.PromptTokens), - OutputTokens: int(chunk.Usage.CompletionTokens), - TotalTokens: int(chunk.Usage.TotalTokens), - Details: extractUsageDetails(chunk.Usage), - } - } - - for _, choice := range chunk.Choices { - // Text delta - if choice.Delta.Content != "" { - fullText.WriteString(choice.Delta.Content) - events <- provider.StreamEvent{ - Type: provider.StreamEventText, - Text: choice.Delta.Content, - } - } - - // Tool call deltas - for _, tc := range choice.Delta.ToolCalls { - idx := int(tc.Index) - - if tc.ID != "" { - // New tool call starting - for len(toolCalls) <= idx { - toolCalls = append(toolCalls, provider.ToolCall{}) - } - toolCalls[idx].ID = tc.ID - toolCalls[idx].Name = tc.Function.Name - toolCallArgs[idx] = &strings.Builder{} - - events <- provider.StreamEvent{ - Type: provider.StreamEventToolStart, - ToolCall: &provider.ToolCall{ - ID: tc.ID, - Name: tc.Function.Name, - }, - ToolIndex: idx, - } - } - - if tc.Function.Arguments != "" { - if b, ok := toolCallArgs[idx]; ok { - b.WriteString(tc.Function.Arguments) - } - events <- provider.StreamEvent{ - Type: provider.StreamEventToolDelta, - ToolIndex: idx, - ToolCall: &provider.ToolCall{ - Arguments: tc.Function.Arguments, - }, - } - } - } - } - } - - if err := stream.Err(); err != nil { - return fmt.Errorf("openai stream error: %w", err) - } - - // Finalize tool calls - for idx := range toolCalls { - if b, ok := toolCallArgs[idx]; ok { - toolCalls[idx].Arguments = b.String() - } - events <- provider.StreamEvent{ - Type: provider.StreamEventToolEnd, - ToolIndex: idx, - ToolCall: &toolCalls[idx], - } - } - - // Send done event - events <- provider.StreamEvent{ - Type: provider.StreamEventDone, - Response: &provider.Response{ - Text: fullText.String(), - ToolCalls: toolCalls, - Usage: usage, - }, - } - - return nil -} - -func (p *Provider) buildRequest(req provider.Request) openai.ChatCompletionNewParams { - oaiReq := openai.ChatCompletionNewParams{ - Model: req.Model, - } - - for _, msg := range req.Messages { - oaiReq.Messages = append(oaiReq.Messages, convertMessage(msg, req.Model)) - } - - for _, tool := range req.Tools { - oaiReq.Tools = append(oaiReq.Tools, openai.ChatCompletionToolParam{ - Type: "function", - Function: shared.FunctionDefinitionParam{ - Name: tool.Name, - Description: openai.String(tool.Description), - Parameters: openai.FunctionParameters(tool.Schema), - }, - }) - } - - if req.Temperature != nil { - // o* and gpt-5* models don't support custom temperatures - if !strings.HasPrefix(req.Model, "o") && !strings.HasPrefix(req.Model, "gpt-5") { - oaiReq.Temperature = openai.Float(*req.Temperature) - } - } - - if req.MaxTokens != nil { - oaiReq.MaxCompletionTokens = openai.Int(int64(*req.MaxTokens)) - } - - if req.TopP != nil { - oaiReq.TopP = openai.Float(*req.TopP) - } - - if len(req.Stop) > 0 { - oaiReq.Stop = openai.ChatCompletionNewParamsStopUnion{OfString: openai.String(req.Stop[0])} - } - - return oaiReq -} - -func convertMessage(msg provider.Message, model string) openai.ChatCompletionMessageParamUnion { - var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam - var textContent param.Opt[string] - - for _, img := range msg.Images { - var url string - if img.Base64 != "" { - url = "data:" + img.ContentType + ";base64," + img.Base64 - } else if img.URL != "" { - url = img.URL - } - if url != "" { - arrayOfContentParts = append(arrayOfContentParts, - openai.ChatCompletionContentPartUnionParam{ - OfImageURL: &openai.ChatCompletionContentPartImageParam{ - ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ - URL: url, - }, - }, - }, - ) - } - } - - for _, aud := range msg.Audio { - var b64Data string - var format string - - if aud.Base64 != "" { - b64Data = aud.Base64 - format = audioFormat(aud.ContentType) - } else if aud.URL != "" { - resp, err := http.Get(aud.URL) - if err != nil { - continue - } - data, err := io.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - continue - } - b64Data = base64.StdEncoding.EncodeToString(data) - ct := resp.Header.Get("Content-Type") - if ct == "" { - ct = aud.ContentType - } - if ct == "" { - ct = audioFormatFromURL(aud.URL) - } - format = audioFormat(ct) - } - - if b64Data != "" && format != "" { - arrayOfContentParts = append(arrayOfContentParts, - openai.ChatCompletionContentPartUnionParam{ - OfInputAudio: &openai.ChatCompletionContentPartInputAudioParam{ - InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{ - Data: b64Data, - Format: format, - }, - }, - }, - ) - } - } - - if msg.Content != "" { - if len(arrayOfContentParts) > 0 { - arrayOfContentParts = append(arrayOfContentParts, - openai.ChatCompletionContentPartUnionParam{ - OfText: &openai.ChatCompletionContentPartTextParam{ - Text: msg.Content, - }, - }, - ) - } else { - textContent = openai.String(msg.Content) - } - } - - // Determine if this model uses developer messages instead of system - useDeveloper := false - parts := strings.Split(model, "-") - if len(parts) > 1 && len(parts[0]) > 0 && parts[0][0] == 'o' { - useDeveloper = true - } - - switch msg.Role { - case "system": - if useDeveloper { - return openai.ChatCompletionMessageParamUnion{ - OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{ - Content: openai.ChatCompletionDeveloperMessageParamContentUnion{ - OfString: textContent, - }, - }, - } - } - return openai.ChatCompletionMessageParamUnion{ - OfSystem: &openai.ChatCompletionSystemMessageParam{ - Content: openai.ChatCompletionSystemMessageParamContentUnion{ - OfString: textContent, - }, - }, - } - - case "user": - return openai.ChatCompletionMessageParamUnion{ - OfUser: &openai.ChatCompletionUserMessageParam{ - Content: openai.ChatCompletionUserMessageParamContentUnion{ - OfString: textContent, - OfArrayOfContentParts: arrayOfContentParts, - }, - }, - } - - case "assistant": - as := &openai.ChatCompletionAssistantMessageParam{} - if msg.Content != "" { - as.Content.OfString = openai.String(msg.Content) - } - for _, tc := range msg.ToolCalls { - as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{ - ID: tc.ID, - Function: openai.ChatCompletionMessageToolCallFunctionParam{ - Name: tc.Name, - Arguments: tc.Arguments, - }, - }) - } - return openai.ChatCompletionMessageParamUnion{OfAssistant: as} - - case "tool": - return openai.ChatCompletionMessageParamUnion{ - OfTool: &openai.ChatCompletionToolMessageParam{ - ToolCallID: msg.ToolCallID, - Content: openai.ChatCompletionToolMessageParamContentUnion{ - OfString: openai.String(msg.Content), - }, - }, - } - } - - // Fallback to user message - return openai.ChatCompletionMessageParamUnion{ - OfUser: &openai.ChatCompletionUserMessageParam{ - Content: openai.ChatCompletionUserMessageParamContentUnion{ - OfString: textContent, - }, - }, - } -} - -func (p *Provider) convertResponse(resp *openai.ChatCompletion) provider.Response { - var res provider.Response - - if resp == nil || len(resp.Choices) == 0 { - return res - } - - choice := resp.Choices[0] - res.Text = choice.Message.Content - - for _, tc := range choice.Message.ToolCalls { - res.ToolCalls = append(res.ToolCalls, provider.ToolCall{ - ID: tc.ID, - Name: tc.Function.Name, - Arguments: strings.TrimSpace(tc.Function.Arguments), - }) - } - - if resp.Usage.TotalTokens > 0 { - res.Usage = &provider.Usage{ - InputTokens: int(resp.Usage.PromptTokens), - OutputTokens: int(resp.Usage.CompletionTokens), - TotalTokens: int(resp.Usage.TotalTokens), - } - res.Usage.Details = extractUsageDetails(resp.Usage) - } - - return res -} - -// audioFormat converts a MIME type to an OpenAI audio format string ("wav" or "mp3"). -func audioFormat(contentType string) string { - ct := strings.ToLower(contentType) - switch { - case strings.Contains(ct, "wav"): - return "wav" - case strings.Contains(ct, "mp3"), strings.Contains(ct, "mpeg"): - return "mp3" - default: - return "wav" - } -} - -// extractUsageDetails extracts provider-specific detail tokens from an OpenAI CompletionUsage. -func extractUsageDetails(usage openai.CompletionUsage) map[string]int { - details := map[string]int{} - if usage.CompletionTokensDetails.ReasoningTokens > 0 { - details[provider.UsageDetailReasoningTokens] = int(usage.CompletionTokensDetails.ReasoningTokens) - } - if usage.CompletionTokensDetails.AudioTokens > 0 { - details[provider.UsageDetailAudioOutputTokens] = int(usage.CompletionTokensDetails.AudioTokens) - } - if usage.PromptTokensDetails.CachedTokens > 0 { - details[provider.UsageDetailCachedInputTokens] = int(usage.PromptTokensDetails.CachedTokens) - } - if usage.PromptTokensDetails.AudioTokens > 0 { - details[provider.UsageDetailAudioInputTokens] = int(usage.PromptTokensDetails.AudioTokens) - } - if len(details) == 0 { - return nil - } - return details -} - -// audioFormatFromURL guesses the audio format from a URL's file extension. -func audioFormatFromURL(u string) string { - ext := strings.ToLower(path.Ext(u)) - switch ext { - case ".mp3": - return "audio/mp3" - case ".wav": - return "audio/wav" - default: - return "audio/wav" - } +// restrictTemperature reports whether OpenAI rejects a user-supplied +// temperature for this model. o-series reasoning models and gpt-5* both do. +func restrictTemperature(model string) bool { + return strings.HasPrefix(model, "o") || strings.HasPrefix(model, "gpt-5") } diff --git a/v2/openaicompat/openaicompat.go b/v2/openaicompat/openaicompat.go new file mode 100644 index 0000000..690b358 --- /dev/null +++ b/v2/openaicompat/openaicompat.go @@ -0,0 +1,537 @@ +// Package openaicompat implements a shared chat-completion Provider for any +// service that speaks the OpenAI Chat Completions API (OpenAI itself, DeepSeek, +// Moonshot, xAI, Groq, Ollama, and friends). +// +// Most providers differ from vanilla OpenAI only in endpoint URL and a handful +// of per-model quirks (e.g., "this model is text-only", "this model doesn't +// accept tools", "drop temperature on reasoning models"). Those quirks are +// captured declaratively via Rules, so a concrete provider package is usually +// a one-function wrapper that calls New with its own base URL and Rules. +package openaicompat + +import ( + "context" + "encoding/base64" + "fmt" + "io" + "net/http" + "path" + "strings" + + "github.com/openai/openai-go" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/shared" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +// Rules encodes provider-specific constraints on top of the OpenAI wire +// protocol. The zero value means "no restrictions" and behaves like vanilla +// OpenAI. Individual fields are documented inline. +type Rules struct { + // MaxImagesPerMessage rejects requests whose any single message carries + // more images than this cap. 0 means "no cap". + MaxImagesPerMessage int + + // MaxAudioPerMessage rejects requests whose any single message carries + // more audio attachments than this cap. 0 means "no cap". + MaxAudioPerMessage int + + // SupportsVision, when non-nil, is consulted for every request that + // includes any image attachments. If it returns false for the request's + // model, the call fails with a FeatureUnsupportedError before hitting + // the network. + SupportsVision func(model string) bool + + // SupportsTools, when non-nil, is consulted for every request that + // includes any tool definitions. If it returns false for the model, + // the call fails with a FeatureUnsupportedError before hitting the + // network. + SupportsTools func(model string) bool + + // SupportsAudio, when non-nil, is consulted for every request that + // includes any audio attachments. If it returns false for the model, + // the call fails with a FeatureUnsupportedError. + SupportsAudio func(model string) bool + + // RestrictTemperature, when non-nil and returning true for the request's + // model, causes the Temperature field to be silently dropped from the + // outgoing request. Used by OpenAI o-series and gpt-5* which reject a + // user-provided temperature. + RestrictTemperature func(model string) bool + + // CustomizeRequest is a last-mile hook invoked after buildRequest but + // before the call is sent. It receives the fully built OpenAI SDK + // parameters and may mutate them freely (add headers, flip flags, tweak + // response_format, etc.). + CustomizeRequest func(params *openai.ChatCompletionNewParams) +} + +// FeatureUnsupportedError is returned when a Rules predicate rejects a request +// because the target model does not support a feature the caller included. +type FeatureUnsupportedError struct { + Feature string + Model string +} + +func (e *FeatureUnsupportedError) Error() string { + return fmt.Sprintf("openaicompat: model %q does not support %s", e.Model, e.Feature) +} + +// Provider implements provider.Provider for any OpenAI-compatible endpoint. +type Provider struct { + apiKey string + baseURL string + rules Rules +} + +// New creates a Provider. baseURL may be empty to let the OpenAI SDK use its +// default; in practice concrete provider packages always pass a default. +func New(apiKey, baseURL string, rules Rules) *Provider { + return &Provider{apiKey: apiKey, baseURL: baseURL, rules: rules} +} + +// Complete performs a non-streaming completion. +func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) { + if err := p.checkRules(req); err != nil { + return provider.Response{}, err + } + + cl := openai.NewClient(p.requestOptions()...) + oaiReq := p.buildRequest(req) + if p.rules.CustomizeRequest != nil { + p.rules.CustomizeRequest(&oaiReq) + } + + resp, err := cl.Chat.Completions.New(ctx, oaiReq) + if err != nil { + return provider.Response{}, fmt.Errorf("openai completion error: %w", err) + } + + return p.convertResponse(resp), nil +} + +// Stream performs a streaming completion. +func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error { + if err := p.checkRules(req); err != nil { + return err + } + + cl := openai.NewClient(p.requestOptions()...) + oaiReq := p.buildRequest(req) + oaiReq.StreamOptions = openai.ChatCompletionStreamOptionsParam{ + IncludeUsage: openai.Bool(true), + } + if p.rules.CustomizeRequest != nil { + p.rules.CustomizeRequest(&oaiReq) + } + + stream := cl.Chat.Completions.NewStreaming(ctx, oaiReq) + + var fullText strings.Builder + var toolCalls []provider.ToolCall + toolCallArgs := map[int]*strings.Builder{} + var usage *provider.Usage + + for stream.Next() { + chunk := stream.Current() + + // Capture usage from the final chunk (present when StreamOptions.IncludeUsage is true) + if chunk.Usage.TotalTokens > 0 { + usage = &provider.Usage{ + InputTokens: int(chunk.Usage.PromptTokens), + OutputTokens: int(chunk.Usage.CompletionTokens), + TotalTokens: int(chunk.Usage.TotalTokens), + Details: extractUsageDetails(chunk.Usage), + } + } + + for _, choice := range chunk.Choices { + // Text delta + if choice.Delta.Content != "" { + fullText.WriteString(choice.Delta.Content) + events <- provider.StreamEvent{ + Type: provider.StreamEventText, + Text: choice.Delta.Content, + } + } + + // Tool call deltas + for _, tc := range choice.Delta.ToolCalls { + idx := int(tc.Index) + + if tc.ID != "" { + // New tool call starting + for len(toolCalls) <= idx { + toolCalls = append(toolCalls, provider.ToolCall{}) + } + toolCalls[idx].ID = tc.ID + toolCalls[idx].Name = tc.Function.Name + toolCallArgs[idx] = &strings.Builder{} + + events <- provider.StreamEvent{ + Type: provider.StreamEventToolStart, + ToolCall: &provider.ToolCall{ + ID: tc.ID, + Name: tc.Function.Name, + }, + ToolIndex: idx, + } + } + + if tc.Function.Arguments != "" { + if b, ok := toolCallArgs[idx]; ok { + b.WriteString(tc.Function.Arguments) + } + events <- provider.StreamEvent{ + Type: provider.StreamEventToolDelta, + ToolIndex: idx, + ToolCall: &provider.ToolCall{ + Arguments: tc.Function.Arguments, + }, + } + } + } + } + } + + if err := stream.Err(); err != nil { + return fmt.Errorf("openai stream error: %w", err) + } + + // Finalize tool calls + for idx := range toolCalls { + if b, ok := toolCallArgs[idx]; ok { + toolCalls[idx].Arguments = b.String() + } + events <- provider.StreamEvent{ + Type: provider.StreamEventToolEnd, + ToolIndex: idx, + ToolCall: &toolCalls[idx], + } + } + + events <- provider.StreamEvent{ + Type: provider.StreamEventDone, + Response: &provider.Response{ + Text: fullText.String(), + ToolCalls: toolCalls, + Usage: usage, + }, + } + + return nil +} + +func (p *Provider) requestOptions() []option.RequestOption { + opts := []option.RequestOption{option.WithAPIKey(p.apiKey)} + if p.baseURL != "" { + opts = append(opts, option.WithBaseURL(p.baseURL)) + } + return opts +} + +// checkRules applies all Rules predicates against a request and returns an +// error if any constraint is violated. Runs before any network call. +func (p *Provider) checkRules(req provider.Request) error { + var hasImages, hasAudio bool + for _, msg := range req.Messages { + if len(msg.Images) > 0 { + hasImages = true + } + if len(msg.Audio) > 0 { + hasAudio = true + } + if p.rules.MaxImagesPerMessage > 0 && len(msg.Images) > p.rules.MaxImagesPerMessage { + return fmt.Errorf("openaicompat: message has %d images, max allowed is %d for model %q", + len(msg.Images), p.rules.MaxImagesPerMessage, req.Model) + } + if p.rules.MaxAudioPerMessage > 0 && len(msg.Audio) > p.rules.MaxAudioPerMessage { + return fmt.Errorf("openaicompat: message has %d audio attachments, max allowed is %d for model %q", + len(msg.Audio), p.rules.MaxAudioPerMessage, req.Model) + } + } + + if hasImages && p.rules.SupportsVision != nil && !p.rules.SupportsVision(req.Model) { + return &FeatureUnsupportedError{Feature: "vision", Model: req.Model} + } + if hasAudio && p.rules.SupportsAudio != nil && !p.rules.SupportsAudio(req.Model) { + return &FeatureUnsupportedError{Feature: "audio", Model: req.Model} + } + if len(req.Tools) > 0 && p.rules.SupportsTools != nil && !p.rules.SupportsTools(req.Model) { + return &FeatureUnsupportedError{Feature: "tools", Model: req.Model} + } + return nil +} + +func (p *Provider) buildRequest(req provider.Request) openai.ChatCompletionNewParams { + oaiReq := openai.ChatCompletionNewParams{ + Model: req.Model, + } + + for _, msg := range req.Messages { + oaiReq.Messages = append(oaiReq.Messages, convertMessage(msg, req.Model)) + } + + for _, tool := range req.Tools { + oaiReq.Tools = append(oaiReq.Tools, openai.ChatCompletionToolParam{ + Type: "function", + Function: shared.FunctionDefinitionParam{ + Name: tool.Name, + Description: openai.String(tool.Description), + Parameters: openai.FunctionParameters(tool.Schema), + }, + }) + } + + if req.Temperature != nil { + if p.rules.RestrictTemperature == nil || !p.rules.RestrictTemperature(req.Model) { + oaiReq.Temperature = openai.Float(*req.Temperature) + } + } + + if req.MaxTokens != nil { + oaiReq.MaxCompletionTokens = openai.Int(int64(*req.MaxTokens)) + } + + if req.TopP != nil { + oaiReq.TopP = openai.Float(*req.TopP) + } + + if len(req.Stop) > 0 { + oaiReq.Stop = openai.ChatCompletionNewParamsStopUnion{OfString: openai.String(req.Stop[0])} + } + + return oaiReq +} + +func convertMessage(msg provider.Message, model string) openai.ChatCompletionMessageParamUnion { + var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam + var textContent param.Opt[string] + + for _, img := range msg.Images { + var url string + if img.Base64 != "" { + url = "data:" + img.ContentType + ";base64," + img.Base64 + } else if img.URL != "" { + url = img.URL + } + if url != "" { + arrayOfContentParts = append(arrayOfContentParts, + openai.ChatCompletionContentPartUnionParam{ + OfImageURL: &openai.ChatCompletionContentPartImageParam{ + ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ + URL: url, + }, + }, + }, + ) + } + } + + for _, aud := range msg.Audio { + var b64Data string + var format string + + if aud.Base64 != "" { + b64Data = aud.Base64 + format = audioFormat(aud.ContentType) + } else if aud.URL != "" { + resp, err := http.Get(aud.URL) + if err != nil { + continue + } + data, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + continue + } + b64Data = base64.StdEncoding.EncodeToString(data) + ct := resp.Header.Get("Content-Type") + if ct == "" { + ct = aud.ContentType + } + if ct == "" { + ct = audioFormatFromURL(aud.URL) + } + format = audioFormat(ct) + } + + if b64Data != "" && format != "" { + arrayOfContentParts = append(arrayOfContentParts, + openai.ChatCompletionContentPartUnionParam{ + OfInputAudio: &openai.ChatCompletionContentPartInputAudioParam{ + InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{ + Data: b64Data, + Format: format, + }, + }, + }, + ) + } + } + + if msg.Content != "" { + if len(arrayOfContentParts) > 0 { + arrayOfContentParts = append(arrayOfContentParts, + openai.ChatCompletionContentPartUnionParam{ + OfText: &openai.ChatCompletionContentPartTextParam{ + Text: msg.Content, + }, + }, + ) + } else { + textContent = openai.String(msg.Content) + } + } + + // Determine if this model uses developer messages instead of system + useDeveloper := false + parts := strings.Split(model, "-") + if len(parts) > 1 && len(parts[0]) > 0 && parts[0][0] == 'o' { + useDeveloper = true + } + + switch msg.Role { + case "system": + if useDeveloper { + return openai.ChatCompletionMessageParamUnion{ + OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{ + Content: openai.ChatCompletionDeveloperMessageParamContentUnion{ + OfString: textContent, + }, + }, + } + } + return openai.ChatCompletionMessageParamUnion{ + OfSystem: &openai.ChatCompletionSystemMessageParam{ + Content: openai.ChatCompletionSystemMessageParamContentUnion{ + OfString: textContent, + }, + }, + } + + case "user": + return openai.ChatCompletionMessageParamUnion{ + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.ChatCompletionUserMessageParamContentUnion{ + OfString: textContent, + OfArrayOfContentParts: arrayOfContentParts, + }, + }, + } + + case "assistant": + as := &openai.ChatCompletionAssistantMessageParam{} + if msg.Content != "" { + as.Content.OfString = openai.String(msg.Content) + } + for _, tc := range msg.ToolCalls { + as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{ + ID: tc.ID, + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: tc.Name, + Arguments: tc.Arguments, + }, + }) + } + return openai.ChatCompletionMessageParamUnion{OfAssistant: as} + + case "tool": + return openai.ChatCompletionMessageParamUnion{ + OfTool: &openai.ChatCompletionToolMessageParam{ + ToolCallID: msg.ToolCallID, + Content: openai.ChatCompletionToolMessageParamContentUnion{ + OfString: openai.String(msg.Content), + }, + }, + } + } + + // Fallback to user message + return openai.ChatCompletionMessageParamUnion{ + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.ChatCompletionUserMessageParamContentUnion{ + OfString: textContent, + }, + }, + } +} + +func (p *Provider) convertResponse(resp *openai.ChatCompletion) provider.Response { + var res provider.Response + + if resp == nil || len(resp.Choices) == 0 { + return res + } + + choice := resp.Choices[0] + res.Text = choice.Message.Content + + for _, tc := range choice.Message.ToolCalls { + res.ToolCalls = append(res.ToolCalls, provider.ToolCall{ + ID: tc.ID, + Name: tc.Function.Name, + Arguments: strings.TrimSpace(tc.Function.Arguments), + }) + } + + if resp.Usage.TotalTokens > 0 { + res.Usage = &provider.Usage{ + InputTokens: int(resp.Usage.PromptTokens), + OutputTokens: int(resp.Usage.CompletionTokens), + TotalTokens: int(resp.Usage.TotalTokens), + } + res.Usage.Details = extractUsageDetails(resp.Usage) + } + + return res +} + +// audioFormat converts a MIME type to an OpenAI audio format string ("wav" or "mp3"). +func audioFormat(contentType string) string { + ct := strings.ToLower(contentType) + switch { + case strings.Contains(ct, "wav"): + return "wav" + case strings.Contains(ct, "mp3"), strings.Contains(ct, "mpeg"): + return "mp3" + default: + return "wav" + } +} + +// extractUsageDetails extracts provider-specific detail tokens from an OpenAI CompletionUsage. +func extractUsageDetails(usage openai.CompletionUsage) map[string]int { + details := map[string]int{} + if usage.CompletionTokensDetails.ReasoningTokens > 0 { + details[provider.UsageDetailReasoningTokens] = int(usage.CompletionTokensDetails.ReasoningTokens) + } + if usage.CompletionTokensDetails.AudioTokens > 0 { + details[provider.UsageDetailAudioOutputTokens] = int(usage.CompletionTokensDetails.AudioTokens) + } + if usage.PromptTokensDetails.CachedTokens > 0 { + details[provider.UsageDetailCachedInputTokens] = int(usage.PromptTokensDetails.CachedTokens) + } + if usage.PromptTokensDetails.AudioTokens > 0 { + details[provider.UsageDetailAudioInputTokens] = int(usage.PromptTokensDetails.AudioTokens) + } + if len(details) == 0 { + return nil + } + return details +} + +// audioFormatFromURL guesses the audio format from a URL's file extension. +func audioFormatFromURL(u string) string { + ext := strings.ToLower(path.Ext(u)) + switch ext { + case ".mp3": + return "audio/mp3" + case ".wav": + return "audio/wav" + default: + return "audio/wav" + } +} diff --git a/v2/openaicompat/openaicompat_test.go b/v2/openaicompat/openaicompat_test.go new file mode 100644 index 0000000..efd412a --- /dev/null +++ b/v2/openaicompat/openaicompat_test.go @@ -0,0 +1,313 @@ +package openaicompat_test + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/openai/openai-go" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +// newTestServer returns a httptest server that captures the raw request body +// on POST /chat/completions and returns a canned OpenAI response so Complete() +// succeeds. Use `captured` to assert on what the provider would send. +func newTestServer(t *testing.T) (*httptest.Server, *[]byte) { + t.Helper() + var body []byte + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/chat/completions" { + http.NotFound(w, r) + return + } + b, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + } + body = b + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{ + "id": "cmpl-1", + "object": "chat.completion", + "choices": [{ + "index": 0, + "message": {"role":"assistant","content":"ok"}, + "finish_reason": "stop" + }], + "usage": {"prompt_tokens":1,"completion_tokens":1,"total_tokens":2} + }`) + })) + return srv, &body +} + +func textReq(model, content string) provider.Request { + return provider.Request{ + Model: model, + Messages: []provider.Message{{Role: "user", Content: content}}, + } +} + +func TestComplete_ZeroRulesPassesThrough(t *testing.T) { + srv, body := newTestServer(t) + defer srv.Close() + + temp := 0.7 + req := textReq("gpt-4o", "hi") + req.Temperature = &temp + + p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{}) + resp, err := p.Complete(context.Background(), req) + if err != nil { + t.Fatalf("Complete: %v", err) + } + if resp.Text != "ok" { + t.Errorf("Text = %q, want %q", resp.Text, "ok") + } + + // Temperature should be present since RestrictTemperature is nil. + var parsed map[string]any + if err := json.Unmarshal(*body, &parsed); err != nil { + t.Fatalf("unmarshal request body: %v", err) + } + if _, ok := parsed["temperature"]; !ok { + t.Errorf("expected temperature in request body, got: %s", *body) + } +} + +func TestComplete_RestrictTemperatureDropsField(t *testing.T) { + srv, body := newTestServer(t) + defer srv.Close() + + temp := 0.7 + req := textReq("o1", "hi") + req.Temperature = &temp + + p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{ + RestrictTemperature: func(m string) bool { return strings.HasPrefix(m, "o") }, + }) + if _, err := p.Complete(context.Background(), req); err != nil { + t.Fatalf("Complete: %v", err) + } + + var parsed map[string]any + if err := json.Unmarshal(*body, &parsed); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if _, ok := parsed["temperature"]; ok { + t.Errorf("temperature should be dropped for o1, got: %s", *body) + } +} + +func TestComplete_SupportsVisionRejectsWhenFalse(t *testing.T) { + srv, _ := newTestServer(t) + defer srv.Close() + + req := provider.Request{ + Model: "deepseek-chat", + Messages: []provider.Message{{ + Role: "user", + Content: "describe", + Images: []provider.Image{{URL: "https://example.com/a.png"}}, + }}, + } + p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{ + SupportsVision: func(string) bool { return false }, + }) + _, err := p.Complete(context.Background(), req) + var fue *openaicompat.FeatureUnsupportedError + if !errors.As(err, &fue) { + t.Fatalf("want FeatureUnsupportedError, got %v", err) + } + if fue.Feature != "vision" || fue.Model != "deepseek-chat" { + t.Errorf("unexpected err: %+v", fue) + } +} + +func TestComplete_SupportsToolsRejectsWhenFalse(t *testing.T) { + srv, _ := newTestServer(t) + defer srv.Close() + + req := provider.Request{ + Model: "deepseek-reasoner", + Messages: []provider.Message{{Role: "user", Content: "hi"}}, + Tools: []provider.ToolDef{ + {Name: "get_weather", Description: "weather", Schema: map[string]any{"type": "object"}}, + }, + } + p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{ + SupportsTools: func(m string) bool { return !strings.Contains(m, "reasoner") }, + }) + _, err := p.Complete(context.Background(), req) + var fue *openaicompat.FeatureUnsupportedError + if !errors.As(err, &fue) { + t.Fatalf("want FeatureUnsupportedError, got %v", err) + } + if fue.Feature != "tools" { + t.Errorf("feature = %q, want tools", fue.Feature) + } +} + +func TestComplete_SupportsAudioRejectsWhenFalse(t *testing.T) { + srv, _ := newTestServer(t) + defer srv.Close() + + req := provider.Request{ + Model: "groq-llama", + Messages: []provider.Message{{ + Role: "user", + Audio: []provider.Audio{{Base64: "AAA=", ContentType: "audio/wav"}}, + }}, + } + p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{ + SupportsAudio: func(string) bool { return false }, + }) + _, err := p.Complete(context.Background(), req) + var fue *openaicompat.FeatureUnsupportedError + if !errors.As(err, &fue) { + t.Fatalf("want FeatureUnsupportedError, got %v", err) + } + if fue.Feature != "audio" { + t.Errorf("feature = %q, want audio", fue.Feature) + } +} + +func TestComplete_MaxImagesPerMessage(t *testing.T) { + srv, _ := newTestServer(t) + defer srv.Close() + + req := provider.Request{ + Model: "anything", + Messages: []provider.Message{{ + Role: "user", + Images: []provider.Image{ + {URL: "a"}, {URL: "b"}, {URL: "c"}, + }, + }}, + } + p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{MaxImagesPerMessage: 2}) + _, err := p.Complete(context.Background(), req) + if err == nil || !strings.Contains(err.Error(), "max allowed is 2") { + t.Fatalf("want max-images error, got %v", err) + } + + // Exactly at limit succeeds. + req.Messages[0].Images = req.Messages[0].Images[:2] + if _, err := p.Complete(context.Background(), req); err != nil { + t.Errorf("at-limit request should succeed, got %v", err) + } +} + +func TestComplete_CustomizeRequestInvoked(t *testing.T) { + srv, body := newTestServer(t) + defer srv.Close() + + called := false + p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{ + CustomizeRequest: func(params *openai.ChatCompletionNewParams) { + called = true + // Confirm we receive a non-empty built request. + if params.Model != "gpt-4o" { + t.Errorf("CustomizeRequest saw model %q, want gpt-4o", params.Model) + } + // Mutation here should end up on the wire. + params.User = openai.String("test-user") + }, + }) + if _, err := p.Complete(context.Background(), textReq("gpt-4o", "hi")); err != nil { + t.Fatalf("Complete: %v", err) + } + if !called { + t.Fatal("CustomizeRequest hook was not invoked") + } + if !strings.Contains(string(*body), `"user":"test-user"`) { + t.Errorf("mutation from CustomizeRequest not reflected on wire: %s", *body) + } +} + +func TestStream_EmitsDoneAndText(t *testing.T) { + // SSE stream with one content chunk then [DONE]. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + for _, line := range []string{ + `data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"hel"}}]}`, + `data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"lo"}}]}`, + `data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}`, + `data: [DONE]`, + } { + _, _ = io.WriteString(w, line+"\n\n") + if flusher != nil { + flusher.Flush() + } + } + })) + defer srv.Close() + + p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{}) + events := make(chan provider.StreamEvent, 16) + go func() { + _ = p.Stream(context.Background(), textReq("gpt-4o", "hi"), events) + close(events) + }() + + var text strings.Builder + var sawDone bool + var doneUsage *provider.Usage + for ev := range events { + switch ev.Type { + case provider.StreamEventText: + text.WriteString(ev.Text) + case provider.StreamEventDone: + sawDone = true + if ev.Response != nil { + doneUsage = ev.Response.Usage + } + } + } + if text.String() != "hello" { + t.Errorf("got text %q, want %q", text.String(), "hello") + } + if !sawDone { + t.Fatal("no Done event emitted") + } + if doneUsage == nil || doneUsage.TotalTokens != 3 { + t.Errorf("usage on Done = %+v, want TotalTokens=3", doneUsage) + } +} + +func TestStream_RulesCheckedBeforeNetwork(t *testing.T) { + // Server should never be hit when rules reject up front. + hit := false + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hit = true + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{ + SupportsVision: func(string) bool { return false }, + }) + req := provider.Request{ + Model: "no-vision-model", + Messages: []provider.Message{{ + Role: "user", + Images: []provider.Image{{URL: "a"}}, + }}, + } + events := make(chan provider.StreamEvent, 4) + err := p.Stream(context.Background(), req, events) + var fue *openaicompat.FeatureUnsupportedError + if !errors.As(err, &fue) { + t.Fatalf("want FeatureUnsupportedError, got %v", err) + } + if hit { + t.Error("server was contacted despite Rules violation") + } +} diff --git a/v2/registry.go b/v2/registry.go new file mode 100644 index 0000000..58ea1f5 --- /dev/null +++ b/v2/registry.go @@ -0,0 +1,158 @@ +package llm + +import ( + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/deepseek" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/groq" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/moonshot" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/ollama" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openai" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/xai" +) + +// ProviderInfo describes a registered provider for discovery purposes (CLI +// pickers, wiring layers, admin tools). It is the single source of truth for +// "what providers exist and how do I instantiate one." +type ProviderInfo struct { + // Name is the short lowercase identifier used in provider/model strings + // (e.g., "openai", "deepseek", "moonshot"). + Name string + + // DisplayName is a human-readable label for UIs. + DisplayName string + + // EnvKey is the conventional environment variable that holds the API key + // for this provider. Empty string means "no key needed" (e.g., Ollama). + EnvKey string + + // DefaultURL is the default base URL used when no override is supplied. + DefaultURL string + + // Models is a list of well-known model names, populated for CLI pickers + // and similar. It is not exhaustive and not validated against the API. + Models []string + + // New returns a ready-to-use Client for this provider, given an API key + // (ignored for key-less providers like Ollama) and optional ClientOptions. + New func(apiKey string, opts ...ClientOption) *Client +} + +// providerRegistry is the in-process list of known providers. Order is +// intentional: the three original providers first, then OpenAI-compatible +// additions in the order they were added. +var providerRegistry = []ProviderInfo{ + { + Name: "openai", + DisplayName: "OpenAI", + EnvKey: "OPENAI_API_KEY", + DefaultURL: openai.DefaultBaseURL, + Models: []string{ + "gpt-4.1", "gpt-4.1-mini", "gpt-4.1-nano", + "gpt-4o", "gpt-4o-mini", + "gpt-4-turbo", "gpt-3.5-turbo", + "o1", "o1-mini", "o1-preview", "o3-mini", + }, + New: OpenAI, + }, + { + Name: "anthropic", + DisplayName: "Anthropic", + EnvKey: "ANTHROPIC_API_KEY", + DefaultURL: "https://api.anthropic.com", + Models: []string{ + "claude-opus-4-7", + "claude-sonnet-4-6", + "claude-haiku-4-5-20251001", + "claude-opus-4-20250514", + "claude-sonnet-4-20250514", + "claude-3-7-sonnet-20250219", + "claude-3-5-sonnet-20241022", + "claude-3-5-haiku-20241022", + }, + New: Anthropic, + }, + { + Name: "google", + DisplayName: "Google", + EnvKey: "GOOGLE_API_KEY", + DefaultURL: "https://generativelanguage.googleapis.com", + Models: []string{ + "gemini-2.0-flash", "gemini-2.0-flash-lite", + "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b", + }, + New: Google, + }, + { + Name: "deepseek", + DisplayName: "DeepSeek", + EnvKey: "DEEPSEEK_API_KEY", + DefaultURL: deepseek.DefaultBaseURL, + Models: []string{"deepseek-chat", "deepseek-reasoner"}, + New: DeepSeek, + }, + { + Name: "moonshot", + DisplayName: "Moonshot (Kimi)", + EnvKey: "MOONSHOT_API_KEY", + DefaultURL: moonshot.DefaultBaseURL, + Models: []string{ + "kimi-k2-0711-preview", + "moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k", + "moonshot-v1-8k-vision-preview", + }, + New: Moonshot, + }, + { + Name: "xai", + DisplayName: "xAI (Grok)", + EnvKey: "XAI_API_KEY", + DefaultURL: xai.DefaultBaseURL, + Models: []string{ + "grok-2", "grok-2-mini", "grok-2-vision", "grok-beta", + }, + New: XAI, + }, + { + Name: "groq", + DisplayName: "Groq", + EnvKey: "GROQ_API_KEY", + DefaultURL: groq.DefaultBaseURL, + Models: []string{ + "llama-3.3-70b-versatile", + "llama-3.1-8b-instant", + "mixtral-8x7b-32768", + "gemma2-9b-it", + "llama-3.2-90b-vision-preview", + }, + New: Groq, + }, + { + Name: "ollama", + DisplayName: "Ollama (local)", + EnvKey: "", // no key needed + DefaultURL: ollama.DefaultBaseURL, + Models: []string{ + "llama3.2", "llama3.1", "qwen2.5", "mistral", "gemma2", "phi4", + }, + New: func(_ string, opts ...ClientOption) *Client { return Ollama(opts...) }, + }, +} + +// Providers returns a copy of the registered provider list so callers cannot +// mutate library state. +func Providers() []ProviderInfo { + out := make([]ProviderInfo, len(providerRegistry)) + copy(out, providerRegistry) + return out +} + +// ProviderByName returns the registered ProviderInfo with the given name, or +// nil if no such provider is registered. Name matching is exact. +func ProviderByName(name string) *ProviderInfo { + for i := range providerRegistry { + if providerRegistry[i].Name == name { + p := providerRegistry[i] + return &p + } + } + return nil +} diff --git a/v2/xai/xai.go b/v2/xai/xai.go new file mode 100644 index 0000000..500fd4c --- /dev/null +++ b/v2/xai/xai.go @@ -0,0 +1,29 @@ +// Package xai implements the go-llm v2 provider interface for xAI (Grok, +// https://x.ai/api). xAI speaks OpenAI Chat Completions, so this package is a +// thin wrapper over openaicompat with its own defaults and per-model Rules. +package xai + +import ( + "strings" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" +) + +// DefaultBaseURL is the public xAI API endpoint. +const DefaultBaseURL = "https://api.x.ai/v1" + +// Provider is a type alias over openaicompat.Provider. +type Provider = openaicompat.Provider + +// New creates a new xAI provider. An empty baseURL uses DefaultBaseURL. +func New(apiKey, baseURL string) *Provider { + if baseURL == "" { + baseURL = DefaultBaseURL + } + return openaicompat.New(apiKey, baseURL, openaicompat.Rules{ + // Grok models whose name contains "vision" accept images; others don't. + SupportsVision: func(m string) bool { + return strings.Contains(m, "vision") + }, + }) +} diff --git a/v2/xai/xai_test.go b/v2/xai/xai_test.go new file mode 100644 index 0000000..bed5b6b --- /dev/null +++ b/v2/xai/xai_test.go @@ -0,0 +1,33 @@ +package xai_test + +import ( + "context" + "errors" + "testing" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/xai" +) + +func TestNew_Basic(t *testing.T) { + if p := xai.New("key", ""); p == nil { + t.Fatal("New returned nil") + } +} + +func TestRules_Grok2RejectsImages(t *testing.T) { + p := xai.New("key", "") + req := provider.Request{ + Model: "grok-2", + Messages: []provider.Message{{ + Role: "user", + Images: []provider.Image{{URL: "a"}}, + }}, + } + _, err := p.Complete(context.Background(), req) + var fue *openaicompat.FeatureUnsupportedError + if !errors.As(err, &fue) || fue.Feature != "vision" { + t.Fatalf("want FeatureUnsupportedError(vision), got %v", err) + } +}