diff --git a/CLAUDE.md b/CLAUDE.md index 3824b30..4c3fd66 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,88 +1,31 @@ # CLAUDE.md for go-llm -## Build and Test Commands -- Build project: `go build ./...` -- Run all tests: `go test ./...` -- Run specific test: `go test -v -run ./...` -- Tidy dependencies: `go mod tidy` +All Go code now lives under `v2/`. The module path is +`gitea.stevedudenhoeffer.com/steve/go-llm/v2`. There is no module at the +repository root anymore; the v1 code at the root was deleted after all +consumers migrated to v2. -## Code Style Guidelines -- **Indentation**: Use standard Go tabs for indentation. -- **Naming**: - - Use `camelCase` for internal/private variables and functions. - - Use `PascalCase` for exported types, functions, and struct fields. - - Interface names should be concise (e.g., `LLM`, `ChatCompletion`). -- **Error Handling**: - - Always check and handle errors immediately. - - Wrap errors with context using `fmt.Errorf("%w: ...", err)`. - - Use the project's internal `Error` struct in `error.go` when differentiating between error types is needed. -- **Project Structure**: - - `llm.go`: Contains core interfaces (`LLM`, `ChatCompletion`) and shared types (`Message`, `Role`, `Image`). - - Provider implementations are in `openai.go`, `anthropic.go`, and `google.go`. - - Schema definitions for tool calling are in the `schema/` directory. - - `mcp.go`: MCP (Model Context Protocol) client integration for connecting to MCP servers. -- **Imports**: Organize imports into groups: standard library, then third-party libraries. -- **Documentation**: Use standard Go doc comments for exported symbols. -- **README.md**: The README.md file should always be kept up to date with any significant changes to the project. +See `v2/CLAUDE.md` for build/test commands and per-package guidance. -## CLI Tool -- Build CLI: `go build ./cmd/llm` -- Run CLI: `./llm` (or `llm.exe` on Windows) -- Run without building: `go run ./cmd/llm` +## CLI -### CLI Features -- Interactive TUI for testing all go-llm features -- Support for OpenAI, Anthropic, and Google providers -- Image input (file path, URL, or base64) -- Tool/function calling with demo tools -- Temperature control and settings +The interactive TUI lives at `v2/cmd/llm`: -### Key Bindings -- `Enter` - Send message -- `Ctrl+I` - Add image -- `Ctrl+T` - Toggle tools panel -- `Ctrl+P` - Change provider -- `Ctrl+M` - Change model -- `Ctrl+S` - Settings -- `Ctrl+N` - New conversation -- `Esc` - Exit/Cancel - -## MCP (Model Context Protocol) Support - -The library supports connecting to MCP servers to use their tools. MCP servers can be connected via: -- **stdio**: Run a command as a subprocess -- **sse**: Connect to an SSE endpoint -- **http**: Connect to a streamable HTTP endpoint - -### Usage Example -```go -ctx := context.Background() - -// Create and connect to an MCP server -server := &llm.MCPServer{ - Name: "my-server", - Command: "my-mcp-server", - Args: []string{"--some-flag"}, -} -if err := server.Connect(ctx); err != nil { - log.Fatal(err) -} -defer server.Close() - -// Add the server to a toolbox -toolbox := llm.NewToolBox().WithMCPServer(server) - -// Use the toolbox in requests - MCP tools are automatically available -req := llm.Request{ - Messages: []llm.Message{{Role: llm.RoleUser, Text: "Use the MCP tool"}}, - Toolbox: toolbox, -} +``` +cd v2 && go run ./cmd/llm ``` -### MCPServer Options -- `Name`: Friendly name for logging -- `Command`: Command to run (for stdio transport) -- `Args`: Command arguments -- `Env`: Additional environment variables -- `URL`: Endpoint URL (for sse/http transport) -- `Transport`: "stdio" (default), "sse", or "http" +It iterates `llm.Providers()` so every registered provider (OpenAI, Anthropic, +Google, DeepSeek, Moonshot, xAI, Groq, Ollama) appears in the picker +automatically. Status is derived from each provider's env var; Ollama shows as +"(local)" because it needs no key. + +### Key bindings +- `Enter` — Send message +- `Ctrl+I` — Add image +- `Ctrl+T` — Toggle tools panel +- `Ctrl+P` — Change provider +- `Ctrl+M` — Change model +- `Ctrl+S` — Settings +- `Ctrl+N` — New conversation +- `Esc` — Exit/Cancel diff --git a/anthropic.go b/anthropic.go deleted file mode 100644 index b5e0244..0000000 --- a/anthropic.go +++ /dev/null @@ -1,225 +0,0 @@ -package llm - -import ( - "context" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "log" - "log/slog" - "net/http" - - "gitea.stevedudenhoeffer.com/steve/go-llm/internal/imageutil" - - anth "github.com/liushuangls/go-anthropic/v2" -) - -type anthropicImpl struct { - key string - model string -} - -var _ LLM = anthropicImpl{} - -func (a anthropicImpl) ModelVersion(modelVersion string) (ChatCompletion, error) { - a.model = modelVersion - - // TODO: model verification? - return a, nil -} - -func deferClose(c io.Closer) { - err := c.Close() - if err != nil { - slog.Error("error closing", "error", err) - } -} - -func (a anthropicImpl) requestToAnthropicRequest(req Request) anth.MessagesRequest { - res := anth.MessagesRequest{ - Model: anth.Model(a.model), - MaxTokens: 1000, - } - - msgs := []anth.Message{} - - // we gotta convert messages into anthropic messages, however - // anthropic does not have a "system" message type, so we need to - // append it to the res.System field instead - - for _, msg := range req.Messages { - if msg.Role == RoleSystem { - if len(res.System) > 0 { - res.System += "\n" - } - res.System += msg.Text - } else { - role := anth.RoleUser - - if msg.Role == RoleAssistant { - role = anth.RoleAssistant - } - - m := anth.Message{ - Role: role, - Content: []anth.MessageContent{}, - } - - if msg.Text != "" { - m.Content = append(m.Content, anth.MessageContent{ - Type: anth.MessagesContentTypeText, - Text: &msg.Text, - }) - } - - for _, img := range msg.Images { - // anthropic doesn't allow the assistant to send images, so we need to say it's from the user - if m.Role == anth.RoleAssistant { - m.Role = anth.RoleUser - } - - if img.Base64 != "" { - // Anthropic models expect images to be < 5MiB in size - raw, err := base64.StdEncoding.DecodeString(img.Base64) - - if err != nil { - continue - } - - // Check if image size exceeds 5MiB (5242880 bytes) - if len(raw) >= 5242880 { - - compressed, mime, err := imageutil.CompressImage(img.Base64, 5*1024*1024) - - // just replace the image with the compressed one - if err != nil { - continue - } - - img.Base64 = compressed - img.ContentType = mime - } - - m.Content = append(m.Content, anth.NewImageMessageContent( - anth.NewMessageContentSource( - anth.MessagesContentSourceTypeBase64, - img.ContentType, - img.Base64, - ))) - } else if img.Url != "" { - - // download the image - cl, err := http.NewRequest(http.MethodGet, img.Url, nil) - if err != nil { - log.Println("failed to create request", err) - continue - } - - resp, err := http.DefaultClient.Do(cl) - if err != nil { - log.Println("failed to download image", err) - continue - } - - defer deferClose(resp.Body) - - img.ContentType = resp.Header.Get("Content-Type") - - // read the image - b, err := io.ReadAll(resp.Body) - if err != nil { - log.Println("failed to read image", err) - continue - } - - // base64 encode the image - img.Base64 = string(b) - - m.Content = append(m.Content, anth.NewImageMessageContent( - anth.NewMessageContentSource( - anth.MessagesContentSourceTypeBase64, - img.ContentType, - img.Base64, - ))) - } - } - - // if this has the same role as the previous message, we can append it to the previous message - // as anthropic expects alternating assistant and user roles - if len(msgs) > 0 && msgs[len(msgs)-1].Role == role { - m2 := &msgs[len(msgs)-1] - - m2.Content = append(m2.Content, m.Content...) - } else { - msgs = append(msgs, m) - } - } - } - - for _, tool := range req.Toolbox.Functions() { - res.Tools = append(res.Tools, anth.ToolDefinition{ - Name: tool.Name, - Description: tool.Description, - InputSchema: tool.Parameters.AnthropicInputSchema(), - }) - } - - res.Messages = msgs - - if req.Temperature != nil { - var f = float32(*req.Temperature) - res.Temperature = &f - } - - log.Println("llm request to anthropic request", res) - - return res -} - -func (a anthropicImpl) responseToLLMResponse(in anth.MessagesResponse) Response { - choice := ResponseChoice{} - for _, msg := range in.Content { - - switch msg.Type { - case anth.MessagesContentTypeText: - if msg.Text != nil { - choice.Content += *msg.Text - } - - case anth.MessagesContentTypeToolUse: - if msg.MessageContentToolUse != nil { - b, e := json.Marshal(msg.MessageContentToolUse.Input) - if e != nil { - log.Println("failed to marshal input", e) - } else { - choice.Calls = append(choice.Calls, ToolCall{ - ID: msg.MessageContentToolUse.ID, - FunctionCall: FunctionCall{ - Name: msg.MessageContentToolUse.Name, - Arguments: string(b), - }, - }) - } - } - } - } - - log.Println("anthropic response to llm response", choice) - - return Response{ - Choices: []ResponseChoice{choice}, - } -} - -func (a anthropicImpl) ChatComplete(ctx context.Context, req Request) (Response, error) { - cl := anth.NewClient(a.key) - - res, err := cl.CreateMessages(ctx, a.requestToAnthropicRequest(req)) - - if err != nil { - return Response{}, fmt.Errorf("failed to chat complete: %w", err) - } - - return a.responseToLLMResponse(res), nil -} diff --git a/cmd/llm/.env.example b/cmd/llm/.env.example deleted file mode 100644 index 3ddaa80..0000000 --- a/cmd/llm/.env.example +++ /dev/null @@ -1,11 +0,0 @@ -# go-llm CLI Environment Variables -# Copy this file to .env and fill in your API keys - -# OpenAI API Key (https://platform.openai.com/api-keys) -OPENAI_API_KEY= - -# Anthropic API Key (https://console.anthropic.com/settings/keys) -ANTHROPIC_API_KEY= - -# Google AI API Key (https://aistudio.google.com/apikey) -GOOGLE_API_KEY= diff --git a/cmd/llm/commands.go b/cmd/llm/commands.go deleted file mode 100644 index e5e7941..0000000 --- a/cmd/llm/commands.go +++ /dev/null @@ -1,182 +0,0 @@ -package main - -import ( - "context" - "encoding/base64" - "fmt" - "net/http" - "os" - "strings" - - tea "github.com/charmbracelet/bubbletea" - - llm "gitea.stevedudenhoeffer.com/steve/go-llm" -) - -// Message types for async operations - -// ChatResponseMsg contains the response from a chat completion -type ChatResponseMsg struct { - Response llm.Response - Err error -} - -// ToolExecutionMsg contains results from tool execution -type ToolExecutionMsg struct { - Results []llm.ToolCallResponse - Err error -} - -// ImageLoadedMsg contains a loaded image -type ImageLoadedMsg struct { - Image llm.Image - Err error -} - -// sendChatRequest sends a chat completion request -func sendChatRequest(chat llm.ChatCompletion, req llm.Request) tea.Cmd { - return func() tea.Msg { - resp, err := chat.ChatComplete(context.Background(), req) - return ChatResponseMsg{Response: resp, Err: err} - } -} - -// executeTools executes tool calls and returns results -func executeTools(toolbox llm.ToolBox, req llm.Request, resp llm.ResponseChoice) tea.Cmd { - return func() tea.Msg { - ctx := llm.NewContext(context.Background(), req, &resp, nil) - var results []llm.ToolCallResponse - - for _, call := range resp.Calls { - result, err := toolbox.Execute(ctx, call) - results = append(results, llm.ToolCallResponse{ - ID: call.ID, - Result: result, - Error: err, - }) - } - - return ToolExecutionMsg{Results: results, Err: nil} - } -} - -// loadImageFromPath loads an image from a file path -func loadImageFromPath(path string) tea.Cmd { - return func() tea.Msg { - // Clean up the path - path = strings.TrimSpace(path) - path = strings.Trim(path, "\"'") - - // Read the file - data, err := os.ReadFile(path) - if err != nil { - return ImageLoadedMsg{Err: fmt.Errorf("failed to read image file: %w", err)} - } - - // Detect content type - contentType := http.DetectContentType(data) - if !strings.HasPrefix(contentType, "image/") { - return ImageLoadedMsg{Err: fmt.Errorf("file is not an image: %s", contentType)} - } - - // Base64 encode - encoded := base64.StdEncoding.EncodeToString(data) - - return ImageLoadedMsg{ - Image: llm.Image{ - Base64: encoded, - ContentType: contentType, - }, - } - } -} - -// loadImageFromURL loads an image from a URL -func loadImageFromURL(url string) tea.Cmd { - return func() tea.Msg { - url = strings.TrimSpace(url) - - // For URL images, we can just use the URL directly - return ImageLoadedMsg{ - Image: llm.Image{ - Url: url, - }, - } - } -} - -// loadImageFromBase64 loads an image from base64 data -func loadImageFromBase64(data string) tea.Cmd { - return func() tea.Msg { - data = strings.TrimSpace(data) - - // Check if it's a data URL - if strings.HasPrefix(data, "data:") { - // Parse data URL: data:image/png;base64,.... - parts := strings.SplitN(data, ",", 2) - if len(parts) != 2 { - return ImageLoadedMsg{Err: fmt.Errorf("invalid data URL format")} - } - - // Extract content type from first part - mediaType := strings.TrimPrefix(parts[0], "data:") - mediaType = strings.TrimSuffix(mediaType, ";base64") - - return ImageLoadedMsg{ - Image: llm.Image{ - Base64: parts[1], - ContentType: mediaType, - }, - } - } - - // Assume it's raw base64, try to detect content type - decoded, err := base64.StdEncoding.DecodeString(data) - if err != nil { - return ImageLoadedMsg{Err: fmt.Errorf("invalid base64 data: %w", err)} - } - - contentType := http.DetectContentType(decoded) - if !strings.HasPrefix(contentType, "image/") { - return ImageLoadedMsg{Err: fmt.Errorf("data is not an image: %s", contentType)} - } - - return ImageLoadedMsg{ - Image: llm.Image{ - Base64: data, - ContentType: contentType, - }, - } - } -} - -// buildRequest builds a chat request from the current state -func buildRequest(m *Model, userText string) llm.Request { - // Create the user message with any pending images - userMsg := llm.Message{ - Role: llm.RoleUser, - Text: userText, - Images: m.pendingImages, - } - - req := llm.Request{ - Conversation: m.conversation, - Messages: []llm.Message{ - {Role: llm.RoleSystem, Text: m.systemPrompt}, - userMsg, - }, - Temperature: m.temperature, - } - - // Add toolbox if enabled - if m.toolsEnabled && len(m.toolbox.Functions()) > 0 { - req.Toolbox = m.toolbox.WithRequireTool(false) - } - - return req -} - -// buildFollowUpRequest builds a follow-up request after tool execution -func buildFollowUpRequest(m *Model, previousReq llm.Request, resp llm.ResponseChoice, toolResults []llm.ToolCallResponse) llm.Request { - return previousReq.NextRequest(resp, toolResults) -} diff --git a/cmd/llm/tools.go b/cmd/llm/tools.go deleted file mode 100644 index 58f8164..0000000 --- a/cmd/llm/tools.go +++ /dev/null @@ -1,105 +0,0 @@ -package main - -import ( - "fmt" - "math" - "strconv" - "strings" - "time" - - llm "gitea.stevedudenhoeffer.com/steve/go-llm" -) - -// TimeParams is the parameter struct for the GetTime function -type TimeParams struct{} - -// GetTime returns the current time -func GetTime(_ *llm.Context, _ TimeParams) (any, error) { - return time.Now().Format("Monday, January 2, 2006 3:04:05 PM MST"), nil -} - -// CalcParams is the parameter struct for the Calculate function -type CalcParams struct { - A float64 `json:"a" description:"First number"` - B float64 `json:"b" description:"Second number"` - Op string `json:"op" description:"Operation: add, subtract, multiply, divide, power, sqrt, mod"` -} - -// Calculate performs basic math operations -func Calculate(_ *llm.Context, params CalcParams) (any, error) { - switch strings.ToLower(params.Op) { - case "add", "+": - return params.A + params.B, nil - case "subtract", "sub", "-": - return params.A - params.B, nil - case "multiply", "mul", "*": - return params.A * params.B, nil - case "divide", "div", "/": - if params.B == 0 { - return nil, fmt.Errorf("division by zero") - } - return params.A / params.B, nil - case "power", "pow", "^": - return math.Pow(params.A, params.B), nil - case "sqrt": - if params.A < 0 { - return nil, fmt.Errorf("cannot take square root of negative number") - } - return math.Sqrt(params.A), nil - case "mod", "%": - return math.Mod(params.A, params.B), nil - default: - return nil, fmt.Errorf("unknown operation: %s", params.Op) - } -} - -// WeatherParams is the parameter struct for the GetWeather function -type WeatherParams struct { - Location string `json:"location" description:"City name or location"` -} - -// GetWeather returns mock weather data (for demo purposes) -func GetWeather(_ *llm.Context, params WeatherParams) (any, error) { - // This is a demo function - returns mock data - weathers := []string{"sunny", "cloudy", "rainy", "partly cloudy", "windy"} - temps := []int{65, 72, 58, 80, 45} - - // Use location string to deterministically pick weather - idx := len(params.Location) % len(weathers) - - return map[string]any{ - "location": params.Location, - "temperature": strconv.Itoa(temps[idx]) + "F", - "condition": weathers[idx], - "humidity": "45%", - "note": "This is mock data for demonstration purposes", - }, nil -} - -// RandomNumberParams is the parameter struct for the RandomNumber function -type RandomNumberParams struct { - Min int `json:"min" description:"Minimum value (inclusive)"` - Max int `json:"max" description:"Maximum value (inclusive)"` -} - -// RandomNumber generates a pseudo-random number (using current time nanoseconds) -func RandomNumber(_ *llm.Context, params RandomNumberParams) (any, error) { - if params.Min > params.Max { - return nil, fmt.Errorf("min cannot be greater than max") - } - // Simple pseudo-random using time - n := time.Now().UnixNano() - rangeSize := params.Max - params.Min + 1 - result := params.Min + int(n%int64(rangeSize)) - return result, nil -} - -// createDemoToolbox creates a toolbox with demo tools for testing -func createDemoToolbox() llm.ToolBox { - return llm.NewToolBox( - llm.NewFunction("get_time", "Get the current date and time", GetTime), - llm.NewFunction("calculate", "Perform basic math operations (add, subtract, multiply, divide, power, sqrt, mod)", Calculate), - llm.NewFunction("get_weather", "Get weather information for a location (demo data)", GetWeather), - llm.NewFunction("random_number", "Generate a random number between min and max", RandomNumber), - ) -} diff --git a/context.go b/context.go deleted file mode 100644 index 4990bb0..0000000 --- a/context.go +++ /dev/null @@ -1,120 +0,0 @@ -package llm - -import ( - "context" - "time" -) - -type Context struct { - context.Context - request Request - response *ResponseChoice - toolcall *ToolCall - syntheticFields map[string]string -} - -func (c *Context) ToNewRequest(toolResults ...ToolCallResponse) Request { - var res Request - - res.Toolbox = c.request.Toolbox - res.Temperature = c.request.Temperature - - res.Conversation = make([]Input, len(c.request.Conversation)) - copy(res.Conversation, c.request.Conversation) - - // now for every input message, convert those to an Input to add to the conversation - for _, msg := range c.request.Messages { - res.Conversation = append(res.Conversation, msg) - } - - // if there are tool calls, then we need to add those to the conversation - if c.response != nil { - res.Conversation = append(res.Conversation, *c.response) - } - - // if there are tool results, then we need to add those to the conversation - for _, result := range toolResults { - res.Conversation = append(res.Conversation, result) - } - - return res -} - -func NewContext(ctx context.Context, request Request, response *ResponseChoice, toolcall *ToolCall) *Context { - return &Context{Context: ctx, request: request, response: response, toolcall: toolcall} -} - -func (c *Context) Request() Request { - return c.request -} - -func (c *Context) Response() *ResponseChoice { - return c.response -} - -func (c *Context) ToolCall() *ToolCall { - return c.toolcall -} - -func (c *Context) SyntheticFields() map[string]string { - if c.syntheticFields == nil { - c.syntheticFields = map[string]string{} - } - - return c.syntheticFields -} - -func (c *Context) WithContext(ctx context.Context) *Context { - return &Context{Context: ctx, request: c.request, response: c.response, toolcall: c.toolcall, syntheticFields: c.syntheticFields} -} - -func (c *Context) WithRequest(request Request) *Context { - return &Context{Context: c.Context, request: request, response: c.response, toolcall: c.toolcall, syntheticFields: c.syntheticFields} -} - -func (c *Context) WithResponse(response *ResponseChoice) *Context { - return &Context{Context: c.Context, request: c.request, response: response, toolcall: c.toolcall, syntheticFields: c.syntheticFields} -} - -func (c *Context) WithToolCall(toolcall *ToolCall) *Context { - return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: toolcall, syntheticFields: c.syntheticFields} -} - -func (c *Context) WithSyntheticFields(syntheticFields map[string]string) *Context { - return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: c.toolcall, syntheticFields: syntheticFields} -} - -func (c *Context) Deadline() (deadline time.Time, ok bool) { - return c.Context.Deadline() -} - -func (c *Context) Done() <-chan struct{} { - return c.Context.Done() -} - -func (c *Context) Err() error { - return c.Context.Err() -} - -func (c *Context) Value(key any) any { - switch key { - case "request": - return c.request - - case "response": - return c.response - - case "toolcall": - return c.toolcall - - case "syntheticFields": - return c.syntheticFields - - } - return c.Context.Value(key) -} - -func (c *Context) WithTimeout(timeout time.Duration) (*Context, context.CancelFunc) { - ctx, cancel := context.WithTimeout(c.Context, timeout) - return c.WithContext(ctx), cancel -} diff --git a/error.go b/error.go deleted file mode 100644 index 8f432bb..0000000 --- a/error.go +++ /dev/null @@ -1,21 +0,0 @@ -package llm - -import "fmt" - -// Error is essentially just an error, but it is used to differentiate between a normal error and a fatal error. -type Error struct { - error - - Source error - Parameter error -} - -func newError(parent error, err error) Error { - e := fmt.Errorf("%w: %w", parent, err) - return Error{ - error: e, - - Source: parent, - Parameter: err, - } -} diff --git a/function.go b/function.go deleted file mode 100644 index 4ef5df0..0000000 --- a/function.go +++ /dev/null @@ -1,136 +0,0 @@ -package llm - -import ( - "context" - "encoding/json" - "fmt" - "log/slog" - "reflect" - "time" - - "gitea.stevedudenhoeffer.com/steve/go-llm/schema" -) - -type Function struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - Strict bool `json:"strict,omitempty"` - Parameters schema.Type `json:"parameters"` - - Forced bool `json:"forced,omitempty"` - - // Timeout is the maximum time to wait for the function to complete - Timeout time.Duration `json:"-"` - - // fn is the function to call, only set if this is constructed with NewFunction - fn reflect.Value - - paramType reflect.Type -} - -func (f Function) WithSyntheticField(name string, description string) Function { - if obj, o := f.Parameters.(schema.Object); o { - f.Parameters = obj.WithSyntheticField(name, description) - } - - return f -} - -func (f Function) WithSyntheticFields(fieldsAndDescriptions map[string]string) Function { - if obj, o := f.Parameters.(schema.Object); o { - for k, v := range fieldsAndDescriptions { - obj = obj.WithSyntheticField(k, v) - } - f.Parameters = obj - } - - return f -} - -func (f Function) WithDescription(description string) Function { - f.Description = description - return f -} - -func (f Function) Execute(ctx *Context, input string) (any, error) { - if !f.fn.IsValid() { - return "", fmt.Errorf("function %s is not implemented", f.Name) - } - - slog.Info("Function.Execute", "name", f.Name, "input", input, "f", f.paramType) - // first, we need to parse the input into the struct - p := reflect.New(f.paramType) - fmt.Println("Function.Execute", f.Name, "input:", input) - - var vals map[string]any - err := json.Unmarshal([]byte(input), &vals) - - var syntheticFields map[string]string - - // first eat up any synthetic fields - if obj, o := f.Parameters.(schema.Object); o { - for k := range obj.SyntheticFields() { - key := schema.SyntheticFieldPrefix + k - if val, ok := vals[key]; ok { - if syntheticFields == nil { - syntheticFields = map[string]string{} - } - - syntheticFields[k] = fmt.Sprint(val) - delete(vals, key) - } - } - } - - // now for any remaining fields, re-marshal them into json and then unmarshal into the struct - b, err := json.Marshal(vals) - if err != nil { - return "", fmt.Errorf("failed to marshal input: %w (input: %s)", err, input) - } - - // now we can unmarshal the input into the struct - err = json.Unmarshal(b, p.Interface()) - if err != nil { - return "", fmt.Errorf("failed to unmarshal input: %w (input: %s)", err, input) - } - - // now we can call the function - exec := func(ctx *Context) (any, error) { - out := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()}) - - if len(out) != 2 { - return "", fmt.Errorf("function %s must return two values, got %d", f.Name, len(out)) - } - - if out[1].IsNil() { - return out[0].Interface(), nil - } - - return "", out[1].Interface().(error) - } - - var cancel context.CancelFunc - if f.Timeout > 0 { - ctx, cancel = ctx.WithTimeout(f.Timeout) - defer cancel() - } - - return exec(ctx) -} - -type FunctionCall struct { - Name string `json:"name,omitempty"` - Arguments string `json:"arguments,omitempty"` -} - -func (fc *FunctionCall) toRaw() map[string]any { - res := map[string]interface{}{ - "name": fc.Name, - } - - if fc.Arguments != "" { - res["arguments"] = fc.Arguments - } - - return res -} diff --git a/functions.go b/functions.go deleted file mode 100644 index dfdcebd..0000000 --- a/functions.go +++ /dev/null @@ -1,35 +0,0 @@ -package llm - -import ( - "reflect" - - "gitea.stevedudenhoeffer.com/steve/go-llm/schema" -) - -// Parse takes a function pointer and returns a function object. -// fn must be a pointer to a function that takes a context.Context as its first argument, and then a struct that contains -// the parameters for the function. The struct must contain only the types: string, int, float64, bool, and pointers to -// those types. -// The struct parameters can have the following tags: -// - Description: a string that describes the parameter, passed to openaiImpl to tell it what the parameter is for - -func NewFunction[T any](name string, description string, fn func(*Context, T) (any, error)) Function { - var o T - - res := Function{ - Name: name, - Description: description, - Parameters: schema.GetType(o), - fn: reflect.ValueOf(fn), - paramType: reflect.TypeOf(o), - } - - if res.fn.Kind() != reflect.Func { - panic("fn must be a function") - } - if res.paramType.Kind() != reflect.Struct { - panic("function parameter must be a struct") - } - - return res -} diff --git a/go.mod b/go.mod deleted file mode 100644 index 8e10002..0000000 --- a/go.mod +++ /dev/null @@ -1,67 +0,0 @@ -module gitea.stevedudenhoeffer.com/steve/go-llm - -go 1.24.0 - -toolchain go1.24.2 - -require ( - github.com/charmbracelet/bubbles v0.21.0 - github.com/charmbracelet/bubbletea v1.3.10 - github.com/charmbracelet/lipgloss v1.1.0 - github.com/joho/godotenv v1.5.1 - github.com/liushuangls/go-anthropic/v2 v2.17.0 - github.com/modelcontextprotocol/go-sdk v1.2.0 - github.com/openai/openai-go v1.12.0 - golang.org/x/image v0.35.0 - google.golang.org/genai v1.43.0 -) - -require ( - cloud.google.com/go v0.123.0 // indirect - cloud.google.com/go/auth v0.18.1 // indirect - cloud.google.com/go/compute/metadata v0.9.0 // indirect - github.com/atotto/clipboard v0.1.4 // indirect - github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect - github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect - github.com/charmbracelet/x/ansi v0.10.1 // indirect - github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect - github.com/charmbracelet/x/term v0.2.1 // indirect - github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect - github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/go-logr/logr v1.4.3 // indirect - github.com/go-logr/stdr v1.2.2 // indirect - github.com/google/go-cmp v0.7.0 // indirect - github.com/google/jsonschema-go v0.3.0 // indirect - github.com/google/s2a-go v0.1.9 // indirect - github.com/googleapis/enterprise-certificate-proxy v0.3.11 // indirect - github.com/googleapis/gax-go/v2 v2.16.0 // indirect - github.com/gorilla/websocket v1.5.3 // indirect - github.com/lucasb-eyer/go-colorful v1.2.0 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-localereader v0.0.1 // indirect - github.com/mattn/go-runewidth v0.0.16 // indirect - github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect - github.com/muesli/cancelreader v0.2.2 // indirect - github.com/muesli/termenv v0.16.0 // indirect - github.com/rivo/uniseg v0.4.7 // indirect - github.com/tidwall/gjson v1.18.0 // indirect - github.com/tidwall/match v1.2.0 // indirect - github.com/tidwall/pretty v1.2.1 // indirect - github.com/tidwall/sjson v1.2.5 // indirect - github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect - github.com/yosida95/uritemplate/v3 v3.0.2 // indirect - go.opentelemetry.io/auto/sdk v1.2.1 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 // indirect - go.opentelemetry.io/otel v1.39.0 // indirect - go.opentelemetry.io/otel/metric v1.39.0 // indirect - go.opentelemetry.io/otel/trace v1.39.0 // indirect - golang.org/x/crypto v0.47.0 // indirect - golang.org/x/net v0.49.0 // indirect - golang.org/x/oauth2 v0.32.0 // indirect - golang.org/x/sys v0.40.0 // indirect - golang.org/x/text v0.33.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20260122232226-8e98ce8d340d // indirect - google.golang.org/grpc v1.78.0 // indirect - google.golang.org/protobuf v1.36.11 // indirect -) diff --git a/go.sum b/go.sum deleted file mode 100644 index a7a8e88..0000000 --- a/go.sum +++ /dev/null @@ -1,145 +0,0 @@ -cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE= -cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU= -cloud.google.com/go/auth v0.18.1 h1:IwTEx92GFUo2pJ6Qea0EU3zYvKnTAeRCODxfA/G5UWs= -cloud.google.com/go/auth v0.18.1/go.mod h1:GfTYoS9G3CWpRA3Va9doKN9mjPGRS+v41jmZAhBzbrA= -cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= -cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= -github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= -github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= -github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= -github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= -github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= -github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs= -github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg= -github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw= -github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4= -github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs= -github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk= -github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= -github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= -github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ= -github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE= -github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8= -github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs= -github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= -github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= -github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= -github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= -github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= -github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= -github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= -github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= -github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= -github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= -github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= -github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= -github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= -github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= -github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/googleapis/enterprise-certificate-proxy v0.3.11 h1:vAe81Msw+8tKUxi2Dqh/NZMz7475yUvmRIkXr4oN2ao= -github.com/googleapis/enterprise-certificate-proxy v0.3.11/go.mod h1:RFV7MUdlb7AgEq2v7FmMCfeSMCllAzWxFgRdusoGks8= -github.com/googleapis/gax-go/v2 v2.16.0 h1:iHbQmKLLZrexmb0OSsNGTeSTS0HO4YvFOG8g5E4Zd0Y= -github.com/googleapis/gax-go/v2 v2.16.0/go.mod h1:o1vfQjjNZn4+dPnRdl/4ZD7S9414Y4xA+a/6Icj6l14= -github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= -github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= -github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= -github.com/liushuangls/go-anthropic/v2 v2.17.0 h1:iBA6h7aghi1q86owEQ95XE2R2MF/0dQ7bCxtwTxOg4c= -github.com/liushuangls/go-anthropic/v2 v2.17.0/go.mod h1:a550cJXPoTG2FL3DvfKG2zzD5O2vjgvo4tHtoGPzFLU= -github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= -github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= -github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= -github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= -github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= -github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s= -github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= -github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= -github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= -github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= -github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= -github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= -github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= -github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0= -github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= -github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= -github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= -github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= -github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= -github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM= -github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= -github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= -github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= -github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= -github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= -github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= -github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= -github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= -go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= -go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 h1:ssfIgGNANqpVFCndZvcuyKbl0g+UAVcbBcqGkG28H0Y= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0/go.mod h1:GQ/474YrbE4Jx8gZ4q5I4hrhUzM6UPzyrqJYV2AqPoQ= -go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48= -go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8= -go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0= -go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs= -go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18= -go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE= -go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8= -go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= -go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= -go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= -golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= -golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= -golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E= -golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= -golang.org/x/image v0.35.0 h1:LKjiHdgMtO8z7Fh18nGY6KDcoEtVfsgLDPeLyguqb7I= -golang.org/x/image v0.35.0/go.mod h1:MwPLTVgvxSASsxdLzKrl8BRFuyqMyGhLwmC+TO1Sybk= -golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= -golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= -golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY= -golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= -golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= -golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= -golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= -golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= -gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= -gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= -google.golang.org/genai v1.43.0 h1:8vhqhzJNZu1U94e2m+KvDq/TUUjSmDrs1aKkvTa8SoM= -google.golang.org/genai v1.43.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260122232226-8e98ce8d340d h1:xXzuihhT3gL/ntduUZwHECzAn57E8dA6l8SOtYWdD8Q= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260122232226-8e98ce8d340d/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= -google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc= -google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U= -google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= -google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/google.go b/google.go deleted file mode 100644 index 6316f63..0000000 --- a/google.go +++ /dev/null @@ -1,165 +0,0 @@ -package llm - -import ( - "context" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - - "google.golang.org/genai" -) - -type googleImpl struct { - key string - model string -} - -var _ LLM = googleImpl{} - -func (g googleImpl) ModelVersion(modelVersion string) (ChatCompletion, error) { - g.model = modelVersion - - return g, nil -} - -func (g googleImpl) requestToContents(in Request) ([]*genai.Content, *genai.GenerateContentConfig) { - var contents []*genai.Content - var cfg genai.GenerateContentConfig - - for _, tool := range in.Toolbox.Functions() { - cfg.Tools = append(cfg.Tools, &genai.Tool{ - FunctionDeclarations: []*genai.FunctionDeclaration{ - { - Name: tool.Name, - Description: tool.Description, - Parameters: tool.Parameters.GoogleParameters(), - }, - }, - }) - } - - if in.Toolbox.RequiresTool() { - cfg.ToolConfig = &genai.ToolConfig{FunctionCallingConfig: &genai.FunctionCallingConfig{ - Mode: genai.FunctionCallingConfigModeAny, - }} - } - - for _, c := range in.Messages { - var role genai.Role - switch c.Role { - case RoleAssistant, RoleSystem: - role = genai.RoleModel - case RoleUser: - role = genai.RoleUser - } - - var parts []*genai.Part - if c.Text != "" { - parts = append(parts, genai.NewPartFromText(c.Text)) - } - - for _, img := range c.Images { - if img.Url != "" { - // gemini does not support URLs, so we need to download the image and convert it to a blob - resp, err := http.Get(img.Url) - if err != nil { - panic(fmt.Sprintf("error downloading image: %v", err)) - } - defer resp.Body.Close() - - if resp.ContentLength > 20*1024*1024 { - panic(fmt.Sprintf("image size exceeds 20MB: %d bytes", resp.ContentLength)) - } - - data, err := io.ReadAll(resp.Body) - if err != nil { - panic(fmt.Sprintf("error reading image data: %v", err)) - } - - mimeType := http.DetectContentType(data) - switch mimeType { - case "image/jpeg", "image/png", "image/gif": - // MIME type is valid - default: - panic(fmt.Sprintf("unsupported image MIME type: %s", mimeType)) - } - - parts = append(parts, genai.NewPartFromBytes(data, mimeType)) - } else { - b, e := base64.StdEncoding.DecodeString(img.Base64) - if e != nil { - panic(fmt.Sprintf("error decoding base64: %v", e)) - } - - parts = append(parts, genai.NewPartFromBytes(b, img.ContentType)) - } - } - - contents = append(contents, genai.NewContentFromParts(parts, role)) - } - - return contents, &cfg -} - -func (g googleImpl) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) { - res := Response{} - - for _, c := range in.Candidates { - var choice ResponseChoice - var set = false - if c.Content != nil { - for _, p := range c.Content.Parts { - if p.Text != "" { - set = true - choice.Content = p.Text - } else if p.FunctionCall != nil { - v := p.FunctionCall - b, e := json.Marshal(v.Args) - if e != nil { - return Response{}, fmt.Errorf("error marshalling args: %w", e) - } - - call := ToolCall{ - ID: v.Name, - FunctionCall: FunctionCall{ - Name: v.Name, - Arguments: string(b), - }, - } - - choice.Calls = append(choice.Calls, call) - set = true - } - } - } - - if set { - choice.Role = RoleAssistant - res.Choices = append(res.Choices, choice) - } - } - - return res, nil -} - -func (g googleImpl) ChatComplete(ctx context.Context, req Request) (Response, error) { - cl, err := genai.NewClient(ctx, &genai.ClientConfig{ - APIKey: g.key, - Backend: genai.BackendGeminiAPI, - }) - - if err != nil { - return Response{}, fmt.Errorf("error creating genai client: %w", err) - } - - contents, cfg := g.requestToContents(req) - - resp, err := cl.Models.GenerateContent(ctx, g.model, contents, cfg) - if err != nil { - return Response{}, fmt.Errorf("error generating content: %w", err) - } - - return g.responseToLLMResponse(resp) -} diff --git a/internal/imageutil/compress.go b/internal/imageutil/compress.go deleted file mode 100644 index cba4d25..0000000 --- a/internal/imageutil/compress.go +++ /dev/null @@ -1,114 +0,0 @@ -package imageutil - -import ( - "bytes" - "encoding/base64" - "fmt" - "image" - "image/gif" - "image/jpeg" - "net/http" - - "golang.org/x/image/draw" -) - -// CompressImage takes a base-64-encoded image (JPEG, PNG or GIF) and returns -// a base-64-encoded version that is at most maxLength in size, or an error. -func CompressImage(b64 string, maxLength int) (string, string, error) { - raw, err := base64.StdEncoding.DecodeString(b64) - if err != nil { - return "", "", fmt.Errorf("base64 decode: %w", err) - } - - mime := http.DetectContentType(raw) - if len(raw) <= maxLength { - return b64, mime, nil // small enough already - } - - switch mime { - case "image/gif": - return compressGIF(raw, maxLength) - - default: // jpeg, png, webp, etc. -> treat as raster - return compressRaster(raw, maxLength) - } -} - -// ---------- Raster path (jpeg / png / single-frame gif) ---------- - -func compressRaster(src []byte, maxLength int) (string, string, error) { - img, _, err := image.Decode(bytes.NewReader(src)) - if err != nil { - return "", "", fmt.Errorf("decode raster: %w", err) - } - - quality := 95 - for { - var buf bytes.Buffer - if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: quality}); err != nil { - return "", "", fmt.Errorf("jpeg encode: %w", err) - } - if buf.Len() <= maxLength { - return base64.StdEncoding.EncodeToString(buf.Bytes()), "image/jpeg", nil - } - - if quality > 20 { - quality -= 5 - continue - } - - // down-scale 80% - b := img.Bounds() - if b.Dx() < 100 || b.Dy() < 100 { - return "", "", fmt.Errorf("cannot compress below %.02fMiB without destroying image", float64(maxLength)/1048576.0) - } - dst := image.NewRGBA(image.Rect(0, 0, int(float64(b.Dx())*0.8), int(float64(b.Dy())*0.8))) - draw.ApproxBiLinear.Scale(dst, dst.Bounds(), img, b, draw.Over, nil) - img = dst - quality = 95 // restart ladder - } -} - -// ---------- Animated GIF path ---------- - -func compressGIF(src []byte, maxLength int) (string, string, error) { - g, err := gif.DecodeAll(bytes.NewReader(src)) - if err != nil { - return "", "", fmt.Errorf("gif decode: %w", err) - } - - for { - var buf bytes.Buffer - if err := gif.EncodeAll(&buf, g); err != nil { - return "", "", fmt.Errorf("gif encode: %w", err) - } - if buf.Len() <= maxLength { - return base64.StdEncoding.EncodeToString(buf.Bytes()), "image/gif", nil - } - - // down-scale every frame by 80% - w, h := g.Config.Width, g.Config.Height - if w < 100 || h < 100 { - return "", "", fmt.Errorf("cannot compress animated GIF below 5 MiB without excessive quality loss") - } - - nw, nh := int(float64(w)*0.8), int(float64(h)*0.8) - for i, frm := range g.Image { - // convert paletted frame -> RGBA for scaling - rgba := image.NewRGBA(frm.Bounds()) - draw.Draw(rgba, rgba.Bounds(), frm, frm.Bounds().Min, draw.Src) - - // scaled destination - dst := image.NewRGBA(image.Rect(0, 0, nw, nh)) - draw.ApproxBiLinear.Scale(dst, dst.Bounds(), rgba, rgba.Bounds(), draw.Over, nil) - - // quantize back to paletted using default encoder quantizer - paletted := image.NewPaletted(dst.Bounds(), nil) - draw.FloydSteinberg.Draw(paletted, paletted.Bounds(), dst, dst.Bounds().Min) - - g.Image[i] = paletted - } - g.Config.Width, g.Config.Height = nw, nh - // loop back and test size again ... - } -} diff --git a/llm.go b/llm.go deleted file mode 100644 index 3038430..0000000 --- a/llm.go +++ /dev/null @@ -1,30 +0,0 @@ -package llm - -import ( - "context" -) - -// ChatCompletion is the interface for chat completion. -type ChatCompletion interface { - ChatComplete(ctx context.Context, req Request) (Response, error) -} - -// LLM is the interface for language model providers. -type LLM interface { - ModelVersion(modelVersion string) (ChatCompletion, error) -} - -// OpenAI creates a new OpenAI LLM provider with the given API key. -func OpenAI(key string) LLM { - return openaiImpl{key: key} -} - -// Anthropic creates a new Anthropic LLM provider with the given API key. -func Anthropic(key string) LLM { - return anthropicImpl{key: key} -} - -// Google creates a new Google LLM provider with the given API key. -func Google(key string) LLM { - return googleImpl{key: key} -} diff --git a/mcp.go b/mcp.go deleted file mode 100644 index 342a6fb..0000000 --- a/mcp.go +++ /dev/null @@ -1,238 +0,0 @@ -package llm - -import ( - "context" - "encoding/json" - "fmt" - "os" - "os/exec" - "sync" - - "github.com/modelcontextprotocol/go-sdk/mcp" - - "gitea.stevedudenhoeffer.com/steve/go-llm/schema" -) - -// MCPServer represents a connection to an MCP server. -// It manages the lifecycle of the connection and provides access to the server's tools. -type MCPServer struct { - // Name is a friendly name for this server (used for logging/identification) - Name string - - // Command is the command to run the MCP server (for stdio transport) - Command string - - // Args are arguments to pass to the command - Args []string - - // Env are environment variables to set for the command (in addition to current environment) - Env []string - - // URL is the URL for SSE or HTTP transport (alternative to Command) - URL string - - // Transport specifies the transport type: "stdio" (default), "sse", or "http" - Transport string - - client *mcp.Client - session *mcp.ClientSession - tools map[string]*mcp.Tool // tool name -> tool definition - mu sync.RWMutex -} - -// Connect establishes a connection to the MCP server. -func (m *MCPServer) Connect(ctx context.Context) error { - m.mu.Lock() - defer m.mu.Unlock() - - if m.session != nil { - return nil // Already connected - } - - m.client = mcp.NewClient(&mcp.Implementation{ - Name: "go-llm", - Version: "1.0.0", - }, nil) - - var transport mcp.Transport - - switch m.Transport { - case "sse": - transport = &mcp.SSEClientTransport{ - Endpoint: m.URL, - } - case "http": - transport = &mcp.StreamableClientTransport{ - Endpoint: m.URL, - } - default: // "stdio" or empty - cmd := exec.Command(m.Command, m.Args...) - cmd.Env = append(os.Environ(), m.Env...) - transport = &mcp.CommandTransport{ - Command: cmd, - } - } - - session, err := m.client.Connect(ctx, transport, nil) - if err != nil { - return fmt.Errorf("failed to connect to MCP server %s: %w", m.Name, err) - } - - m.session = session - - // Load tools - m.tools = make(map[string]*mcp.Tool) - for tool, err := range session.Tools(ctx, nil) { - if err != nil { - m.session.Close() - m.session = nil - return fmt.Errorf("failed to list tools from %s: %w", m.Name, err) - } - m.tools[tool.Name] = tool - } - - return nil -} - -// Close closes the connection to the MCP server. -func (m *MCPServer) Close() error { - m.mu.Lock() - defer m.mu.Unlock() - - if m.session == nil { - return nil - } - - err := m.session.Close() - m.session = nil - m.tools = nil - return err -} - -// IsConnected returns true if the server is connected. -func (m *MCPServer) IsConnected() bool { - m.mu.RLock() - defer m.mu.RUnlock() - return m.session != nil -} - -// Tools returns the list of tool names available from this server. -func (m *MCPServer) Tools() []string { - m.mu.RLock() - defer m.mu.RUnlock() - - var names []string - for name := range m.tools { - names = append(names, name) - } - return names -} - -// HasTool returns true if this server provides the named tool. -func (m *MCPServer) HasTool(name string) bool { - m.mu.RLock() - defer m.mu.RUnlock() - _, ok := m.tools[name] - return ok -} - -// CallTool calls a tool on the MCP server. -func (m *MCPServer) CallTool(ctx context.Context, name string, arguments map[string]any) (any, error) { - m.mu.RLock() - session := m.session - m.mu.RUnlock() - - if session == nil { - return nil, fmt.Errorf("not connected to MCP server %s", m.Name) - } - - result, err := session.CallTool(ctx, &mcp.CallToolParams{ - Name: name, - Arguments: arguments, - }) - if err != nil { - return nil, err - } - - // Process the result content - if len(result.Content) == 0 { - return nil, nil - } - - // If there's a single text content, return it as a string - if len(result.Content) == 1 { - if textContent, ok := result.Content[0].(*mcp.TextContent); ok { - return textContent.Text, nil - } - } - - // For multiple contents or non-text, serialize to string - return contentToString(result.Content), nil -} - -// toFunction converts an MCP tool to a go-llm Function (for schema purposes only). -func (m *MCPServer) toFunction(tool *mcp.Tool) Function { - var inputSchema map[string]any - if tool.InputSchema != nil { - data, err := json.Marshal(tool.InputSchema) - if err == nil { - _ = json.Unmarshal(data, &inputSchema) - } - } - - if inputSchema == nil { - inputSchema = map[string]any{ - "type": "object", - "properties": map[string]any{}, - } - } - - return Function{ - Name: tool.Name, - Description: tool.Description, - Parameters: schema.NewRaw(inputSchema), - } -} - -// contentToString converts MCP content to a string representation. -func contentToString(content []mcp.Content) string { - var parts []string - for _, c := range content { - switch tc := c.(type) { - case *mcp.TextContent: - parts = append(parts, tc.Text) - default: - if data, err := json.Marshal(c); err == nil { - parts = append(parts, string(data)) - } - } - } - if len(parts) == 1 { - return parts[0] - } - data, _ := json.Marshal(parts) - return string(data) -} - -// WithMCPServer adds an MCP server to the toolbox. -// The server must already be connected. Tools from the server will be available -// for use, and tool calls will be routed to the appropriate server. -func (t ToolBox) WithMCPServer(server *MCPServer) ToolBox { - if t.mcpServers == nil { - t.mcpServers = make(map[string]*MCPServer) - } - - server.mu.RLock() - defer server.mu.RUnlock() - - for name, tool := range server.tools { - // Add the function definition (for schema) - fn := server.toFunction(tool) - t.functions[name] = fn - - // Track which server owns this tool - t.mcpServers[name] = server - } - - return t -} diff --git a/message.go b/message.go deleted file mode 100644 index b1296ba..0000000 --- a/message.go +++ /dev/null @@ -1,115 +0,0 @@ -package llm - -// Role represents the role of a message in a conversation. -type Role string - -const ( - RoleSystem Role = "system" - RoleUser Role = "user" - RoleAssistant Role = "assistant" -) - -// Image represents an image that can be included in a message. -type Image struct { - Base64 string - ContentType string - Url string -} - -func (i Image) toRaw() map[string]any { - res := map[string]any{ - "base64": i.Base64, - "contenttype": i.ContentType, - "url": i.Url, - } - - return res -} - -func (i *Image) fromRaw(raw map[string]any) Image { - var res Image - - res.Base64 = raw["base64"].(string) - res.ContentType = raw["contenttype"].(string) - res.Url = raw["url"].(string) - - return res -} - -// Message represents a message in a conversation. -type Message struct { - Role Role - Name string - Text string - Images []Image -} - -func (m Message) toRaw() map[string]any { - res := map[string]any{ - "role": m.Role, - "name": m.Name, - "text": m.Text, - } - - images := make([]map[string]any, 0, len(m.Images)) - for _, img := range m.Images { - images = append(images, img.toRaw()) - } - - res["images"] = images - - return res -} - -func (m *Message) fromRaw(raw map[string]any) Message { - var res Message - - res.Role = Role(raw["role"].(string)) - res.Name = raw["name"].(string) - res.Text = raw["text"].(string) - - images := raw["images"].([]map[string]any) - for _, img := range images { - var i Image - - res.Images = append(res.Images, i.fromRaw(img)) - } - - return res -} - -// ToolCall represents a tool call made by an assistant. -type ToolCall struct { - ID string - FunctionCall FunctionCall -} - -func (t ToolCall) toRaw() map[string]any { - res := map[string]any{ - "id": t.ID, - } - - res["function"] = t.FunctionCall.toRaw() - - return res -} - -// ToolCallResponse represents the response to a tool call. -type ToolCallResponse struct { - ID string - Result any - Error error -} - -func (t ToolCallResponse) toRaw() map[string]any { - res := map[string]any{ - "id": t.ID, - "result": t.Result, - } - - if t.Error != nil { - res["error"] = t.Error.Error() - } - - return res -} diff --git a/openai.go b/openai.go deleted file mode 100644 index 5d71402..0000000 --- a/openai.go +++ /dev/null @@ -1,322 +0,0 @@ -package llm - -import ( - "context" - "fmt" - "strings" - - "github.com/openai/openai-go" - "github.com/openai/openai-go/option" - "github.com/openai/openai-go/packages/param" - "github.com/openai/openai-go/shared" -) - -type openaiImpl struct { - key string - model string - baseUrl string -} - -var _ LLM = openaiImpl{} - -func (o openaiImpl) newRequestToOpenAIRequest(request Request) openai.ChatCompletionNewParams { - res := openai.ChatCompletionNewParams{ - Model: o.model, - } - - for _, i := range request.Conversation { - res.Messages = append(res.Messages, inputToChatCompletionMessages(i, o.model)...) - } - - for _, msg := range request.Messages { - res.Messages = append(res.Messages, messageToChatCompletionMessages(msg, o.model)...) - } - - for _, tool := range request.Toolbox.Functions() { - res.Tools = append(res.Tools, openai.ChatCompletionToolParam{ - Type: "function", - Function: shared.FunctionDefinitionParam{ - Name: tool.Name, - Description: openai.String(tool.Description), - Strict: openai.Bool(tool.Strict), - Parameters: tool.Parameters.OpenAIParameters(), - }, - }) - } - - if request.Toolbox.RequiresTool() { - res.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{ - OfAuto: openai.String("required"), - } - } - - if request.Temperature != nil { - // these are known models that do not support custom temperatures - // all the o* models - // gpt-5* models - if !strings.HasPrefix(o.model, "o") && !strings.HasPrefix(o.model, "gpt-5") { - res.Temperature = openai.Float(*request.Temperature) - } - } - - return res -} - -func (o openaiImpl) responseToLLMResponse(response *openai.ChatCompletion) Response { - var res Response - - if response == nil { - return res - } - - if len(response.Choices) == 0 { - return res - } - - for _, choice := range response.Choices { - var toolCalls []ToolCall - for _, call := range choice.Message.ToolCalls { - toolCall := ToolCall{ - ID: call.ID, - FunctionCall: FunctionCall{ - Name: call.Function.Name, - Arguments: strings.TrimSpace(call.Function.Arguments), - }, - } - - toolCalls = append(toolCalls, toolCall) - - } - res.Choices = append(res.Choices, ResponseChoice{ - Content: choice.Message.Content, - Role: Role(choice.Message.Role), - Refusal: choice.Message.Refusal, - Calls: toolCalls, - }) - } - - return res -} - -func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) { - var opts = []option.RequestOption{ - option.WithAPIKey(o.key), - } - - if o.baseUrl != "" { - opts = append(opts, option.WithBaseURL(o.baseUrl)) - } - - cl := openai.NewClient(opts...) - - req := o.newRequestToOpenAIRequest(request) - - resp, err := cl.Chat.Completions.New(ctx, req) - - if err != nil { - return Response{}, fmt.Errorf("unhandled openai error: %w", err) - } - - return o.responseToLLMResponse(resp), nil -} - -func (o openaiImpl) ModelVersion(modelVersion string) (ChatCompletion, error) { - return openaiImpl{ - key: o.key, - model: modelVersion, - baseUrl: o.baseUrl, - }, nil -} - -// inputToChatCompletionMessages converts an Input to OpenAI chat completion messages. -func inputToChatCompletionMessages(input Input, model string) []openai.ChatCompletionMessageParamUnion { - switch v := input.(type) { - case Message: - return messageToChatCompletionMessages(v, model) - case ToolCall: - return toolCallToChatCompletionMessages(v) - case ToolCallResponse: - return toolCallResponseToChatCompletionMessages(v) - case ResponseChoice: - return responseChoiceToChatCompletionMessages(v) - default: - return nil - } -} - -func messageToChatCompletionMessages(m Message, model string) []openai.ChatCompletionMessageParamUnion { - var res openai.ChatCompletionMessageParamUnion - - var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam - var textContent param.Opt[string] - - for _, img := range m.Images { - if img.Base64 != "" { - arrayOfContentParts = append(arrayOfContentParts, - openai.ChatCompletionContentPartUnionParam{ - OfImageURL: &openai.ChatCompletionContentPartImageParam{ - ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ - URL: "data:" + img.ContentType + ";base64," + img.Base64, - }, - }, - }, - ) - } else if img.Url != "" { - arrayOfContentParts = append(arrayOfContentParts, - openai.ChatCompletionContentPartUnionParam{ - OfImageURL: &openai.ChatCompletionContentPartImageParam{ - ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ - URL: img.Url, - }, - }, - }, - ) - } - } - - if m.Text != "" { - if len(arrayOfContentParts) > 0 { - arrayOfContentParts = append(arrayOfContentParts, - openai.ChatCompletionContentPartUnionParam{ - OfText: &openai.ChatCompletionContentPartTextParam{ - Text: "\n", - }, - }, - ) - } else { - textContent = openai.String(m.Text) - } - } - - a := strings.Split(model, "-") - - useSystemInsteadOfDeveloper := true - if len(a) > 1 && a[0][0] == 'o' { - useSystemInsteadOfDeveloper = false - } - - switch m.Role { - case RoleSystem: - if useSystemInsteadOfDeveloper { - res = openai.ChatCompletionMessageParamUnion{ - OfSystem: &openai.ChatCompletionSystemMessageParam{ - Content: openai.ChatCompletionSystemMessageParamContentUnion{ - OfString: textContent, - }, - }, - } - } else { - res = openai.ChatCompletionMessageParamUnion{ - OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{ - Content: openai.ChatCompletionDeveloperMessageParamContentUnion{ - OfString: textContent, - }, - }, - } - } - - case RoleUser: - var name param.Opt[string] - if m.Name != "" { - name = openai.String(m.Name) - } - - res = openai.ChatCompletionMessageParamUnion{ - OfUser: &openai.ChatCompletionUserMessageParam{ - Name: name, - Content: openai.ChatCompletionUserMessageParamContentUnion{ - OfString: textContent, - OfArrayOfContentParts: arrayOfContentParts, - }, - }, - } - - case RoleAssistant: - var name param.Opt[string] - if m.Name != "" { - name = openai.String(m.Name) - } - - res = openai.ChatCompletionMessageParamUnion{ - OfAssistant: &openai.ChatCompletionAssistantMessageParam{ - Name: name, - Content: openai.ChatCompletionAssistantMessageParamContentUnion{ - OfString: textContent, - }, - }, - } - } - - return []openai.ChatCompletionMessageParamUnion{res} -} - -func toolCallToChatCompletionMessages(t ToolCall) []openai.ChatCompletionMessageParamUnion { - return []openai.ChatCompletionMessageParamUnion{{ - OfAssistant: &openai.ChatCompletionAssistantMessageParam{ - ToolCalls: []openai.ChatCompletionMessageToolCallParam{ - { - ID: t.ID, - Function: openai.ChatCompletionMessageToolCallFunctionParam{ - Name: t.FunctionCall.Name, - Arguments: t.FunctionCall.Arguments, - }, - }, - }, - }, - }} -} - -func toolCallResponseToChatCompletionMessages(t ToolCallResponse) []openai.ChatCompletionMessageParamUnion { - var refusal string - if t.Error != nil { - refusal = t.Error.Error() - } - - result := t.Result - if refusal != "" { - if result != "" { - result = fmt.Sprint(result) + " (error in execution: " + refusal + ")" - } else { - result = "error in execution:" + refusal - } - } - - return []openai.ChatCompletionMessageParamUnion{{ - OfTool: &openai.ChatCompletionToolMessageParam{ - ToolCallID: t.ID, - Content: openai.ChatCompletionToolMessageParamContentUnion{ - OfString: openai.String(fmt.Sprint(result)), - }, - }, - }} -} - -func responseChoiceToChatCompletionMessages(r ResponseChoice) []openai.ChatCompletionMessageParamUnion { - var as openai.ChatCompletionAssistantMessageParam - - if r.Name != "" { - as.Name = openai.String(r.Name) - } - if r.Refusal != "" { - as.Refusal = openai.String(r.Refusal) - } - - if r.Content != "" { - as.Content.OfString = openai.String(r.Content) - } - - for _, call := range r.Calls { - as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{ - ID: call.ID, - Function: openai.ChatCompletionMessageToolCallFunctionParam{ - Name: call.FunctionCall.Name, - Arguments: call.FunctionCall.Arguments, - }, - }) - } - return []openai.ChatCompletionMessageParamUnion{ - { - OfAssistant: &as, - }, - } -} diff --git a/openai_transcriber.go b/openai_transcriber.go deleted file mode 100644 index 9afe434..0000000 --- a/openai_transcriber.go +++ /dev/null @@ -1,219 +0,0 @@ -package llm - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "strings" - - "github.com/openai/openai-go" - "github.com/openai/openai-go/option" -) - -type openaiTranscriber struct { - key string - model string - baseUrl string -} - -var _ Transcriber = openaiTranscriber{} - -// OpenAITranscriber creates a transcriber backed by OpenAI's audio models. -// If model is empty, whisper-1 is used by default. -func OpenAITranscriber(key string, model string) Transcriber { - if strings.TrimSpace(model) == "" { - model = "whisper-1" - } - return openaiTranscriber{ - key: key, - model: model, - } -} - -func (o openaiTranscriber) Transcribe(ctx context.Context, wav []byte, opts TranscriptionOptions) (Transcription, error) { - if len(wav) == 0 { - return Transcription{}, fmt.Errorf("wav data is empty") - } - - format := opts.ResponseFormat - if format == "" { - if strings.HasPrefix(o.model, "gpt-4o") { - format = TranscriptionResponseFormatJSON - } else { - format = TranscriptionResponseFormatVerboseJSON - } - } - - if format != TranscriptionResponseFormatJSON && format != TranscriptionResponseFormatVerboseJSON { - return Transcription{}, fmt.Errorf("openai transcriber requires response_format json or verbose_json for structured output") - } - - if len(opts.TimestampGranularities) > 0 && format != TranscriptionResponseFormatVerboseJSON { - return Transcription{}, fmt.Errorf("timestamp granularities require response_format=verbose_json") - } - - params := openai.AudioTranscriptionNewParams{ - File: openai.File(bytes.NewReader(wav), "audio.wav", "audio/wav"), - Model: openai.AudioModel(o.model), - } - - if opts.Language != "" { - params.Language = openai.String(opts.Language) - } - if opts.Prompt != "" { - params.Prompt = openai.String(opts.Prompt) - } - if opts.Temperature != nil { - params.Temperature = openai.Float(*opts.Temperature) - } - - params.ResponseFormat = openai.AudioResponseFormat(format) - - if opts.IncludeLogprobs { - params.Include = []openai.TranscriptionInclude{openai.TranscriptionIncludeLogprobs} - } - - if len(opts.TimestampGranularities) > 0 { - for _, granularity := range opts.TimestampGranularities { - params.TimestampGranularities = append(params.TimestampGranularities, string(granularity)) - } - } - - clientOptions := []option.RequestOption{ - option.WithAPIKey(o.key), - } - if o.baseUrl != "" { - clientOptions = append(clientOptions, option.WithBaseURL(o.baseUrl)) - } - - client := openai.NewClient(clientOptions...) - resp, err := client.Audio.Transcriptions.New(ctx, params) - if err != nil { - return Transcription{}, fmt.Errorf("openai transcription failed: %w", err) - } - - return openaiTranscriptionToResult(o.model, resp), nil -} - -type openaiVerboseTranscription struct { - Text string `json:"text"` - Language string `json:"language"` - Duration float64 `json:"duration"` - Segments []openaiVerboseSegment `json:"segments"` - Words []openaiVerboseWord `json:"words"` -} - -type openaiVerboseSegment struct { - ID int `json:"id"` - Start float64 `json:"start"` - End float64 `json:"end"` - Text string `json:"text"` - Tokens []int `json:"tokens"` - AvgLogprob *float64 `json:"avg_logprob"` - CompressionRatio *float64 `json:"compression_ratio"` - NoSpeechProb *float64 `json:"no_speech_prob"` - Words []openaiVerboseWord `json:"words"` -} - -type openaiVerboseWord struct { - Word string `json:"word"` - Start float64 `json:"start"` - End float64 `json:"end"` -} - -func openaiTranscriptionToResult(model string, resp *openai.Transcription) Transcription { - result := Transcription{ - Provider: "openai", - Model: model, - } - if resp == nil { - return result - } - - result.Text = resp.Text - result.RawJSON = resp.RawJSON() - - for _, logprob := range resp.Logprobs { - result.Logprobs = append(result.Logprobs, TranscriptionTokenLogprob{ - Token: logprob.Token, - Bytes: logprob.Bytes, - Logprob: logprob.Logprob, - }) - } - - if usage := openaiUsageToTranscriptionUsage(resp.Usage); usage.Type != "" { - result.Usage = usage - } - - if result.RawJSON == "" { - return result - } - - var verbose openaiVerboseTranscription - if err := json.Unmarshal([]byte(result.RawJSON), &verbose); err != nil { - return result - } - - if verbose.Text != "" { - result.Text = verbose.Text - } - result.Language = verbose.Language - result.DurationSeconds = verbose.Duration - - for _, seg := range verbose.Segments { - segment := TranscriptionSegment{ - ID: seg.ID, - Start: seg.Start, - End: seg.End, - Text: seg.Text, - Tokens: append([]int(nil), seg.Tokens...), - AvgLogprob: seg.AvgLogprob, - CompressionRatio: seg.CompressionRatio, - NoSpeechProb: seg.NoSpeechProb, - } - - for _, word := range seg.Words { - segment.Words = append(segment.Words, TranscriptionWord{ - Word: word.Word, - Start: word.Start, - End: word.End, - }) - } - - result.Segments = append(result.Segments, segment) - } - - for _, word := range verbose.Words { - result.Words = append(result.Words, TranscriptionWord{ - Word: word.Word, - Start: word.Start, - End: word.End, - }) - } - - return result -} - -func openaiUsageToTranscriptionUsage(usage openai.TranscriptionUsageUnion) TranscriptionUsage { - switch usage.Type { - case "tokens": - tokens := usage.AsTokens() - return TranscriptionUsage{ - Type: usage.Type, - InputTokens: tokens.InputTokens, - OutputTokens: tokens.OutputTokens, - TotalTokens: tokens.TotalTokens, - AudioTokens: tokens.InputTokenDetails.AudioTokens, - TextTokens: tokens.InputTokenDetails.TextTokens, - } - case "duration": - duration := usage.AsDuration() - return TranscriptionUsage{ - Type: usage.Type, - Seconds: duration.Seconds, - } - default: - return TranscriptionUsage{} - } -} diff --git a/parse.go b/parse.go deleted file mode 100644 index a9614ff..0000000 --- a/parse.go +++ /dev/null @@ -1,50 +0,0 @@ -package llm - -import ( - "strings" -) - -// Providers are the allowed shortcuts in the providers, e.g.: if you set { "openai": OpenAI("key") } that'll allow -// for the "openai" provider to be used when parsed. -type Providers map[string]LLM - -// Parse will parse the provided input and attempt to return a LLM chat completion interface. -// Input should be in the provided format: -// - provider/modelname -// -// where provider is a key inside Providers, and the modelname being passed to the LLM interface's GetModel -func (providers Providers) Parse(input string) ChatCompletion { - sections := strings.Split(input, "/") - - var provider LLM - var ok bool - var modelVersion string - - if len(sections) < 2 { - // is there a default provider? - provider, ok = providers["default"] - if !ok { - panic("expected format: \"provider/model\" or provide a \"default\" provider to the Parse callback") - } - - modelVersion = sections[0] - } else { - provider, ok = providers[sections[0]] - modelVersion = sections[1] - } - - if !ok { - panic("expected format: \"provider/model\" or provide a \"default\" provider to the Parse callback") - } - - if provider == nil { - panic("unknown provider: " + sections[0]) - } - - res, err := provider.ModelVersion(modelVersion) - if err != nil { - panic(err) - } - - return res -} diff --git a/provider/anthropic/anthropic.go b/provider/anthropic/anthropic.go deleted file mode 100644 index 32e13e4..0000000 --- a/provider/anthropic/anthropic.go +++ /dev/null @@ -1,11 +0,0 @@ -// Package anthropic provides the Anthropic LLM provider. -package anthropic - -import ( - llm "gitea.stevedudenhoeffer.com/steve/go-llm" -) - -// New creates a new Anthropic LLM provider with the given API key. -func New(key string) llm.LLM { - return llm.Anthropic(key) -} diff --git a/provider/google/google.go b/provider/google/google.go deleted file mode 100644 index df5fe72..0000000 --- a/provider/google/google.go +++ /dev/null @@ -1,11 +0,0 @@ -// Package google provides the Google LLM provider. -package google - -import ( - llm "gitea.stevedudenhoeffer.com/steve/go-llm" -) - -// New creates a new Google LLM provider with the given API key. -func New(key string) llm.LLM { - return llm.Google(key) -} diff --git a/provider/openai/openai.go b/provider/openai/openai.go deleted file mode 100644 index 2c853f9..0000000 --- a/provider/openai/openai.go +++ /dev/null @@ -1,11 +0,0 @@ -// Package openai provides the OpenAI LLM provider. -package openai - -import ( - llm "gitea.stevedudenhoeffer.com/steve/go-llm" -) - -// New creates a new OpenAI LLM provider with the given API key. -func New(key string) llm.LLM { - return llm.OpenAI(key) -} diff --git a/request.go b/request.go deleted file mode 100644 index 43e430f..0000000 --- a/request.go +++ /dev/null @@ -1,51 +0,0 @@ -package llm - -// Input is the interface for conversation inputs. -// Types that implement this interface can be part of a conversation: -// Message, ToolCall, ToolCallResponse, and ResponseChoice. -type Input interface { - // isInput is a marker method to ensure only valid types implement this interface. - isInput() -} - -// Implement Input interface for all valid input types. -func (Message) isInput() {} -func (ToolCall) isInput() {} -func (ToolCallResponse) isInput() {} -func (ResponseChoice) isInput() {} - -// Request represents a request to a language model. -type Request struct { - Conversation []Input - Messages []Message - Toolbox ToolBox - Temperature *float64 -} - -// NextRequest will take the current request's conversation, messages, the response, and any tool results, and -// return a new request with the conversation updated to include the response and tool results. -func (req Request) NextRequest(resp ResponseChoice, toolResults []ToolCallResponse) Request { - var res Request - - res.Toolbox = req.Toolbox - res.Temperature = req.Temperature - - res.Conversation = make([]Input, len(req.Conversation)) - copy(res.Conversation, req.Conversation) - - // now for every input message, convert those to an Input to add to the conversation - for _, msg := range req.Messages { - res.Conversation = append(res.Conversation, msg) - } - - if resp.Content != "" || resp.Refusal != "" || len(resp.Calls) > 0 { - res.Conversation = append(res.Conversation, resp) - } - - // if there are tool results, then we need to add those to the conversation - for _, result := range toolResults { - res.Conversation = append(res.Conversation, result) - } - - return res -} diff --git a/response.go b/response.go deleted file mode 100644 index e7043b8..0000000 --- a/response.go +++ /dev/null @@ -1,52 +0,0 @@ -package llm - -// ResponseChoice represents a single choice in a response. -type ResponseChoice struct { - Index int - Role Role - Content string - Refusal string - Name string - Calls []ToolCall -} - -func (r ResponseChoice) toRaw() map[string]any { - res := map[string]any{ - "index": r.Index, - "role": r.Role, - "content": r.Content, - "refusal": r.Refusal, - "name": r.Name, - } - - calls := make([]map[string]any, 0, len(r.Calls)) - for _, call := range r.Calls { - calls = append(calls, call.toRaw()) - } - - res["tool_calls"] = calls - - return res -} - -func (r ResponseChoice) toInput() []Input { - var res []Input - - for _, call := range r.Calls { - res = append(res, call) - } - - if r.Content != "" || r.Refusal != "" { - res = append(res, Message{ - Role: RoleAssistant, - Text: r.Content, - }) - } - - return res -} - -// Response represents a response from a language model. -type Response struct { - Choices []ResponseChoice -} diff --git a/schema/GetType.go b/schema/GetType.go deleted file mode 100644 index 342dea2..0000000 --- a/schema/GetType.go +++ /dev/null @@ -1,142 +0,0 @@ -package schema - -import ( - "reflect" - "strings" -) - -// GetType will, given an interface{} that is a struct (NOT a pointer to a struct), return the Type of the struct that -// can be used to generate a json schema and build an object from a parsed json object. -func GetType(a any) Type { - t := reflect.TypeOf(a) - - if t.Kind() != reflect.Struct { - panic("GetType expects a struct") - } - - return getObject(t) -} - -func getFromType(t reflect.Type, b basic) Type { - if t.Kind() == reflect.Ptr { - t = t.Elem() - b.required = false - } - - switch t.Kind() { - case reflect.String: - b.DataType = TypeString - b.typeName = "string" - return b - - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - b.DataType = TypeInteger - b.typeName = "integer" - return b - - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - b.DataType = TypeInteger - b.typeName = "integer" - return b - - case reflect.Float32, reflect.Float64: - b.DataType = TypeNumber - b.typeName = "number" - return b - - case reflect.Bool: - b.DataType = TypeBoolean - b.typeName = "boolean" - return b - - case reflect.Struct: - o := getObject(t) - - o.basic.required = b.required - o.basic.index = b.index - o.basic.description = b.description - - return o - - case reflect.Slice: - return getArray(t) - - default: - panic("unhandled default case for " + t.Kind().String() + " in getFromType") - } -} - -func getField(f reflect.StructField, index int) Type { - b := basic{ - index: index, - required: true, - description: "", - } - - t := f.Type - - // if the tag "description" is set, use that as the description - if desc, ok := f.Tag.Lookup("description"); ok { - b.description = desc - } - - // now if the tag "enum" is set, we need to create an enum type - if v, ok := f.Tag.Lookup("enum"); ok { - vals := strings.Split(v, ",") - - for i := 0; i < len(vals); i++ { - vals[i] = strings.TrimSpace(vals[i]) - - if vals[i] == "" { - vals = append(vals[:i], vals[i+1:]...) - } - } - - b.DataType = TypeString - b.typeName = "string" - return enum{ - basic: b, - values: vals, - } - - } - - return getFromType(t, b) -} - -func getObject(t reflect.Type) Object { - fields := make(map[string]Type, t.NumField()) - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - - if field.Anonymous { - // if the field is anonymous, we need to get the fields of the anonymous struct - // and add them to the object - anon := getObject(field.Type) - for k, v := range anon.fields { - fields[k] = v - } - continue - } else { - fields[field.Name] = getField(field, i) - } - } - - return Object{ - basic: basic{DataType: TypeObject, typeName: "object"}, - fields: fields, - } -} - -func getArray(t reflect.Type) array { - res := array{ - basic: basic{ - DataType: TypeArray, - typeName: "array", - }, - } - - res.items = getFromType(t.Elem(), basic{}) - - return res -} diff --git a/schema/array.go b/schema/array.go deleted file mode 100644 index c18c75c..0000000 --- a/schema/array.go +++ /dev/null @@ -1,77 +0,0 @@ -package schema - -import ( - "errors" - "reflect" - - "github.com/openai/openai-go" - "google.golang.org/genai" -) - -type array struct { - basic - - // items is the schema of the items in the array - items Type -} - -func (a array) OpenAIParameters() openai.FunctionParameters { - return openai.FunctionParameters{ - "type": "array", - "description": a.Description(), - "items": a.items.OpenAIParameters(), - } -} - -func (a array) GoogleParameters() *genai.Schema { - return &genai.Schema{ - Type: genai.TypeArray, - Description: a.Description(), - Items: a.items.GoogleParameters(), - } -} - -func (a array) AnthropicInputSchema() map[string]any { - return map[string]any{ - "type": "array", - "description": a.Description(), - "items": a.items.AnthropicInputSchema(), - } -} - -func (a array) FromAny(val any) (reflect.Value, error) { - v := reflect.ValueOf(val) - - // first realize we may have a pointer to a slice if this type is not required - if !a.required && v.Kind() == reflect.Ptr { - v = v.Elem() - } - - if v.Kind() != reflect.Slice { - return reflect.Value{}, errors.New("expected slice, got " + v.Kind().String()) - } - - // if the slice is nil, we can just return it - if v.IsNil() { - return v, nil - } - - // if the slice is not nil, we need to convert each item - items := make([]reflect.Value, v.Len()) - for i := 0; i < v.Len(); i++ { - item, err := a.items.FromAny(v.Index(i).Interface()) - if err != nil { - return reflect.Value{}, err - } - items[i] = item - } - - return reflect.ValueOf(items), nil -} - -func (a array) SetValue(obj reflect.Value, val reflect.Value) { - if !a.required { - val = val.Addr() - } - obj.Field(a.index).Set(val) -} diff --git a/schema/basic.go b/schema/basic.go deleted file mode 100644 index a38f5a4..0000000 --- a/schema/basic.go +++ /dev/null @@ -1,165 +0,0 @@ -package schema - -import ( - "errors" - "reflect" - "strconv" - - "github.com/openai/openai-go" - "google.golang.org/genai" -) - -// just enforcing that basic implements Type -var _ Type = basic{} - -type DataType string - -const ( - TypeString DataType = "string" - TypeInteger DataType = "integer" - TypeNumber DataType = "number" - TypeBoolean DataType = "boolean" - TypeObject DataType = "object" - TypeArray DataType = "array" -) - -type basic struct { - DataType - typeName string - - // index is the position of the parameter in the StructField of the function's parameter struct - index int - - // required is a flag that indicates whether the parameter is required in the function's parameter struct. - // this is inferred by if the parameter is a pointer type or not. - required bool - - // description is a llm-readable description of the parameter passed to openai - description string -} - -func (b basic) OpenAIParameters() openai.FunctionParameters { - return openai.FunctionParameters{ - "type": b.typeName, - "description": b.description, - } -} - -func (b basic) GoogleParameters() *genai.Schema { - var t = genai.TypeUnspecified - - switch b.DataType { - case TypeString: - t = genai.TypeString - case TypeInteger: - t = genai.TypeInteger - case TypeNumber: - t = genai.TypeNumber - case TypeBoolean: - t = genai.TypeBoolean - case TypeObject: - t = genai.TypeObject - case TypeArray: - t = genai.TypeArray - default: - t = genai.TypeUnspecified - } - return &genai.Schema{ - Type: t, - Description: b.description, - } -} - -func (b basic) AnthropicInputSchema() map[string]any { - var t = "string" - - switch b.DataType { - case TypeString: - t = "string" - case TypeInteger: - t = "integer" - case TypeNumber: - t = "number" - case TypeBoolean: - t = "boolean" - case TypeObject: - t = "object" - case TypeArray: - t = "array" - default: - t = "unknown" - } - - return map[string]any{ - "type": t, - "description": b.description, - } -} - -func (b basic) Required() bool { - return b.required -} - -func (b basic) Description() string { - return b.description -} - -func (b basic) FromAny(val any) (reflect.Value, error) { - v := reflect.ValueOf(val) - - switch b.DataType { - case TypeString: - var val = v.String() - - return reflect.ValueOf(val), nil - - case TypeInteger: - if v.Kind() == reflect.Float64 { - return v.Convert(reflect.TypeOf(int(0))), nil - } else if v.Kind() != reflect.Int { - return reflect.Value{}, errors.New("expected int, got " + v.Kind().String()) - } else { - return v, nil - } - - case TypeNumber: - if v.Kind() == reflect.Float64 { - return v.Convert(reflect.TypeOf(float64(0))), nil - } else if v.Kind() != reflect.Float64 { - return reflect.Value{}, errors.New("expected float64, got " + v.Kind().String()) - } else { - return v, nil - } - - case TypeBoolean: - if v.Kind() == reflect.Bool { - return v, nil - } else if v.Kind() == reflect.String { - b, err := strconv.ParseBool(v.String()) - if err != nil { - return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String()) - } - return reflect.ValueOf(b), nil - } else { - return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String()) - } - - default: - return reflect.Value{}, errors.New("unknown type") - } -} - -func (b basic) SetValueOnField(obj reflect.Value, val reflect.Value) { - // if this basic type is not required that means it's a pointer type - // so we need to create a new value of the type of the pointer - if !b.required { - vv := reflect.New(obj.Field(b.index).Type().Elem()) - - // and then set the value of the pointer to the new value - vv.Elem().Set(val) - - obj.Field(b.index).Set(vv) - return - } - obj.Field(b.index).Set(val) -} diff --git a/schema/enum.go b/schema/enum.go deleted file mode 100644 index 1b473a7..0000000 --- a/schema/enum.go +++ /dev/null @@ -1,61 +0,0 @@ -package schema - -import ( - "errors" - "reflect" - "slices" - - "github.com/openai/openai-go" - "google.golang.org/genai" -) - -type enum struct { - basic - - values []string -} - -func (e enum) FunctionParameters() openai.FunctionParameters { - return openai.FunctionParameters{ - "type": "string", - "description": e.Description(), - "enum": e.values, - } -} - -func (e enum) GoogleParameters() *genai.Schema { - return &genai.Schema{ - Type: genai.TypeString, - Description: e.Description(), - Enum: e.values, - } -} - -func (e enum) AnthropicInputSchema() map[string]any { - return map[string]any{ - "type": "string", - "description": e.Description(), - "enum": e.values, - } -} - -func (e enum) FromAny(val any) (reflect.Value, error) { - v := reflect.ValueOf(val) - if v.Kind() != reflect.String { - return reflect.Value{}, errors.New("expected string, got " + v.Kind().String()) - } - - s := v.String() - if !slices.Contains(e.values, s) { - return reflect.Value{}, errors.New("value " + s + " not in enum") - } - - return v, nil -} - -func (e enum) SetValueOnField(obj reflect.Value, val reflect.Value) { - if !e.required { - val = val.Addr() - } - obj.Field(e.index).Set(val) -} diff --git a/schema/object.go b/schema/object.go deleted file mode 100644 index 053baa8..0000000 --- a/schema/object.go +++ /dev/null @@ -1,169 +0,0 @@ -package schema - -import ( - "errors" - "reflect" - - "github.com/openai/openai-go" - "google.golang.org/genai" -) - -const ( - // SyntheticFieldPrefix is any prefix that is added to any synthetic fields that are added to the object, to prevent - // collisions with the fields in the struct. - SyntheticFieldPrefix = "__" -) - -type Object struct { - basic - - ref reflect.Type - - fields map[string]Type - - // syntheticFields are fields that are not in the struct but are generated by a system. - synetheticFields map[string]Type -} - -func (o Object) WithSyntheticField(name string, description string) Object { - if o.synetheticFields == nil { - o.synetheticFields = map[string]Type{} - } - - o.synetheticFields[name] = basic{ - DataType: TypeString, - typeName: "string", - index: -1, - required: false, - description: description, - } - - return o -} - -func (o Object) SyntheticFields() map[string]Type { - return o.synetheticFields -} - -func (o Object) OpenAIParameters() openai.FunctionParameters { - var properties = map[string]openai.FunctionParameters{} - var required []string - for k, v := range o.fields { - properties[k] = v.OpenAIParameters() - if v.Required() { - required = append(required, k) - } - } - - for k, v := range o.synetheticFields { - properties[SyntheticFieldPrefix+k] = v.OpenAIParameters() - if v.Required() { - required = append(required, SyntheticFieldPrefix+k) - } - } - - var res = openai.FunctionParameters{ - "type": "object", - "description": o.Description(), - "properties": properties, - } - - if len(required) > 0 { - res["required"] = required - } - - return res -} - -func (o Object) GoogleParameters() *genai.Schema { - var properties = map[string]*genai.Schema{} - var required []string - for k, v := range o.fields { - properties[k] = v.GoogleParameters() - if v.Required() { - required = append(required, k) - } - } - - var res = &genai.Schema{ - Type: genai.TypeObject, - Description: o.Description(), - Properties: properties, - } - - if len(required) > 0 { - res.Required = required - } - - return res -} - -func (o Object) AnthropicInputSchema() map[string]any { - var properties = map[string]any{} - var required []string - for k, v := range o.fields { - properties[k] = v.AnthropicInputSchema() - if v.Required() { - required = append(required, k) - } - } - - var res = map[string]any{ - "type": "object", - "description": o.Description(), - "properties": properties, - } - - if len(required) > 0 { - res["required"] = required - } - - return res -} - -// FromAny converts the value from any to the correct type, returning the value, and an error if any -func (o Object) FromAny(val any) (reflect.Value, error) { - // if the value is nil, we can't do anything - if val == nil { - return reflect.Value{}, nil - } - - // now make a new object of the type we're trying to parse - obj := reflect.New(o.ref).Elem() - - // now we need to iterate over the fields and set the values - for k, v := range o.fields { - // get the field by name - field := obj.FieldByName(k) - if !field.IsValid() { - return reflect.Value{}, errors.New("field " + k + " not found") - } - - // get the value from the map - val2, ok := val.(map[string]interface{})[k] - if !ok { - return reflect.Value{}, errors.New("field " + k + " not found in map") - } - - // now we need to convert the value to the correct type - val3, err := v.FromAny(val2) - if err != nil { - return reflect.Value{}, err - } - - // now we need to set the value on the field - v.SetValueOnField(field, val3) - - } - - return obj, nil -} - -func (o Object) SetValueOnField(obj reflect.Value, val reflect.Value) { - // if this basic type is not required that means it's a pointer type so we need to set the value to the address of the value - if !o.required { - val = val.Addr() - } - - obj.Field(o.index).Set(val) -} diff --git a/schema/raw.go b/schema/raw.go deleted file mode 100644 index b4dcb27..0000000 --- a/schema/raw.go +++ /dev/null @@ -1,134 +0,0 @@ -package schema - -import ( - "encoding/json" - "fmt" - "reflect" - - "github.com/openai/openai-go" - "google.golang.org/genai" -) - -// Raw represents a raw JSON schema that is passed through directly. -// This is used for MCP tools where we receive the schema from the server. -type Raw struct { - schema map[string]any -} - -// NewRaw creates a new Raw schema from a map. -func NewRaw(schema map[string]any) Raw { - if schema == nil { - schema = map[string]any{ - "type": "object", - "properties": map[string]any{}, - } - } - return Raw{schema: schema} -} - -// NewRawFromJSON creates a new Raw schema from JSON bytes. -func NewRawFromJSON(data []byte) (Raw, error) { - var schema map[string]any - if err := json.Unmarshal(data, &schema); err != nil { - return Raw{}, fmt.Errorf("failed to parse JSON schema: %w", err) - } - return NewRaw(schema), nil -} - -func (r Raw) OpenAIParameters() openai.FunctionParameters { - return openai.FunctionParameters(r.schema) -} - -func (r Raw) GoogleParameters() *genai.Schema { - return mapToGenaiSchema(r.schema) -} - -func (r Raw) AnthropicInputSchema() map[string]any { - return r.schema -} - -func (r Raw) Required() bool { - return false -} - -func (r Raw) Description() string { - if desc, ok := r.schema["description"].(string); ok { - return desc - } - return "" -} - -func (r Raw) FromAny(val any) (reflect.Value, error) { - return reflect.ValueOf(val), nil -} - -func (r Raw) SetValueOnField(obj reflect.Value, val reflect.Value) { - // No-op for raw schemas -} - -// mapToGenaiSchema converts a map[string]any JSON schema to genai.Schema -func mapToGenaiSchema(m map[string]any) *genai.Schema { - if m == nil { - return nil - } - - schema := &genai.Schema{} - - // Type - if t, ok := m["type"].(string); ok { - switch t { - case "string": - schema.Type = genai.TypeString - case "number": - schema.Type = genai.TypeNumber - case "integer": - schema.Type = genai.TypeInteger - case "boolean": - schema.Type = genai.TypeBoolean - case "array": - schema.Type = genai.TypeArray - case "object": - schema.Type = genai.TypeObject - } - } - - // Description - if desc, ok := m["description"].(string); ok { - schema.Description = desc - } - - // Enum - if enum, ok := m["enum"].([]any); ok { - for _, e := range enum { - if s, ok := e.(string); ok { - schema.Enum = append(schema.Enum, s) - } - } - } - - // Properties (for objects) - if props, ok := m["properties"].(map[string]any); ok { - schema.Properties = make(map[string]*genai.Schema) - for k, v := range props { - if vm, ok := v.(map[string]any); ok { - schema.Properties[k] = mapToGenaiSchema(vm) - } - } - } - - // Required - if req, ok := m["required"].([]any); ok { - for _, r := range req { - if s, ok := r.(string); ok { - schema.Required = append(schema.Required, s) - } - } - } - - // Items (for arrays) - if items, ok := m["items"].(map[string]any); ok { - schema.Items = mapToGenaiSchema(items) - } - - return schema -} diff --git a/schema/type.go b/schema/type.go deleted file mode 100644 index 318f9b6..0000000 --- a/schema/type.go +++ /dev/null @@ -1,23 +0,0 @@ -package schema - -import ( - "reflect" - - "github.com/openai/openai-go" - "google.golang.org/genai" -) - -type Type interface { - OpenAIParameters() openai.FunctionParameters - GoogleParameters() *genai.Schema - AnthropicInputSchema() map[string]any - - //SchemaType() jsonschema.DataType - //Definition() jsonschema.Definition - - Required() bool - Description() string - - FromAny(any) (reflect.Value, error) - SetValueOnField(obj reflect.Value, val reflect.Value) -} diff --git a/toolbox.go b/toolbox.go deleted file mode 100644 index 10cb892..0000000 --- a/toolbox.go +++ /dev/null @@ -1,174 +0,0 @@ -package llm - -import ( - "context" - "encoding/json" - "errors" - "fmt" -) - -// ToolBox is a collection of tools that OpenAI can use to execute functions. -// It is a wrapper around a collection of functions, and provides a way to automatically call the correct function with -// the correct parameters. -type ToolBox struct { - functions map[string]Function - mcpServers map[string]*MCPServer // tool name -> MCP server that provides it - dontRequireTool bool -} - -func NewToolBox(fns ...Function) ToolBox { - res := ToolBox{ - functions: map[string]Function{}, - } - - for _, f := range fns { - res.functions[f.Name] = f - } - - return res -} - -func (t ToolBox) Functions() []Function { - var res []Function - - for _, f := range t.functions { - res = append(res, f) - } - - return res -} - -func (t ToolBox) WithFunction(f Function) ToolBox { - t.functions[f.Name] = f - - return t -} - -func (t ToolBox) WithFunctions(fns ...Function) ToolBox { - for _, f := range fns { - t.functions[f.Name] = f - } - - return t -} - -func (t ToolBox) WithSyntheticFieldsAddedToAllFunctions(fieldsAndDescriptions map[string]string) ToolBox { - for k, v := range t.functions { - t.functions[k] = v.WithSyntheticFields(fieldsAndDescriptions) - } - - return t -} - -func (t ToolBox) ForEachFunction(fn func(f Function)) { - for _, f := range t.functions { - fn(f) - } -} - -func (t ToolBox) WithFunctionRemoved(name string) ToolBox { - delete(t.functions, name) - return t -} - -func (t ToolBox) WithRequireTool(val bool) ToolBox { - t.dontRequireTool = !val - return t -} - -func (t ToolBox) RequiresTool() bool { - return !t.dontRequireTool && len(t.functions) > 0 -} - -func (t ToolBox) ToToolChoice() any { - if len(t.functions) == 0 { - return nil - } - - return "required" -} - -var ( - ErrFunctionNotFound = errors.New("function not found") -) - -func (t ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) { - // Check if this is an MCP tool - if server, ok := t.mcpServers[functionName]; ok { - var args map[string]any - if params != "" { - if err := json.Unmarshal([]byte(params), &args); err != nil { - return nil, fmt.Errorf("failed to parse MCP tool arguments: %w", err) - } - } - return server.CallTool(ctx, functionName, args) - } - - // Regular function - f, ok := t.functions[functionName] - - if !ok { - return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName)) - } - - return f.Execute(ctx, params) -} - -func (t ToolBox) Execute(ctx *Context, toolCall ToolCall) (any, error) { - return t.executeFunction(ctx.WithToolCall(&toolCall), toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments) -} - -func (t ToolBox) GetSyntheticParametersFromFunctionContext(ctx context.Context) map[string]string { - val := ctx.Value("syntheticParameters") - - if val == nil { - return nil - } - - syntheticParameters, ok := val.(map[string]string) - if !ok { - return nil - } - - return syntheticParameters -} - -// ExecuteCallbacks will execute all the tool calls in the given list, and call the given callbacks when a new function is created, and when a function is finished. -// OnNewFunction is called when a new function is created -// OnFunctionFinished is called when a function is finished -func (t ToolBox) ExecuteCallbacks(ctx *Context, toolCalls []ToolCall, OnNewFunction func(ctx context.Context, funcName string, parameter string) (any, error), OnFunctionFinished func(ctx context.Context, funcName string, parameter string, result any, err error, newFunctionResult any) error) ([]ToolCallResponse, error) { - var res []ToolCallResponse - - for _, call := range toolCalls { - ctx := ctx.WithToolCall(&call) - if call.FunctionCall.Name == "" { - return nil, newError(ErrFunctionNotFound, errors.New("function name is empty")) - } - - var arg any - if OnNewFunction != nil { - var err error - arg, err = OnNewFunction(ctx, call.FunctionCall.Name, call.FunctionCall.Arguments) - - if err != nil { - return nil, newError(ErrFunctionNotFound, err) - } - } - out, err := t.Execute(ctx, call) - - if OnFunctionFinished != nil { - err := OnFunctionFinished(ctx, call.FunctionCall.Name, call.FunctionCall.Arguments, out, err, arg) - if err != nil { - return nil, newError(ErrFunctionNotFound, err) - } - } - - res = append(res, ToolCallResponse{ - ID: call.ID, - Result: out, - Error: err, - }) - } - - return res, nil -} diff --git a/transcriber.go b/transcriber.go deleted file mode 100644 index 9cef7f3..0000000 --- a/transcriber.go +++ /dev/null @@ -1,145 +0,0 @@ -package llm - -import ( - "context" - "fmt" - "os" - "os/exec" - "path/filepath" - "strings" -) - -// Transcriber abstracts a speech-to-text model implementation. -type Transcriber interface { - Transcribe(ctx context.Context, wav []byte, opts TranscriptionOptions) (Transcription, error) -} - -// TranscriptionResponseFormat controls the output format requested from a transcriber. -type TranscriptionResponseFormat string - -const ( - TranscriptionResponseFormatJSON TranscriptionResponseFormat = "json" - TranscriptionResponseFormatVerboseJSON TranscriptionResponseFormat = "verbose_json" - TranscriptionResponseFormatText TranscriptionResponseFormat = "text" - TranscriptionResponseFormatSRT TranscriptionResponseFormat = "srt" - TranscriptionResponseFormatVTT TranscriptionResponseFormat = "vtt" -) - -// TranscriptionTimestampGranularity defines the requested timestamp detail. -type TranscriptionTimestampGranularity string - -const ( - TranscriptionTimestampGranularityWord TranscriptionTimestampGranularity = "word" - TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment" -) - -// TranscriptionOptions configures transcription behavior. -type TranscriptionOptions struct { - Language string - Prompt string - Temperature *float64 - ResponseFormat TranscriptionResponseFormat - TimestampGranularities []TranscriptionTimestampGranularity - IncludeLogprobs bool -} - -// Transcription captures a normalized transcription result. -type Transcription struct { - Provider string - Model string - Text string - Language string - DurationSeconds float64 - Segments []TranscriptionSegment - Words []TranscriptionWord - Logprobs []TranscriptionTokenLogprob - Usage TranscriptionUsage - RawJSON string -} - -// TranscriptionSegment provides a coarse time-sliced transcription segment. -type TranscriptionSegment struct { - ID int - Start float64 - End float64 - Text string - Tokens []int - AvgLogprob *float64 - CompressionRatio *float64 - NoSpeechProb *float64 - Words []TranscriptionWord -} - -// TranscriptionWord provides a word-level timestamp. -type TranscriptionWord struct { - Word string - Start float64 - End float64 - Confidence *float64 -} - -// TranscriptionTokenLogprob captures token-level log probability details. -type TranscriptionTokenLogprob struct { - Token string - Bytes []float64 - Logprob float64 -} - -// TranscriptionUsage captures token or duration usage details. -type TranscriptionUsage struct { - Type string - InputTokens int64 - OutputTokens int64 - TotalTokens int64 - AudioTokens int64 - TextTokens int64 - Seconds float64 -} - -// TranscribeFile converts an audio file to WAV and transcribes it. -func TranscribeFile(ctx context.Context, filename string, transcriber Transcriber, opts TranscriptionOptions) (Transcription, error) { - if transcriber == nil { - return Transcription{}, fmt.Errorf("transcriber is nil") - } - - wav, err := audioFileToWav(ctx, filename) - if err != nil { - return Transcription{}, err - } - - return transcriber.Transcribe(ctx, wav, opts) -} - -func audioFileToWav(ctx context.Context, filename string) ([]byte, error) { - if filename == "" { - return nil, fmt.Errorf("filename is empty") - } - - if strings.EqualFold(filepath.Ext(filename), ".wav") { - data, err := os.ReadFile(filename) - if err != nil { - return nil, fmt.Errorf("read wav file: %w", err) - } - return data, nil - } - - tempFile, err := os.CreateTemp("", "go-llm-audio-*.wav") - if err != nil { - return nil, fmt.Errorf("create temp wav file: %w", err) - } - tempPath := tempFile.Name() - _ = tempFile.Close() - defer os.Remove(tempPath) - - cmd := exec.CommandContext(ctx, "ffmpeg", "-hide_banner", "-loglevel", "error", "-y", "-i", filename, "-vn", "-f", "wav", tempPath) - if output, err := cmd.CombinedOutput(); err != nil { - return nil, fmt.Errorf("ffmpeg convert failed: %w (output: %s)", err, strings.TrimSpace(string(output))) - } - - data, err := os.ReadFile(tempPath) - if err != nil { - return nil, fmt.Errorf("read converted wav file: %w", err) - } - - return data, nil -} diff --git a/v2/cmd/llm/.env.example b/v2/cmd/llm/.env.example new file mode 100644 index 0000000..b8369fd --- /dev/null +++ b/v2/cmd/llm/.env.example @@ -0,0 +1,27 @@ +# go-llm CLI environment variables +# Copy this file to .env and fill in the keys for providers you use. + +# OpenAI API Key (https://platform.openai.com/api-keys) +OPENAI_API_KEY= + +# Anthropic API Key (https://console.anthropic.com/settings/keys) +ANTHROPIC_API_KEY= + +# Google AI API Key (https://aistudio.google.com/apikey) +GOOGLE_API_KEY= + +# DeepSeek API Key (https://platform.deepseek.com) +DEEPSEEK_API_KEY= + +# Moonshot / Kimi API Key (https://platform.moonshot.ai) +MOONSHOT_API_KEY= + +# xAI / Grok API Key (https://x.ai/api) +XAI_API_KEY= + +# Groq API Key (https://console.groq.com/keys) +GROQ_API_KEY= + +# Ollama runs locally with no API key required. +# Override the endpoint if you're not using localhost:11434. +# OLLAMA_BASE_URL=http://localhost:11434/v1 diff --git a/v2/cmd/llm/commands.go b/v2/cmd/llm/commands.go new file mode 100644 index 0000000..03cefb5 --- /dev/null +++ b/v2/cmd/llm/commands.go @@ -0,0 +1,136 @@ +package main + +import ( + "context" + "encoding/base64" + "fmt" + "net/http" + "os" + "strings" + + tea "github.com/charmbracelet/bubbletea" + + llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2" +) + +// Message types for async operations. + +// ChatResponseMsg contains the response from a chat completion. +type ChatResponseMsg struct { + Response llm.Response + Err error +} + +// ToolExecutionMsg contains results from executing tool calls, one Message +// (RoleTool) per ToolCall, in the same order. +type ToolExecutionMsg struct { + Results []llm.Message + Err error +} + +// ImageLoadedMsg contains a loaded image. +type ImageLoadedMsg struct { + Image llm.Image + Err error +} + +// sendChatRequest sends a completion request with the current conversation, +// returning a ChatResponseMsg tea.Msg when the provider responds. +func sendChatRequest(model *llm.Model, messages []llm.Message, toolbox *llm.ToolBox, toolsEnabled bool, temperature *float64) tea.Cmd { + return func() tea.Msg { + opts := buildOpts(toolbox, toolsEnabled, temperature) + resp, err := model.Complete(context.Background(), messages, opts...) + return ChatResponseMsg{Response: resp, Err: err} + } +} + +// executeTools runs each tool call via the toolbox and returns ToolExecutionMsg +// with one RoleTool Message per call, in the same order. +func executeTools(toolbox *llm.ToolBox, calls []llm.ToolCall) tea.Cmd { + return func() tea.Msg { + ctx := context.Background() + results, err := toolbox.ExecuteAll(ctx, calls) + return ToolExecutionMsg{Results: results, Err: err} + } +} + +// buildOpts constructs RequestOptions from the current CLI state. +func buildOpts(toolbox *llm.ToolBox, toolsEnabled bool, temperature *float64) []llm.RequestOption { + var opts []llm.RequestOption + if toolsEnabled && toolbox != nil && len(toolbox.AllTools()) > 0 { + opts = append(opts, llm.WithTools(toolbox)) + } + if temperature != nil { + opts = append(opts, llm.WithTemperature(*temperature)) + } + return opts +} + +// loadImageFromPath loads an image from a file path. +func loadImageFromPath(path string) tea.Cmd { + return func() tea.Msg { + path = strings.TrimSpace(path) + path = strings.Trim(path, "\"'") + + data, err := os.ReadFile(path) + if err != nil { + return ImageLoadedMsg{Err: fmt.Errorf("failed to read image file: %w", err)} + } + + contentType := http.DetectContentType(data) + if !strings.HasPrefix(contentType, "image/") { + return ImageLoadedMsg{Err: fmt.Errorf("file is not an image: %s", contentType)} + } + + return ImageLoadedMsg{ + Image: llm.Image{ + Base64: base64.StdEncoding.EncodeToString(data), + ContentType: contentType, + }, + } + } +} + +// loadImageFromURL loads an image from a URL (kept as URL, not fetched). +func loadImageFromURL(url string) tea.Cmd { + return func() tea.Msg { + return ImageLoadedMsg{Image: llm.Image{URL: strings.TrimSpace(url)}} + } +} + +// loadImageFromBase64 loads an image from base64 data (raw or data: URL). +func loadImageFromBase64(data string) tea.Cmd { + return func() tea.Msg { + data = strings.TrimSpace(data) + + if strings.HasPrefix(data, "data:") { + parts := strings.SplitN(data, ",", 2) + if len(parts) != 2 { + return ImageLoadedMsg{Err: fmt.Errorf("invalid data URL format")} + } + mediaType := strings.TrimPrefix(parts[0], "data:") + mediaType = strings.TrimSuffix(mediaType, ";base64") + return ImageLoadedMsg{ + Image: llm.Image{ + Base64: parts[1], + ContentType: mediaType, + }, + } + } + + decoded, err := base64.StdEncoding.DecodeString(data) + if err != nil { + return ImageLoadedMsg{Err: fmt.Errorf("invalid base64 data: %w", err)} + } + contentType := http.DetectContentType(decoded) + if !strings.HasPrefix(contentType, "image/") { + return ImageLoadedMsg{Err: fmt.Errorf("data is not an image: %s", contentType)} + } + return ImageLoadedMsg{ + Image: llm.Image{ + Base64: data, + ContentType: contentType, + }, + } + } +} diff --git a/cmd/llm/main.go b/v2/cmd/llm/main.go similarity index 100% rename from cmd/llm/main.go rename to v2/cmd/llm/main.go diff --git a/cmd/llm/model.go b/v2/cmd/llm/model.go similarity index 55% rename from cmd/llm/model.go rename to v2/cmd/llm/model.go index 282a0b5..6bcfa8c 100644 --- a/cmd/llm/model.go +++ b/v2/cmd/llm/model.go @@ -7,10 +7,10 @@ import ( "github.com/charmbracelet/bubbles/viewport" tea "github.com/charmbracelet/bubbletea" - llm "gitea.stevedudenhoeffer.com/steve/go-llm" + llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2" ) -// State represents the current view/screen of the application +// State represents the current view/screen of the application. type State int const ( @@ -23,43 +23,42 @@ const ( StateAPIKeyInput ) -// DisplayMessage represents a message for display in the UI +// DisplayMessage represents a message for display in the UI. type DisplayMessage struct { Role llm.Role Content string Images int // number of images attached } -// ProviderInfo contains information about a provider -type ProviderInfo struct { - Name string - EnvVar string - Models []string +// ProviderEntry is a CLI-local view of a registered provider, enriched with +// UI state (which model is currently chosen, whether we have a key, etc.). +type ProviderEntry struct { + Info llm.ProviderInfo HasAPIKey bool ModelIndex int } -// Model is the main Bubble Tea model +// Model is the main Bubble Tea model. type Model struct { // State state State previousState State // Provider - provider llm.LLM + client *llm.Client + chat *llm.Model providerName string - chat llm.ChatCompletion modelName string apiKeys map[string]string - providers []ProviderInfo + providers []ProviderEntry providerIndex int // Conversation - conversation []llm.Input + conversation []llm.Message messages []DisplayMessage // Tools - toolbox llm.ToolBox + toolbox *llm.ToolBox toolsEnabled bool // Settings @@ -90,7 +89,7 @@ type Model struct { apiKeyInput textinput.Model } -// InitialModel creates and returns the initial model +// InitialModel creates and returns the initial model. func InitialModel() Model { ti := textinput.New() ti.Placeholder = "Type your message..." @@ -104,60 +103,21 @@ func InitialModel() Model { aki.Width = 60 aki.EchoMode = textinput.EchoPassword - // Initialize providers with environment variable checks - providers := []ProviderInfo{ - { - Name: "OpenAI", - EnvVar: "OPENAI_API_KEY", - Models: []string{ - "gpt-4.1", - "gpt-4.1-mini", - "gpt-4.1-nano", - "gpt-4o", - "gpt-4o-mini", - "gpt-4-turbo", - "gpt-3.5-turbo", - "o1", - "o1-mini", - "o1-preview", - "o3-mini", - }, - }, - { - Name: "Anthropic", - EnvVar: "ANTHROPIC_API_KEY", - Models: []string{ - "claude-sonnet-4-20250514", - "claude-opus-4-20250514", - "claude-3-7-sonnet-20250219", - "claude-3-5-sonnet-20241022", - "claude-3-5-haiku-20241022", - "claude-3-opus-20240229", - "claude-3-sonnet-20240229", - "claude-3-haiku-20240307", - }, - }, - { - Name: "Google", - EnvVar: "GOOGLE_API_KEY", - Models: []string{ - "gemini-2.0-flash", - "gemini-2.0-flash-lite", - "gemini-1.5-pro", - "gemini-1.5-flash", - "gemini-1.5-flash-8b", - "gemini-1.0-pro", - }, - }, - } - - // Check for API keys in environment + // Build provider list from the go-llm registry. + registry := llm.Providers() + providers := make([]ProviderEntry, len(registry)) apiKeys := make(map[string]string) - for i := range providers { - if key := os.Getenv(providers[i].EnvVar); key != "" { - apiKeys[providers[i].Name] = key - providers[i].HasAPIKey = true + + for i, info := range registry { + entry := ProviderEntry{Info: info} + if info.EnvKey == "" { + // Key-less provider (e.g., Ollama). + entry.HasAPIKey = true + } else if key := os.Getenv(info.EnvKey); key != "" { + apiKeys[info.Name] = key + entry.HasAPIKey = true } + providers[i] = entry } m := Model{ @@ -170,97 +130,87 @@ func InitialModel() Model { toolbox: createDemoToolbox(), toolsEnabled: false, messages: []DisplayMessage{}, - conversation: []llm.Input{}, + conversation: []llm.Message{}, } - // Build list items for provider selection + // Build list items for provider selection. m.listItems = make([]string, len(providers)) for i, p := range providers { status := " (no key)" if p.HasAPIKey { status = " (ready)" + if p.Info.EnvKey == "" { + status = " (local)" + } } - m.listItems[i] = p.Name + status + m.listItems[i] = p.Info.DisplayName + status } return m } -// Init initializes the model +// Init initializes the model. func (m Model) Init() tea.Cmd { return textinput.Blink } -// selectProvider sets up the selected provider +// selectProvider sets up the selected provider. func (m *Model) selectProvider(index int) error { if index < 0 || index >= len(m.providers) { return nil } p := m.providers[index] - key, ok := m.apiKeys[p.Name] - if !ok || key == "" { + key := m.apiKeys[p.Info.Name] // empty for key-less providers like Ollama + + if p.Info.EnvKey != "" && key == "" { return nil } - m.providerName = p.Name + m.providerName = p.Info.DisplayName m.providerIndex = index + m.client = p.Info.New(key) - switch p.Name { - case "OpenAI": - m.provider = llm.OpenAI(key) - case "Anthropic": - m.provider = llm.Anthropic(key) - case "Google": - m.provider = llm.Google(key) - } - - // Select default model - if len(p.Models) > 0 { + // Select default model. + if len(p.Info.Models) > 0 { return m.selectModel(p.ModelIndex) } return nil } -// selectModel sets the current model +// selectModel sets the current model. func (m *Model) selectModel(index int) error { - if m.provider == nil { + if m.client == nil { return nil } p := m.providers[m.providerIndex] - if index < 0 || index >= len(p.Models) { + if index < 0 || index >= len(p.Info.Models) { return nil } - modelName := p.Models[index] - chat, err := m.provider.ModelVersion(modelName) - if err != nil { - return err - } - - m.chat = chat + modelName := p.Info.Models[index] + m.chat = m.client.Model(modelName) m.modelName = modelName m.providers[m.providerIndex].ModelIndex = index return nil } -// newConversation resets the conversation +// newConversation resets the conversation. func (m *Model) newConversation() { - m.conversation = []llm.Input{} + m.conversation = []llm.Message{} m.messages = []DisplayMessage{} m.pendingImages = []llm.Image{} m.err = nil } -// addUserMessage adds a user message to the conversation +// addUserMessage adds a user message to the conversation. func (m *Model) addUserMessage(text string, images []llm.Image) { msg := llm.Message{ - Role: llm.RoleUser, - Text: text, - Images: images, + Role: llm.RoleUser, + Content: llm.Content{Text: text, Images: images}, } m.conversation = append(m.conversation, msg) m.messages = append(m.messages, DisplayMessage{ @@ -270,7 +220,7 @@ func (m *Model) addUserMessage(text string, images []llm.Image) { }) } -// addAssistantMessage adds an assistant message to the conversation +// addAssistantMessage adds an assistant message to the conversation display. func (m *Model) addAssistantMessage(content string) { m.messages = append(m.messages, DisplayMessage{ Role: llm.RoleAssistant, @@ -278,7 +228,7 @@ func (m *Model) addAssistantMessage(content string) { }) } -// addToolCallMessage adds a tool call message to display +// addToolCallMessage adds a tool call message to display. func (m *Model) addToolCallMessage(name string, args string) { m.messages = append(m.messages, DisplayMessage{ Role: llm.Role("tool_call"), @@ -286,7 +236,7 @@ func (m *Model) addToolCallMessage(name string, args string) { }) } -// addToolResultMessage adds a tool result message to display +// addToolResultMessage adds a tool result message to display. func (m *Model) addToolResultMessage(name string, result string) { m.messages = append(m.messages, DisplayMessage{ Role: llm.Role("tool_result"), diff --git a/cmd/llm/styles.go b/v2/cmd/llm/styles.go similarity index 100% rename from cmd/llm/styles.go rename to v2/cmd/llm/styles.go diff --git a/v2/cmd/llm/tools.go b/v2/cmd/llm/tools.go new file mode 100644 index 0000000..5f9bead --- /dev/null +++ b/v2/cmd/llm/tools.go @@ -0,0 +1,114 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "math" + "strconv" + "strings" + "time" + + llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2" +) + +// TimeParams is the parameter struct for the GetTime function. +type TimeParams struct{} + +// GetTime returns the current time. +func GetTime(_ context.Context, _ TimeParams) (string, error) { + return time.Now().Format("Monday, January 2, 2006 3:04:05 PM MST"), nil +} + +// CalcParams is the parameter struct for the Calculate function. +type CalcParams struct { + A float64 `json:"a" description:"First number"` + B float64 `json:"b" description:"Second number"` + Op string `json:"op" description:"Operation: add, subtract, multiply, divide, power, sqrt, mod"` +} + +// Calculate performs basic math operations. +func Calculate(_ context.Context, params CalcParams) (string, error) { + var result float64 + switch strings.ToLower(params.Op) { + case "add", "+": + result = params.A + params.B + case "subtract", "sub", "-": + result = params.A - params.B + case "multiply", "mul", "*": + result = params.A * params.B + case "divide", "div", "/": + if params.B == 0 { + return "", fmt.Errorf("division by zero") + } + result = params.A / params.B + case "power", "pow", "^": + result = math.Pow(params.A, params.B) + case "sqrt": + if params.A < 0 { + return "", fmt.Errorf("cannot take square root of negative number") + } + result = math.Sqrt(params.A) + case "mod", "%": + result = math.Mod(params.A, params.B) + default: + return "", fmt.Errorf("unknown operation: %s", params.Op) + } + return strconv.FormatFloat(result, 'f', -1, 64), nil +} + +// WeatherParams is the parameter struct for the GetWeather function. +type WeatherParams struct { + Location string `json:"location" description:"City name or location"` +} + +// GetWeather returns mock weather data (for demo purposes). +func GetWeather(_ context.Context, params WeatherParams) (string, error) { + weathers := []string{"sunny", "cloudy", "rainy", "partly cloudy", "windy"} + temps := []int{65, 72, 58, 80, 45} + idx := len(params.Location) % len(weathers) + + out := map[string]any{ + "location": params.Location, + "temperature": strconv.Itoa(temps[idx]) + "F", + "condition": weathers[idx], + "humidity": "45%", + "note": "This is mock data for demonstration purposes", + } + b, err := json.Marshal(out) + if err != nil { + return "", err + } + return string(b), nil +} + +// RandomNumberParams is the parameter struct for the RandomNumber function. +type RandomNumberParams struct { + Min int `json:"min" description:"Minimum value (inclusive)"` + Max int `json:"max" description:"Maximum value (inclusive)"` +} + +// RandomNumber generates a pseudo-random number (using current time nanoseconds). +func RandomNumber(_ context.Context, params RandomNumberParams) (string, error) { + if params.Min > params.Max { + return "", fmt.Errorf("min cannot be greater than max") + } + n := time.Now().UnixNano() + rangeSize := params.Max - params.Min + 1 + result := params.Min + int(n%int64(rangeSize)) + return strconv.Itoa(result), nil +} + +// createDemoToolbox creates a toolbox with demo tools for testing. +func createDemoToolbox() *llm.ToolBox { + return llm.NewToolBox( + llm.Define[TimeParams]("get_time", "Get the current date and time", GetTime), + llm.Define[CalcParams]("calculate", + "Perform basic math operations (add, subtract, multiply, divide, power, sqrt, mod)", + Calculate), + llm.Define[WeatherParams]("get_weather", + "Get weather information for a location (demo data)", GetWeather), + llm.Define[RandomNumberParams]("random_number", + "Generate a random number between min and max", RandomNumber), + ) +} diff --git a/cmd/llm/update.go b/v2/cmd/llm/update.go similarity index 68% rename from cmd/llm/update.go rename to v2/cmd/llm/update.go index 4592683..54391b9 100644 --- a/cmd/llm/update.go +++ b/v2/cmd/llm/update.go @@ -8,14 +8,14 @@ import ( "github.com/charmbracelet/bubbles/viewport" tea "github.com/charmbracelet/bubbletea" - llm "gitea.stevedudenhoeffer.com/steve/go-llm" + llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2" ) -// pendingRequest stores the request being processed for follow-up -var pendingRequest llm.Request -var pendingResponse llm.ResponseChoice +// pendingToolCalls stores the last response's tool calls so we can pair them +// with tool execution results for display. +var pendingToolCalls []llm.ToolCall -// Update handles messages and updates the model +// Update handles messages and updates the model. func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmd tea.Cmd var cmds []tea.Cmd @@ -53,40 +53,30 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, nil } - if len(msg.Response.Choices) == 0 { - m.err = fmt.Errorf("no response choices returned") - return m, nil + resp := msg.Response + + // Add the assistant message to the conversation history. + m.conversation = append(m.conversation, resp.Message()) + + // Show any text the assistant produced alongside tool calls. + if resp.Text != "" { + m.addAssistantMessage(resp.Text) } - choice := msg.Response.Choices[0] + if resp.HasToolCalls() && m.toolsEnabled { + pendingToolCalls = resp.ToolCalls - // Check for tool calls - if len(choice.Calls) > 0 && m.toolsEnabled { - // Store for follow-up - pendingResponse = choice - - // Add assistant's response to conversation if there's content - if choice.Content != "" { - m.addAssistantMessage(choice.Content) - } - - // Display tool calls - for _, call := range choice.Calls { - m.addToolCallMessage(call.FunctionCall.Name, call.FunctionCall.Arguments) + for _, call := range resp.ToolCalls { + m.addToolCallMessage(call.Name, call.Arguments) } m.viewport.SetContent(m.renderMessages()) m.viewport.GotoBottom() - // Execute tools m.loading = true - return m, executeTools(m.toolbox, pendingRequest, choice) + return m, executeTools(m.toolbox, resp.ToolCalls) } - // Regular response - add to conversation and display - m.conversation = append(m.conversation, choice) - m.addAssistantMessage(choice.Content) - m.viewport.SetContent(m.renderMessages()) m.viewport.GotoBottom() @@ -97,31 +87,24 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, nil } - // Display tool results + // Display results paired with the tool calls that produced them. for i, result := range msg.Results { - name := pendingResponse.Calls[i].FunctionCall.Name - resultStr := fmt.Sprintf("%v", result.Result) - if result.Error != nil { - resultStr = "Error: " + result.Error.Error() + name := "" + if i < len(pendingToolCalls) { + name = pendingToolCalls[i].Name } - m.addToolResultMessage(name, resultStr) + m.addToolResultMessage(name, result.Content.Text) } - // Add tool call responses to conversation - for _, result := range msg.Results { - m.conversation = append(m.conversation, result) - } - - // Add the assistant's response to conversation - m.conversation = append(m.conversation, pendingResponse) + // Append the raw tool result messages to the conversation so the + // assistant can reference them on the next turn. + m.conversation = append(m.conversation, msg.Results...) m.viewport.SetContent(m.renderMessages()) m.viewport.GotoBottom() - // Send follow-up request - followUp := buildFollowUpRequest(&m, pendingRequest, pendingResponse, msg.Results) - pendingRequest = followUp - return m, sendChatRequest(m.chat, followUp) + // Ask the model to continue given the tool results. + return m, sendChatRequest(m.chat, m.conversation, m.toolbox, m.toolsEnabled, m.temperature) case ImageLoadedMsg: if msg.Err != nil { @@ -135,7 +118,7 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.err = nil default: - // Update text input + // Update text input. if m.state == StateChat { m.input, cmd = m.input.Update(msg) cmds = append(cmds, cmd) @@ -148,13 +131,11 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, tea.Batch(cmds...) } -// handleKeyMsg handles keyboard input +// handleKeyMsg handles keyboard input. func (m Model) handleKeyMsg(msg tea.KeyMsg) (tea.Model, tea.Cmd) { - // Global key handling switch msg.String() { case "ctrl+c": return m, tea.Quit - case "esc": if m.state != StateChat { m.state = StateChat @@ -164,7 +145,6 @@ func (m Model) handleKeyMsg(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return m, tea.Quit } - // State-specific key handling switch m.state { case StateChat: return m.handleChatKeys(msg) @@ -185,7 +165,7 @@ func (m Model) handleKeyMsg(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return m, nil } -// handleChatKeys handles keys in chat state +// handleChatKeys handles keys in chat state. func (m Model) handleChatKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { switch msg.String() { case "enter": @@ -203,14 +183,13 @@ func (m Model) handleChatKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return m, nil } - // Build and send request - req := buildRequest(&m, text) - pendingRequest = req + // Ensure a system message is at the head of the conversation. + if len(m.conversation) == 0 && m.systemPrompt != "" { + m.conversation = append(m.conversation, llm.SystemMessage(m.systemPrompt)) + } - // Add user message to display m.addUserMessage(text, m.pendingImages) - // Clear input and pending images m.input.Reset() m.pendingImages = nil m.err = nil @@ -219,7 +198,7 @@ func (m Model) handleChatKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { m.viewport.SetContent(m.renderMessages()) m.viewport.GotoBottom() - return m, sendChatRequest(m.chat, req) + return m, sendChatRequest(m.chat, m.conversation, m.toolbox, m.toolsEnabled, m.temperature) case "ctrl+i": m.previousState = StateChat @@ -238,12 +217,12 @@ func (m Model) handleChatKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return m, nil case "ctrl+m": - if m.provider == nil { + if m.client == nil { m.err = fmt.Errorf("select a provider first") return m, nil } m.state = StateModelSelect - m.listItems = m.providers[m.providerIndex].Models + m.listItems = m.providers[m.providerIndex].Info.Models m.listIndex = m.providers[m.providerIndex].ModelIndex return m, nil @@ -268,7 +247,7 @@ func (m Model) handleChatKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { } } -// handleProviderSelectKeys handles keys in provider selection state +// handleProviderSelectKeys handles keys in provider selection state. func (m Model) handleProviderSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { switch msg.String() { case "up", "k": @@ -282,15 +261,13 @@ func (m Model) handleProviderSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { case "enter": p := m.providers[m.listIndex] if !p.HasAPIKey { - // Need to get API key m.state = StateAPIKeyInput m.apiKeyInput.Focus() m.apiKeyInput.SetValue("") return m, textinput.Blink } - err := m.selectProvider(m.listIndex) - if err != nil { + if err := m.selectProvider(m.listIndex); err != nil { m.err = err return m, nil } @@ -303,7 +280,7 @@ func (m Model) handleProviderSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return m, nil } -// handleAPIKeyInputKeys handles keys in API key input state +// handleAPIKeyInputKeys handles keys in API key input state. func (m Model) handleAPIKeyInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { switch msg.String() { case "enter": @@ -312,23 +289,22 @@ func (m Model) handleAPIKeyInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return m, nil } - // Store the API key p := m.providers[m.listIndex] - m.apiKeys[p.Name] = key + m.apiKeys[p.Info.Name] = key m.providers[m.listIndex].HasAPIKey = true - // Update list items for i, prov := range m.providers { status := " (no key)" if prov.HasAPIKey { status = " (ready)" + if prov.Info.EnvKey == "" { + status = " (local)" + } } - m.listItems[i] = prov.Name + status + m.listItems[i] = prov.Info.DisplayName + status } - // Select the provider - err := m.selectProvider(m.listIndex) - if err != nil { + if err := m.selectProvider(m.listIndex); err != nil { m.err = err return m, nil } @@ -345,7 +321,7 @@ func (m Model) handleAPIKeyInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { } } -// handleModelSelectKeys handles keys in model selection state +// handleModelSelectKeys handles keys in model selection state. func (m Model) handleModelSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { switch msg.String() { case "up", "k": @@ -357,8 +333,7 @@ func (m Model) handleModelSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { m.listIndex++ } case "enter": - err := m.selectModel(m.listIndex) - if err != nil { + if err := m.selectModel(m.listIndex); err != nil { m.err = err return m, nil } @@ -368,7 +343,7 @@ func (m Model) handleModelSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return m, nil } -// handleImageInputKeys handles keys in image input state +// handleImageInputKeys handles keys in image input state. func (m Model) handleImageInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { switch msg.String() { case "enter": @@ -381,12 +356,12 @@ func (m Model) handleImageInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { m.input.Placeholder = "Type your message..." - // Determine input type and load - if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") { + switch { + case strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://"): return m, loadImageFromURL(input) - } else if strings.HasPrefix(input, "data:") || len(input) > 100 && !strings.Contains(input, "/") && !strings.Contains(input, "\\") { + case strings.HasPrefix(input, "data:") || (len(input) > 100 && !strings.Contains(input, "/") && !strings.Contains(input, "\\")): return m, loadImageFromBase64(input) - } else { + default: return m, loadImageFromPath(input) } @@ -397,7 +372,7 @@ func (m Model) handleImageInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { } } -// handleToolsPanelKeys handles keys in tools panel state +// handleToolsPanelKeys handles keys in tools panel state. func (m Model) handleToolsPanelKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { switch msg.String() { case "t": @@ -409,11 +384,10 @@ func (m Model) handleToolsPanelKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return m, nil } -// handleSettingsKeys handles keys in settings state +// handleSettingsKeys handles keys in settings state. func (m Model) handleSettingsKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { switch msg.String() { case "1": - // Set temperature to nil (default) m.temperature = nil case "2": t := 0.0 diff --git a/cmd/llm/view.go b/v2/cmd/llm/view.go similarity index 88% rename from cmd/llm/view.go rename to v2/cmd/llm/view.go index bc63e0e..66bc449 100644 --- a/cmd/llm/view.go +++ b/v2/cmd/llm/view.go @@ -6,10 +6,10 @@ import ( "github.com/charmbracelet/lipgloss" - llm "gitea.stevedudenhoeffer.com/steve/go-llm" + llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2" ) -// View renders the current state +// View renders the current state. func (m Model) View() string { switch m.state { case StateProviderSelect: @@ -29,11 +29,10 @@ func (m Model) View() string { } } -// renderChat renders the main chat view +// renderChat renders the main chat view. func (m Model) renderChat() string { var b strings.Builder - // Header provider := m.providerName if provider == "" { provider = "None" @@ -49,43 +48,37 @@ func (m Model) renderChat() string { b.WriteString(header) b.WriteString("\n") - // Messages viewport if m.viewportReady { b.WriteString(m.viewport.View()) b.WriteString("\n") } - // Image indicator if len(m.pendingImages) > 0 { b.WriteString(imageIndicatorStyle.Render(fmt.Sprintf(" [%d image(s) attached]", len(m.pendingImages)))) b.WriteString("\n") } - // Error if m.err != nil { b.WriteString(errorStyle.Render(" Error: " + m.err.Error())) b.WriteString("\n") } - // Loading if m.loading { b.WriteString(loadingStyle.Render(" Thinking...")) b.WriteString("\n") } - // Input inputBox := inputStyle.Render(m.input.View()) b.WriteString(inputBox) b.WriteString("\n") - // Help help := inputHelpStyle.Render("Enter: send | Ctrl+I: image | Ctrl+T: tools | Ctrl+P: provider | Ctrl+M: model | Ctrl+S: settings | Ctrl+N: new | Esc: quit") b.WriteString(help) return appStyle.Render(b.String()) } -// renderMessages renders all messages for the viewport +// renderMessages renders all messages for the viewport. func (m Model) renderMessages() string { var b strings.Builder @@ -133,7 +126,7 @@ func (m Model) renderMessages() string { return b.String() } -// renderProviderSelect renders the provider selection view +// renderProviderSelect renders the provider selection view. func (m Model) renderProviderSelect() string { var b strings.Builder @@ -157,16 +150,18 @@ func (m Model) renderProviderSelect() string { return appStyle.Render(b.String()) } -// renderAPIKeyInput renders the API key input view +// renderAPIKeyInput renders the API key input view. func (m Model) renderAPIKeyInput() string { var b strings.Builder provider := m.providers[m.listIndex] - b.WriteString(headerStyle.Render(fmt.Sprintf("Enter API Key for %s", provider.Name))) + b.WriteString(headerStyle.Render(fmt.Sprintf("Enter API Key for %s", provider.Info.DisplayName))) b.WriteString("\n\n") - b.WriteString(fmt.Sprintf("Environment variable: %s\n\n", provider.EnvVar)) + if provider.Info.EnvKey != "" { + b.WriteString(fmt.Sprintf("Environment variable: %s\n\n", provider.Info.EnvKey)) + } b.WriteString("Enter your API key below (it will be hidden):\n\n") inputBox := inputStyle.Render(m.apiKeyInput.View()) @@ -178,7 +173,7 @@ func (m Model) renderAPIKeyInput() string { return appStyle.Render(b.String()) } -// renderModelSelect renders the model selection view +// renderModelSelect renders the model selection view. func (m Model) renderModelSelect() string { var b strings.Builder @@ -205,7 +200,7 @@ func (m Model) renderModelSelect() string { return appStyle.Render(b.String()) } -// renderImageInput renders the image input view +// renderImageInput renders the image input view. func (m Model) renderImageInput() string { var b strings.Builder @@ -230,7 +225,7 @@ func (m Model) renderImageInput() string { return appStyle.Render(b.String()) } -// renderToolsPanel renders the tools panel +// renderToolsPanel renders the tools panel. func (m Model) renderToolsPanel() string { var b strings.Builder @@ -249,8 +244,10 @@ func (m Model) renderToolsPanel() string { b.WriteString("\n\n") b.WriteString("Available tools:\n") - for _, fn := range m.toolbox.Functions() { - b.WriteString(fmt.Sprintf(" - %s: %s\n", selectedItemStyle.Render(fn.Name), fn.Description)) + if m.toolbox != nil { + for _, t := range m.toolbox.AllTools() { + b.WriteString(fmt.Sprintf(" - %s: %s\n", selectedItemStyle.Render(t.Name), t.Description)) + } } b.WriteString("\n") @@ -259,14 +256,13 @@ func (m Model) renderToolsPanel() string { return appStyle.Render(b.String()) } -// renderSettings renders the settings view +// renderSettings renders the settings view. func (m Model) renderSettings() string { var b strings.Builder b.WriteString(headerStyle.Render("Settings")) b.WriteString("\n\n") - // Temperature tempStr := "default" if m.temperature != nil { tempStr = fmt.Sprintf("%.1f", *m.temperature) @@ -284,7 +280,6 @@ func (m Model) renderSettings() string { b.WriteString("\n") - // System prompt b.WriteString(settingLabelStyle.Render("System Prompt:")) b.WriteString("\n") b.WriteString(settingValueStyle.Render(" " + m.systemPrompt)) diff --git a/v2/constructors.go b/v2/constructors.go index 9fa7d11..846c697 100644 --- a/v2/constructors.go +++ b/v2/constructors.go @@ -2,8 +2,13 @@ package llm import ( anthProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/anthropic" + deepseekProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/deepseek" googleProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/google" + groqProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/groq" + moonshotProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/moonshot" + ollamaProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/ollama" openaiProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openai" + xaiProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/xai" ) // OpenAI creates an OpenAI client. @@ -46,3 +51,69 @@ func Google(apiKey string, opts ...ClientOption) *Client { _ = cfg // Google doesn't support custom base URL in the SDK return NewClient(googleProvider.New(apiKey)) } + +// DeepSeek creates a DeepSeek client (OpenAI-compatible). +// +// Example: +// +// model := llm.DeepSeek("sk-...").Model("deepseek-chat") +func DeepSeek(apiKey string, opts ...ClientOption) *Client { + cfg := &clientConfig{} + for _, opt := range opts { + opt(cfg) + } + return NewClient(deepseekProvider.New(apiKey, cfg.baseURL)) +} + +// Moonshot creates a Moonshot AI (Kimi) client (OpenAI-compatible). +// +// Example: +// +// model := llm.Moonshot("sk-...").Model("kimi-k2-0711-preview") +func Moonshot(apiKey string, opts ...ClientOption) *Client { + cfg := &clientConfig{} + for _, opt := range opts { + opt(cfg) + } + return NewClient(moonshotProvider.New(apiKey, cfg.baseURL)) +} + +// XAI creates an xAI (Grok) client (OpenAI-compatible). +// +// Example: +// +// model := llm.XAI("xai-...").Model("grok-2") +func XAI(apiKey string, opts ...ClientOption) *Client { + cfg := &clientConfig{} + for _, opt := range opts { + opt(cfg) + } + return NewClient(xaiProvider.New(apiKey, cfg.baseURL)) +} + +// Groq creates a Groq client (OpenAI-compatible). +// +// Example: +// +// model := llm.Groq("gsk-...").Model("llama-3.3-70b-versatile") +func Groq(apiKey string, opts ...ClientOption) *Client { + cfg := &clientConfig{} + for _, opt := range opts { + opt(cfg) + } + return NewClient(groqProvider.New(apiKey, cfg.baseURL)) +} + +// Ollama creates a client for a local Ollama instance (OpenAI-compatible). +// No API key is required. Use WithBaseURL to point at a non-default host/port. +// +// Example: +// +// model := llm.Ollama().Model("llama3.2") +func Ollama(opts ...ClientOption) *Client { + cfg := &clientConfig{} + for _, opt := range opts { + opt(cfg) + } + return NewClient(ollamaProvider.New("", cfg.baseURL)) +} diff --git a/v2/deepseek/deepseek.go b/v2/deepseek/deepseek.go new file mode 100644 index 0000000..be67913 --- /dev/null +++ b/v2/deepseek/deepseek.go @@ -0,0 +1,36 @@ +// Package deepseek implements the go-llm v2 provider interface for DeepSeek +// (https://platform.deepseek.com). DeepSeek speaks the OpenAI Chat Completions +// protocol, so this package is a thin wrapper around openaicompat with its own +// defaults and per-model Rules. +package deepseek + +import ( + "strings" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" +) + +// DefaultBaseURL is the public DeepSeek API endpoint. +const DefaultBaseURL = "https://api.deepseek.com/v1" + +// Provider is a type alias over openaicompat.Provider. +type Provider = openaicompat.Provider + +// New creates a new DeepSeek provider. An empty baseURL uses DefaultBaseURL. +func New(apiKey, baseURL string) *Provider { + if baseURL == "" { + baseURL = DefaultBaseURL + } + return openaicompat.New(apiKey, baseURL, openaicompat.Rules{ + // DeepSeek's chat and reasoner models are text-only. + SupportsVision: func(string) bool { return false }, + // Reasoner doesn't accept tool calls. + SupportsTools: func(m string) bool { + return !strings.Contains(m, "reasoner") + }, + // Reasoner rejects user-supplied temperature. + RestrictTemperature: func(m string) bool { + return strings.Contains(m, "reasoner") + }, + }) +} diff --git a/v2/deepseek/deepseek_test.go b/v2/deepseek/deepseek_test.go new file mode 100644 index 0000000..0d9b629 --- /dev/null +++ b/v2/deepseek/deepseek_test.go @@ -0,0 +1,49 @@ +package deepseek_test + +import ( + "context" + "errors" + "testing" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/deepseek" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +func TestNew_DefaultBaseURL(t *testing.T) { + if p := deepseek.New("key", ""); p == nil { + t.Fatal("New returned nil") + } +} + +func TestRules_ReasonerRejectsTools(t *testing.T) { + p := deepseek.New("key", "") + req := provider.Request{ + Model: "deepseek-reasoner", + Messages: []provider.Message{{Role: "user", Content: "hi"}}, + Tools: []provider.ToolDef{ + {Name: "x", Schema: map[string]any{"type": "object"}}, + }, + } + _, err := p.Complete(context.Background(), req) + var fue *openaicompat.FeatureUnsupportedError + if !errors.As(err, &fue) || fue.Feature != "tools" { + t.Fatalf("want FeatureUnsupportedError(tools), got %v", err) + } +} + +func TestRules_ChatRejectsImages(t *testing.T) { + p := deepseek.New("key", "") + req := provider.Request{ + Model: "deepseek-chat", + Messages: []provider.Message{{ + Role: "user", + Images: []provider.Image{{URL: "a"}}, + }}, + } + _, err := p.Complete(context.Background(), req) + var fue *openaicompat.FeatureUnsupportedError + if !errors.As(err, &fue) || fue.Feature != "vision" { + t.Fatalf("want FeatureUnsupportedError(vision), got %v", err) + } +} diff --git a/v2/go.mod b/v2/go.mod index d57588e..cedc62d 100644 --- a/v2/go.mod +++ b/v2/go.mod @@ -1,10 +1,12 @@ module gitea.stevedudenhoeffer.com/steve/go-llm/v2 -go 1.24.0 - -toolchain go1.24.2 +go 1.24.2 require ( + github.com/charmbracelet/bubbles v1.0.0 + github.com/charmbracelet/bubbletea v1.3.10 + github.com/charmbracelet/lipgloss v1.1.0 + github.com/joho/godotenv v1.5.1 github.com/liushuangls/go-anthropic/v2 v2.17.0 github.com/modelcontextprotocol/go-sdk v1.2.0 github.com/openai/openai-go v1.12.0 @@ -18,6 +20,16 @@ require ( cloud.google.com/go v0.116.0 // indirect cloud.google.com/go/auth v0.9.3 // indirect cloud.google.com/go/compute/metadata v0.5.0 // indirect + github.com/atotto/clipboard v0.1.4 // indirect + github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect + github.com/charmbracelet/colorprofile v0.4.1 // indirect + github.com/charmbracelet/x/ansi v0.11.6 // indirect + github.com/charmbracelet/x/cellbuf v0.0.15 // indirect + github.com/charmbracelet/x/term v0.2.2 // indirect + github.com/clipperhouse/displaywidth v0.9.0 // indirect + github.com/clipperhouse/stringish v0.1.1 // indirect + github.com/clipperhouse/uax29/v2 v2.5.0 // indirect + github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/jsonschema-go v0.3.0 // indirect @@ -25,15 +37,24 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect github.com/gorilla/websocket v1.5.3 // indirect github.com/kr/fs v0.1.0 // indirect + github.com/lucasb-eyer/go-colorful v1.3.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-localereader v0.0.1 // indirect + github.com/mattn/go-runewidth v0.0.19 // indirect + github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect + github.com/muesli/cancelreader v0.2.2 // indirect + github.com/muesli/termenv v0.16.0 // indirect + github.com/rivo/uniseg v0.4.7 // indirect github.com/tidwall/gjson v1.14.4 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect + github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.opencensus.io v0.24.0 // indirect golang.org/x/net v0.42.0 // indirect golang.org/x/oauth2 v0.30.0 // indirect - golang.org/x/sys v0.35.0 // indirect + golang.org/x/sys v0.38.0 // indirect golang.org/x/text v0.33.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect google.golang.org/grpc v1.66.2 // indirect diff --git a/v2/go.sum b/v2/go.sum index bd64989..5ef6ced 100644 --- a/v2/go.sum +++ b/v2/go.sum @@ -6,8 +6,32 @@ cloud.google.com/go/auth v0.9.3/go.mod h1:7z6VY+7h3KUdRov5F1i8NDP5ZzWKYmEPO842Bg cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY= cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= +github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= +github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= +github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc= +github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E= +github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw= +github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4= +github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk= +github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk= +github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= +github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= +github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8= +github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ= +github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI= +github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q= +github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk= +github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA= +github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA= +github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= +github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= +github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U= +github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -16,6 +40,8 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= @@ -49,12 +75,28 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gT github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/liushuangls/go-anthropic/v2 v2.17.0 h1:iBA6h7aghi1q86owEQ95XE2R2MF/0dQ7bCxtwTxOg4c= github.com/liushuangls/go-anthropic/v2 v2.17.0/go.mod h1:a550cJXPoTG2FL3DvfKG2zzD5O2vjgvo4tHtoGPzFLU= +github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag= +github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= +github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= +github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= +github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s= github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= +github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= +github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= +github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= +github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0= github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/pkg/sftp v1.13.10 h1:+5FbKNTe5Z9aspU88DPIKJ9z2KZoaGCu6Sr6kKR/5mU= @@ -62,6 +104,8 @@ github.com/pkg/sftp v1.13.10/go.mod h1:bJ1a7uDhrX/4OII+agvy28lzRvQrmIQuaHrcI1Hbe github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -80,6 +124,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= @@ -89,6 +135,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= golang.org/x/image v0.35.0 h1:LKjiHdgMtO8z7Fh18nGY6KDcoEtVfsgLDPeLyguqb7I= golang.org/x/image v0.35.0/go.mod h1:MwPLTVgvxSASsxdLzKrl8BRFuyqMyGhLwmC+TO1Sybk= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -114,8 +162,10 @@ golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= -golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4= golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/v2/groq/groq.go b/v2/groq/groq.go new file mode 100644 index 0000000..cd1d9a7 --- /dev/null +++ b/v2/groq/groq.go @@ -0,0 +1,33 @@ +// Package groq implements the go-llm v2 provider interface for Groq +// (https://console.groq.com). Groq hosts open-source models behind an OpenAI +// Chat Completions-compatible endpoint, so this package is a thin wrapper over +// openaicompat with its own defaults and per-model Rules. +package groq + +import ( + "strings" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" +) + +// DefaultBaseURL is the public Groq OpenAI-compatible endpoint. +const DefaultBaseURL = "https://api.groq.com/openai/v1" + +// Provider is a type alias over openaicompat.Provider. +type Provider = openaicompat.Provider + +// New creates a new Groq provider. An empty baseURL uses DefaultBaseURL. +func New(apiKey, baseURL string) *Provider { + if baseURL == "" { + baseURL = DefaultBaseURL + } + return openaicompat.New(apiKey, baseURL, openaicompat.Rules{ + // Only Groq-hosted vision variants (e.g. *-vision-preview) accept images. + SupportsVision: func(m string) bool { + return strings.Contains(m, "vision") + }, + // Chat completions endpoint does not accept audio input; audio is via + // dedicated transcription endpoints, which go-llm doesn't cover here. + SupportsAudio: func(string) bool { return false }, + }) +} diff --git a/v2/groq/groq_test.go b/v2/groq/groq_test.go new file mode 100644 index 0000000..d50c118 --- /dev/null +++ b/v2/groq/groq_test.go @@ -0,0 +1,33 @@ +package groq_test + +import ( + "context" + "errors" + "testing" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/groq" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +func TestNew_Basic(t *testing.T) { + if p := groq.New("key", ""); p == nil { + t.Fatal("New returned nil") + } +} + +func TestRules_AudioRejected(t *testing.T) { + p := groq.New("key", "") + req := provider.Request{ + Model: "llama-3.3-70b-versatile", + Messages: []provider.Message{{ + Role: "user", + Audio: []provider.Audio{{Base64: "AAA=", ContentType: "audio/wav"}}, + }}, + } + _, err := p.Complete(context.Background(), req) + var fue *openaicompat.FeatureUnsupportedError + if !errors.As(err, &fue) || fue.Feature != "audio" { + t.Fatalf("want FeatureUnsupportedError(audio), got %v", err) + } +} diff --git a/v2/moonshot/moonshot.go b/v2/moonshot/moonshot.go new file mode 100644 index 0000000..8d7fbe1 --- /dev/null +++ b/v2/moonshot/moonshot.go @@ -0,0 +1,30 @@ +// Package moonshot implements the go-llm v2 provider interface for Moonshot +// AI (Kimi, https://platform.moonshot.ai). Moonshot speaks OpenAI Chat +// Completions, so this package is a thin wrapper over openaicompat with its +// own defaults and per-model Rules. +package moonshot + +import ( + "strings" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" +) + +// DefaultBaseURL is the public Moonshot API endpoint (international). +const DefaultBaseURL = "https://api.moonshot.ai/v1" + +// Provider is a type alias over openaicompat.Provider. +type Provider = openaicompat.Provider + +// New creates a new Moonshot provider. An empty baseURL uses DefaultBaseURL. +func New(apiKey, baseURL string) *Provider { + if baseURL == "" { + baseURL = DefaultBaseURL + } + return openaicompat.New(apiKey, baseURL, openaicompat.Rules{ + // Only Moonshot models whose name contains "vision" accept images. + SupportsVision: func(m string) bool { + return strings.Contains(m, "vision") + }, + }) +} diff --git a/v2/moonshot/moonshot_test.go b/v2/moonshot/moonshot_test.go new file mode 100644 index 0000000..d3f3c6f --- /dev/null +++ b/v2/moonshot/moonshot_test.go @@ -0,0 +1,33 @@ +package moonshot_test + +import ( + "context" + "errors" + "testing" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/moonshot" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +func TestNew_Basic(t *testing.T) { + if p := moonshot.New("key", ""); p == nil { + t.Fatal("New returned nil") + } +} + +func TestRules_NonVisionModelRejectsImages(t *testing.T) { + p := moonshot.New("key", "") + req := provider.Request{ + Model: "moonshot-v1-8k", + Messages: []provider.Message{{ + Role: "user", + Images: []provider.Image{{URL: "a"}}, + }}, + } + _, err := p.Complete(context.Background(), req) + var fue *openaicompat.FeatureUnsupportedError + if !errors.As(err, &fue) || fue.Feature != "vision" { + t.Fatalf("want FeatureUnsupportedError(vision), got %v", err) + } +} diff --git a/v2/ollama/ollama.go b/v2/ollama/ollama.go new file mode 100644 index 0000000..75623ca --- /dev/null +++ b/v2/ollama/ollama.go @@ -0,0 +1,25 @@ +// Package ollama implements the go-llm v2 provider interface for Ollama +// (https://ollama.com), a local model runner that exposes an OpenAI Chat +// Completions-compatible endpoint. No API key is required; capability depends +// on whichever model the user has pulled locally, so Rules are intentionally +// empty — we trust the local user. +package ollama + +import ( + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" +) + +// DefaultBaseURL points at a local Ollama instance with default port. +const DefaultBaseURL = "http://localhost:11434/v1" + +// Provider is a type alias over openaicompat.Provider. +type Provider = openaicompat.Provider + +// New creates a new Ollama provider. An empty baseURL uses DefaultBaseURL. +// Ollama ignores the API key; callers may pass "". +func New(apiKey, baseURL string) *Provider { + if baseURL == "" { + baseURL = DefaultBaseURL + } + return openaicompat.New(apiKey, baseURL, openaicompat.Rules{}) +} diff --git a/v2/ollama/ollama_test.go b/v2/ollama/ollama_test.go new file mode 100644 index 0000000..5c1e216 --- /dev/null +++ b/v2/ollama/ollama_test.go @@ -0,0 +1,13 @@ +package ollama_test + +import ( + "testing" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/ollama" +) + +func TestNew_NoKeyNeeded(t *testing.T) { + if p := ollama.New("", ""); p == nil { + t.Fatal("New returned nil") + } +} diff --git a/v2/openai/openai.go b/v2/openai/openai.go index 4193542..5b40964 100644 --- a/v2/openai/openai.go +++ b/v2/openai/openai.go @@ -1,433 +1,35 @@ // Package openai implements the go-llm v2 provider interface for OpenAI. +// +// The actual wire-protocol logic lives in the shared openaicompat package; +// this file encodes OpenAI-specific Rules (temperature is rejected on o-series +// and gpt-5* models) and supplies the default base URL. package openai import ( - "context" - "encoding/base64" - "fmt" - "io" - "net/http" - "path" "strings" - "github.com/openai/openai-go" - "github.com/openai/openai-go/option" - "github.com/openai/openai-go/packages/param" - "github.com/openai/openai-go/shared" - - "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" ) -// Provider implements the provider.Provider interface for OpenAI. -type Provider struct { - apiKey string - baseURL string -} +// DefaultBaseURL is the public OpenAI Chat Completions endpoint. +const DefaultBaseURL = "https://api.openai.com/v1" -// New creates a new OpenAI provider. +// Provider is the OpenAI chat-completion provider. It's a type alias over +// openaicompat.Provider so existing callers using openai.Provider keep compiling. +type Provider = openaicompat.Provider + +// New creates a new OpenAI provider. An empty baseURL uses DefaultBaseURL. func New(apiKey string, baseURL string) *Provider { - return &Provider{apiKey: apiKey, baseURL: baseURL} + if baseURL == "" { + baseURL = DefaultBaseURL + } + return openaicompat.New(apiKey, baseURL, openaicompat.Rules{ + RestrictTemperature: restrictTemperature, + }) } -// Complete performs a non-streaming completion. -func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) { - var opts []option.RequestOption - opts = append(opts, option.WithAPIKey(p.apiKey)) - if p.baseURL != "" { - opts = append(opts, option.WithBaseURL(p.baseURL)) - } - - cl := openai.NewClient(opts...) - oaiReq := p.buildRequest(req) - - resp, err := cl.Chat.Completions.New(ctx, oaiReq) - if err != nil { - return provider.Response{}, fmt.Errorf("openai completion error: %w", err) - } - - return p.convertResponse(resp), nil -} - -// Stream performs a streaming completion. -func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error { - var opts []option.RequestOption - opts = append(opts, option.WithAPIKey(p.apiKey)) - if p.baseURL != "" { - opts = append(opts, option.WithBaseURL(p.baseURL)) - } - - cl := openai.NewClient(opts...) - oaiReq := p.buildRequest(req) - oaiReq.StreamOptions = openai.ChatCompletionStreamOptionsParam{ - IncludeUsage: openai.Bool(true), - } - - stream := cl.Chat.Completions.NewStreaming(ctx, oaiReq) - - var fullText strings.Builder - var toolCalls []provider.ToolCall - toolCallArgs := map[int]*strings.Builder{} - var usage *provider.Usage - - for stream.Next() { - chunk := stream.Current() - - // Capture usage from the final chunk (present when StreamOptions.IncludeUsage is true) - if chunk.Usage.TotalTokens > 0 { - usage = &provider.Usage{ - InputTokens: int(chunk.Usage.PromptTokens), - OutputTokens: int(chunk.Usage.CompletionTokens), - TotalTokens: int(chunk.Usage.TotalTokens), - Details: extractUsageDetails(chunk.Usage), - } - } - - for _, choice := range chunk.Choices { - // Text delta - if choice.Delta.Content != "" { - fullText.WriteString(choice.Delta.Content) - events <- provider.StreamEvent{ - Type: provider.StreamEventText, - Text: choice.Delta.Content, - } - } - - // Tool call deltas - for _, tc := range choice.Delta.ToolCalls { - idx := int(tc.Index) - - if tc.ID != "" { - // New tool call starting - for len(toolCalls) <= idx { - toolCalls = append(toolCalls, provider.ToolCall{}) - } - toolCalls[idx].ID = tc.ID - toolCalls[idx].Name = tc.Function.Name - toolCallArgs[idx] = &strings.Builder{} - - events <- provider.StreamEvent{ - Type: provider.StreamEventToolStart, - ToolCall: &provider.ToolCall{ - ID: tc.ID, - Name: tc.Function.Name, - }, - ToolIndex: idx, - } - } - - if tc.Function.Arguments != "" { - if b, ok := toolCallArgs[idx]; ok { - b.WriteString(tc.Function.Arguments) - } - events <- provider.StreamEvent{ - Type: provider.StreamEventToolDelta, - ToolIndex: idx, - ToolCall: &provider.ToolCall{ - Arguments: tc.Function.Arguments, - }, - } - } - } - } - } - - if err := stream.Err(); err != nil { - return fmt.Errorf("openai stream error: %w", err) - } - - // Finalize tool calls - for idx := range toolCalls { - if b, ok := toolCallArgs[idx]; ok { - toolCalls[idx].Arguments = b.String() - } - events <- provider.StreamEvent{ - Type: provider.StreamEventToolEnd, - ToolIndex: idx, - ToolCall: &toolCalls[idx], - } - } - - // Send done event - events <- provider.StreamEvent{ - Type: provider.StreamEventDone, - Response: &provider.Response{ - Text: fullText.String(), - ToolCalls: toolCalls, - Usage: usage, - }, - } - - return nil -} - -func (p *Provider) buildRequest(req provider.Request) openai.ChatCompletionNewParams { - oaiReq := openai.ChatCompletionNewParams{ - Model: req.Model, - } - - for _, msg := range req.Messages { - oaiReq.Messages = append(oaiReq.Messages, convertMessage(msg, req.Model)) - } - - for _, tool := range req.Tools { - oaiReq.Tools = append(oaiReq.Tools, openai.ChatCompletionToolParam{ - Type: "function", - Function: shared.FunctionDefinitionParam{ - Name: tool.Name, - Description: openai.String(tool.Description), - Parameters: openai.FunctionParameters(tool.Schema), - }, - }) - } - - if req.Temperature != nil { - // o* and gpt-5* models don't support custom temperatures - if !strings.HasPrefix(req.Model, "o") && !strings.HasPrefix(req.Model, "gpt-5") { - oaiReq.Temperature = openai.Float(*req.Temperature) - } - } - - if req.MaxTokens != nil { - oaiReq.MaxCompletionTokens = openai.Int(int64(*req.MaxTokens)) - } - - if req.TopP != nil { - oaiReq.TopP = openai.Float(*req.TopP) - } - - if len(req.Stop) > 0 { - oaiReq.Stop = openai.ChatCompletionNewParamsStopUnion{OfString: openai.String(req.Stop[0])} - } - - return oaiReq -} - -func convertMessage(msg provider.Message, model string) openai.ChatCompletionMessageParamUnion { - var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam - var textContent param.Opt[string] - - for _, img := range msg.Images { - var url string - if img.Base64 != "" { - url = "data:" + img.ContentType + ";base64," + img.Base64 - } else if img.URL != "" { - url = img.URL - } - if url != "" { - arrayOfContentParts = append(arrayOfContentParts, - openai.ChatCompletionContentPartUnionParam{ - OfImageURL: &openai.ChatCompletionContentPartImageParam{ - ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ - URL: url, - }, - }, - }, - ) - } - } - - for _, aud := range msg.Audio { - var b64Data string - var format string - - if aud.Base64 != "" { - b64Data = aud.Base64 - format = audioFormat(aud.ContentType) - } else if aud.URL != "" { - resp, err := http.Get(aud.URL) - if err != nil { - continue - } - data, err := io.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - continue - } - b64Data = base64.StdEncoding.EncodeToString(data) - ct := resp.Header.Get("Content-Type") - if ct == "" { - ct = aud.ContentType - } - if ct == "" { - ct = audioFormatFromURL(aud.URL) - } - format = audioFormat(ct) - } - - if b64Data != "" && format != "" { - arrayOfContentParts = append(arrayOfContentParts, - openai.ChatCompletionContentPartUnionParam{ - OfInputAudio: &openai.ChatCompletionContentPartInputAudioParam{ - InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{ - Data: b64Data, - Format: format, - }, - }, - }, - ) - } - } - - if msg.Content != "" { - if len(arrayOfContentParts) > 0 { - arrayOfContentParts = append(arrayOfContentParts, - openai.ChatCompletionContentPartUnionParam{ - OfText: &openai.ChatCompletionContentPartTextParam{ - Text: msg.Content, - }, - }, - ) - } else { - textContent = openai.String(msg.Content) - } - } - - // Determine if this model uses developer messages instead of system - useDeveloper := false - parts := strings.Split(model, "-") - if len(parts) > 1 && len(parts[0]) > 0 && parts[0][0] == 'o' { - useDeveloper = true - } - - switch msg.Role { - case "system": - if useDeveloper { - return openai.ChatCompletionMessageParamUnion{ - OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{ - Content: openai.ChatCompletionDeveloperMessageParamContentUnion{ - OfString: textContent, - }, - }, - } - } - return openai.ChatCompletionMessageParamUnion{ - OfSystem: &openai.ChatCompletionSystemMessageParam{ - Content: openai.ChatCompletionSystemMessageParamContentUnion{ - OfString: textContent, - }, - }, - } - - case "user": - return openai.ChatCompletionMessageParamUnion{ - OfUser: &openai.ChatCompletionUserMessageParam{ - Content: openai.ChatCompletionUserMessageParamContentUnion{ - OfString: textContent, - OfArrayOfContentParts: arrayOfContentParts, - }, - }, - } - - case "assistant": - as := &openai.ChatCompletionAssistantMessageParam{} - if msg.Content != "" { - as.Content.OfString = openai.String(msg.Content) - } - for _, tc := range msg.ToolCalls { - as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{ - ID: tc.ID, - Function: openai.ChatCompletionMessageToolCallFunctionParam{ - Name: tc.Name, - Arguments: tc.Arguments, - }, - }) - } - return openai.ChatCompletionMessageParamUnion{OfAssistant: as} - - case "tool": - return openai.ChatCompletionMessageParamUnion{ - OfTool: &openai.ChatCompletionToolMessageParam{ - ToolCallID: msg.ToolCallID, - Content: openai.ChatCompletionToolMessageParamContentUnion{ - OfString: openai.String(msg.Content), - }, - }, - } - } - - // Fallback to user message - return openai.ChatCompletionMessageParamUnion{ - OfUser: &openai.ChatCompletionUserMessageParam{ - Content: openai.ChatCompletionUserMessageParamContentUnion{ - OfString: textContent, - }, - }, - } -} - -func (p *Provider) convertResponse(resp *openai.ChatCompletion) provider.Response { - var res provider.Response - - if resp == nil || len(resp.Choices) == 0 { - return res - } - - choice := resp.Choices[0] - res.Text = choice.Message.Content - - for _, tc := range choice.Message.ToolCalls { - res.ToolCalls = append(res.ToolCalls, provider.ToolCall{ - ID: tc.ID, - Name: tc.Function.Name, - Arguments: strings.TrimSpace(tc.Function.Arguments), - }) - } - - if resp.Usage.TotalTokens > 0 { - res.Usage = &provider.Usage{ - InputTokens: int(resp.Usage.PromptTokens), - OutputTokens: int(resp.Usage.CompletionTokens), - TotalTokens: int(resp.Usage.TotalTokens), - } - res.Usage.Details = extractUsageDetails(resp.Usage) - } - - return res -} - -// audioFormat converts a MIME type to an OpenAI audio format string ("wav" or "mp3"). -func audioFormat(contentType string) string { - ct := strings.ToLower(contentType) - switch { - case strings.Contains(ct, "wav"): - return "wav" - case strings.Contains(ct, "mp3"), strings.Contains(ct, "mpeg"): - return "mp3" - default: - return "wav" - } -} - -// extractUsageDetails extracts provider-specific detail tokens from an OpenAI CompletionUsage. -func extractUsageDetails(usage openai.CompletionUsage) map[string]int { - details := map[string]int{} - if usage.CompletionTokensDetails.ReasoningTokens > 0 { - details[provider.UsageDetailReasoningTokens] = int(usage.CompletionTokensDetails.ReasoningTokens) - } - if usage.CompletionTokensDetails.AudioTokens > 0 { - details[provider.UsageDetailAudioOutputTokens] = int(usage.CompletionTokensDetails.AudioTokens) - } - if usage.PromptTokensDetails.CachedTokens > 0 { - details[provider.UsageDetailCachedInputTokens] = int(usage.PromptTokensDetails.CachedTokens) - } - if usage.PromptTokensDetails.AudioTokens > 0 { - details[provider.UsageDetailAudioInputTokens] = int(usage.PromptTokensDetails.AudioTokens) - } - if len(details) == 0 { - return nil - } - return details -} - -// audioFormatFromURL guesses the audio format from a URL's file extension. -func audioFormatFromURL(u string) string { - ext := strings.ToLower(path.Ext(u)) - switch ext { - case ".mp3": - return "audio/mp3" - case ".wav": - return "audio/wav" - default: - return "audio/wav" - } +// restrictTemperature reports whether OpenAI rejects a user-supplied +// temperature for this model. o-series reasoning models and gpt-5* both do. +func restrictTemperature(model string) bool { + return strings.HasPrefix(model, "o") || strings.HasPrefix(model, "gpt-5") } diff --git a/v2/openaicompat/openaicompat.go b/v2/openaicompat/openaicompat.go new file mode 100644 index 0000000..690b358 --- /dev/null +++ b/v2/openaicompat/openaicompat.go @@ -0,0 +1,537 @@ +// Package openaicompat implements a shared chat-completion Provider for any +// service that speaks the OpenAI Chat Completions API (OpenAI itself, DeepSeek, +// Moonshot, xAI, Groq, Ollama, and friends). +// +// Most providers differ from vanilla OpenAI only in endpoint URL and a handful +// of per-model quirks (e.g., "this model is text-only", "this model doesn't +// accept tools", "drop temperature on reasoning models"). Those quirks are +// captured declaratively via Rules, so a concrete provider package is usually +// a one-function wrapper that calls New with its own base URL and Rules. +package openaicompat + +import ( + "context" + "encoding/base64" + "fmt" + "io" + "net/http" + "path" + "strings" + + "github.com/openai/openai-go" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/shared" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +// Rules encodes provider-specific constraints on top of the OpenAI wire +// protocol. The zero value means "no restrictions" and behaves like vanilla +// OpenAI. Individual fields are documented inline. +type Rules struct { + // MaxImagesPerMessage rejects requests whose any single message carries + // more images than this cap. 0 means "no cap". + MaxImagesPerMessage int + + // MaxAudioPerMessage rejects requests whose any single message carries + // more audio attachments than this cap. 0 means "no cap". + MaxAudioPerMessage int + + // SupportsVision, when non-nil, is consulted for every request that + // includes any image attachments. If it returns false for the request's + // model, the call fails with a FeatureUnsupportedError before hitting + // the network. + SupportsVision func(model string) bool + + // SupportsTools, when non-nil, is consulted for every request that + // includes any tool definitions. If it returns false for the model, + // the call fails with a FeatureUnsupportedError before hitting the + // network. + SupportsTools func(model string) bool + + // SupportsAudio, when non-nil, is consulted for every request that + // includes any audio attachments. If it returns false for the model, + // the call fails with a FeatureUnsupportedError. + SupportsAudio func(model string) bool + + // RestrictTemperature, when non-nil and returning true for the request's + // model, causes the Temperature field to be silently dropped from the + // outgoing request. Used by OpenAI o-series and gpt-5* which reject a + // user-provided temperature. + RestrictTemperature func(model string) bool + + // CustomizeRequest is a last-mile hook invoked after buildRequest but + // before the call is sent. It receives the fully built OpenAI SDK + // parameters and may mutate them freely (add headers, flip flags, tweak + // response_format, etc.). + CustomizeRequest func(params *openai.ChatCompletionNewParams) +} + +// FeatureUnsupportedError is returned when a Rules predicate rejects a request +// because the target model does not support a feature the caller included. +type FeatureUnsupportedError struct { + Feature string + Model string +} + +func (e *FeatureUnsupportedError) Error() string { + return fmt.Sprintf("openaicompat: model %q does not support %s", e.Model, e.Feature) +} + +// Provider implements provider.Provider for any OpenAI-compatible endpoint. +type Provider struct { + apiKey string + baseURL string + rules Rules +} + +// New creates a Provider. baseURL may be empty to let the OpenAI SDK use its +// default; in practice concrete provider packages always pass a default. +func New(apiKey, baseURL string, rules Rules) *Provider { + return &Provider{apiKey: apiKey, baseURL: baseURL, rules: rules} +} + +// Complete performs a non-streaming completion. +func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) { + if err := p.checkRules(req); err != nil { + return provider.Response{}, err + } + + cl := openai.NewClient(p.requestOptions()...) + oaiReq := p.buildRequest(req) + if p.rules.CustomizeRequest != nil { + p.rules.CustomizeRequest(&oaiReq) + } + + resp, err := cl.Chat.Completions.New(ctx, oaiReq) + if err != nil { + return provider.Response{}, fmt.Errorf("openai completion error: %w", err) + } + + return p.convertResponse(resp), nil +} + +// Stream performs a streaming completion. +func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error { + if err := p.checkRules(req); err != nil { + return err + } + + cl := openai.NewClient(p.requestOptions()...) + oaiReq := p.buildRequest(req) + oaiReq.StreamOptions = openai.ChatCompletionStreamOptionsParam{ + IncludeUsage: openai.Bool(true), + } + if p.rules.CustomizeRequest != nil { + p.rules.CustomizeRequest(&oaiReq) + } + + stream := cl.Chat.Completions.NewStreaming(ctx, oaiReq) + + var fullText strings.Builder + var toolCalls []provider.ToolCall + toolCallArgs := map[int]*strings.Builder{} + var usage *provider.Usage + + for stream.Next() { + chunk := stream.Current() + + // Capture usage from the final chunk (present when StreamOptions.IncludeUsage is true) + if chunk.Usage.TotalTokens > 0 { + usage = &provider.Usage{ + InputTokens: int(chunk.Usage.PromptTokens), + OutputTokens: int(chunk.Usage.CompletionTokens), + TotalTokens: int(chunk.Usage.TotalTokens), + Details: extractUsageDetails(chunk.Usage), + } + } + + for _, choice := range chunk.Choices { + // Text delta + if choice.Delta.Content != "" { + fullText.WriteString(choice.Delta.Content) + events <- provider.StreamEvent{ + Type: provider.StreamEventText, + Text: choice.Delta.Content, + } + } + + // Tool call deltas + for _, tc := range choice.Delta.ToolCalls { + idx := int(tc.Index) + + if tc.ID != "" { + // New tool call starting + for len(toolCalls) <= idx { + toolCalls = append(toolCalls, provider.ToolCall{}) + } + toolCalls[idx].ID = tc.ID + toolCalls[idx].Name = tc.Function.Name + toolCallArgs[idx] = &strings.Builder{} + + events <- provider.StreamEvent{ + Type: provider.StreamEventToolStart, + ToolCall: &provider.ToolCall{ + ID: tc.ID, + Name: tc.Function.Name, + }, + ToolIndex: idx, + } + } + + if tc.Function.Arguments != "" { + if b, ok := toolCallArgs[idx]; ok { + b.WriteString(tc.Function.Arguments) + } + events <- provider.StreamEvent{ + Type: provider.StreamEventToolDelta, + ToolIndex: idx, + ToolCall: &provider.ToolCall{ + Arguments: tc.Function.Arguments, + }, + } + } + } + } + } + + if err := stream.Err(); err != nil { + return fmt.Errorf("openai stream error: %w", err) + } + + // Finalize tool calls + for idx := range toolCalls { + if b, ok := toolCallArgs[idx]; ok { + toolCalls[idx].Arguments = b.String() + } + events <- provider.StreamEvent{ + Type: provider.StreamEventToolEnd, + ToolIndex: idx, + ToolCall: &toolCalls[idx], + } + } + + events <- provider.StreamEvent{ + Type: provider.StreamEventDone, + Response: &provider.Response{ + Text: fullText.String(), + ToolCalls: toolCalls, + Usage: usage, + }, + } + + return nil +} + +func (p *Provider) requestOptions() []option.RequestOption { + opts := []option.RequestOption{option.WithAPIKey(p.apiKey)} + if p.baseURL != "" { + opts = append(opts, option.WithBaseURL(p.baseURL)) + } + return opts +} + +// checkRules applies all Rules predicates against a request and returns an +// error if any constraint is violated. Runs before any network call. +func (p *Provider) checkRules(req provider.Request) error { + var hasImages, hasAudio bool + for _, msg := range req.Messages { + if len(msg.Images) > 0 { + hasImages = true + } + if len(msg.Audio) > 0 { + hasAudio = true + } + if p.rules.MaxImagesPerMessage > 0 && len(msg.Images) > p.rules.MaxImagesPerMessage { + return fmt.Errorf("openaicompat: message has %d images, max allowed is %d for model %q", + len(msg.Images), p.rules.MaxImagesPerMessage, req.Model) + } + if p.rules.MaxAudioPerMessage > 0 && len(msg.Audio) > p.rules.MaxAudioPerMessage { + return fmt.Errorf("openaicompat: message has %d audio attachments, max allowed is %d for model %q", + len(msg.Audio), p.rules.MaxAudioPerMessage, req.Model) + } + } + + if hasImages && p.rules.SupportsVision != nil && !p.rules.SupportsVision(req.Model) { + return &FeatureUnsupportedError{Feature: "vision", Model: req.Model} + } + if hasAudio && p.rules.SupportsAudio != nil && !p.rules.SupportsAudio(req.Model) { + return &FeatureUnsupportedError{Feature: "audio", Model: req.Model} + } + if len(req.Tools) > 0 && p.rules.SupportsTools != nil && !p.rules.SupportsTools(req.Model) { + return &FeatureUnsupportedError{Feature: "tools", Model: req.Model} + } + return nil +} + +func (p *Provider) buildRequest(req provider.Request) openai.ChatCompletionNewParams { + oaiReq := openai.ChatCompletionNewParams{ + Model: req.Model, + } + + for _, msg := range req.Messages { + oaiReq.Messages = append(oaiReq.Messages, convertMessage(msg, req.Model)) + } + + for _, tool := range req.Tools { + oaiReq.Tools = append(oaiReq.Tools, openai.ChatCompletionToolParam{ + Type: "function", + Function: shared.FunctionDefinitionParam{ + Name: tool.Name, + Description: openai.String(tool.Description), + Parameters: openai.FunctionParameters(tool.Schema), + }, + }) + } + + if req.Temperature != nil { + if p.rules.RestrictTemperature == nil || !p.rules.RestrictTemperature(req.Model) { + oaiReq.Temperature = openai.Float(*req.Temperature) + } + } + + if req.MaxTokens != nil { + oaiReq.MaxCompletionTokens = openai.Int(int64(*req.MaxTokens)) + } + + if req.TopP != nil { + oaiReq.TopP = openai.Float(*req.TopP) + } + + if len(req.Stop) > 0 { + oaiReq.Stop = openai.ChatCompletionNewParamsStopUnion{OfString: openai.String(req.Stop[0])} + } + + return oaiReq +} + +func convertMessage(msg provider.Message, model string) openai.ChatCompletionMessageParamUnion { + var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam + var textContent param.Opt[string] + + for _, img := range msg.Images { + var url string + if img.Base64 != "" { + url = "data:" + img.ContentType + ";base64," + img.Base64 + } else if img.URL != "" { + url = img.URL + } + if url != "" { + arrayOfContentParts = append(arrayOfContentParts, + openai.ChatCompletionContentPartUnionParam{ + OfImageURL: &openai.ChatCompletionContentPartImageParam{ + ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ + URL: url, + }, + }, + }, + ) + } + } + + for _, aud := range msg.Audio { + var b64Data string + var format string + + if aud.Base64 != "" { + b64Data = aud.Base64 + format = audioFormat(aud.ContentType) + } else if aud.URL != "" { + resp, err := http.Get(aud.URL) + if err != nil { + continue + } + data, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + continue + } + b64Data = base64.StdEncoding.EncodeToString(data) + ct := resp.Header.Get("Content-Type") + if ct == "" { + ct = aud.ContentType + } + if ct == "" { + ct = audioFormatFromURL(aud.URL) + } + format = audioFormat(ct) + } + + if b64Data != "" && format != "" { + arrayOfContentParts = append(arrayOfContentParts, + openai.ChatCompletionContentPartUnionParam{ + OfInputAudio: &openai.ChatCompletionContentPartInputAudioParam{ + InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{ + Data: b64Data, + Format: format, + }, + }, + }, + ) + } + } + + if msg.Content != "" { + if len(arrayOfContentParts) > 0 { + arrayOfContentParts = append(arrayOfContentParts, + openai.ChatCompletionContentPartUnionParam{ + OfText: &openai.ChatCompletionContentPartTextParam{ + Text: msg.Content, + }, + }, + ) + } else { + textContent = openai.String(msg.Content) + } + } + + // Determine if this model uses developer messages instead of system + useDeveloper := false + parts := strings.Split(model, "-") + if len(parts) > 1 && len(parts[0]) > 0 && parts[0][0] == 'o' { + useDeveloper = true + } + + switch msg.Role { + case "system": + if useDeveloper { + return openai.ChatCompletionMessageParamUnion{ + OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{ + Content: openai.ChatCompletionDeveloperMessageParamContentUnion{ + OfString: textContent, + }, + }, + } + } + return openai.ChatCompletionMessageParamUnion{ + OfSystem: &openai.ChatCompletionSystemMessageParam{ + Content: openai.ChatCompletionSystemMessageParamContentUnion{ + OfString: textContent, + }, + }, + } + + case "user": + return openai.ChatCompletionMessageParamUnion{ + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.ChatCompletionUserMessageParamContentUnion{ + OfString: textContent, + OfArrayOfContentParts: arrayOfContentParts, + }, + }, + } + + case "assistant": + as := &openai.ChatCompletionAssistantMessageParam{} + if msg.Content != "" { + as.Content.OfString = openai.String(msg.Content) + } + for _, tc := range msg.ToolCalls { + as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{ + ID: tc.ID, + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: tc.Name, + Arguments: tc.Arguments, + }, + }) + } + return openai.ChatCompletionMessageParamUnion{OfAssistant: as} + + case "tool": + return openai.ChatCompletionMessageParamUnion{ + OfTool: &openai.ChatCompletionToolMessageParam{ + ToolCallID: msg.ToolCallID, + Content: openai.ChatCompletionToolMessageParamContentUnion{ + OfString: openai.String(msg.Content), + }, + }, + } + } + + // Fallback to user message + return openai.ChatCompletionMessageParamUnion{ + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.ChatCompletionUserMessageParamContentUnion{ + OfString: textContent, + }, + }, + } +} + +func (p *Provider) convertResponse(resp *openai.ChatCompletion) provider.Response { + var res provider.Response + + if resp == nil || len(resp.Choices) == 0 { + return res + } + + choice := resp.Choices[0] + res.Text = choice.Message.Content + + for _, tc := range choice.Message.ToolCalls { + res.ToolCalls = append(res.ToolCalls, provider.ToolCall{ + ID: tc.ID, + Name: tc.Function.Name, + Arguments: strings.TrimSpace(tc.Function.Arguments), + }) + } + + if resp.Usage.TotalTokens > 0 { + res.Usage = &provider.Usage{ + InputTokens: int(resp.Usage.PromptTokens), + OutputTokens: int(resp.Usage.CompletionTokens), + TotalTokens: int(resp.Usage.TotalTokens), + } + res.Usage.Details = extractUsageDetails(resp.Usage) + } + + return res +} + +// audioFormat converts a MIME type to an OpenAI audio format string ("wav" or "mp3"). +func audioFormat(contentType string) string { + ct := strings.ToLower(contentType) + switch { + case strings.Contains(ct, "wav"): + return "wav" + case strings.Contains(ct, "mp3"), strings.Contains(ct, "mpeg"): + return "mp3" + default: + return "wav" + } +} + +// extractUsageDetails extracts provider-specific detail tokens from an OpenAI CompletionUsage. +func extractUsageDetails(usage openai.CompletionUsage) map[string]int { + details := map[string]int{} + if usage.CompletionTokensDetails.ReasoningTokens > 0 { + details[provider.UsageDetailReasoningTokens] = int(usage.CompletionTokensDetails.ReasoningTokens) + } + if usage.CompletionTokensDetails.AudioTokens > 0 { + details[provider.UsageDetailAudioOutputTokens] = int(usage.CompletionTokensDetails.AudioTokens) + } + if usage.PromptTokensDetails.CachedTokens > 0 { + details[provider.UsageDetailCachedInputTokens] = int(usage.PromptTokensDetails.CachedTokens) + } + if usage.PromptTokensDetails.AudioTokens > 0 { + details[provider.UsageDetailAudioInputTokens] = int(usage.PromptTokensDetails.AudioTokens) + } + if len(details) == 0 { + return nil + } + return details +} + +// audioFormatFromURL guesses the audio format from a URL's file extension. +func audioFormatFromURL(u string) string { + ext := strings.ToLower(path.Ext(u)) + switch ext { + case ".mp3": + return "audio/mp3" + case ".wav": + return "audio/wav" + default: + return "audio/wav" + } +} diff --git a/v2/openaicompat/openaicompat_test.go b/v2/openaicompat/openaicompat_test.go new file mode 100644 index 0000000..efd412a --- /dev/null +++ b/v2/openaicompat/openaicompat_test.go @@ -0,0 +1,313 @@ +package openaicompat_test + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/openai/openai-go" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +// newTestServer returns a httptest server that captures the raw request body +// on POST /chat/completions and returns a canned OpenAI response so Complete() +// succeeds. Use `captured` to assert on what the provider would send. +func newTestServer(t *testing.T) (*httptest.Server, *[]byte) { + t.Helper() + var body []byte + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/chat/completions" { + http.NotFound(w, r) + return + } + b, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + } + body = b + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{ + "id": "cmpl-1", + "object": "chat.completion", + "choices": [{ + "index": 0, + "message": {"role":"assistant","content":"ok"}, + "finish_reason": "stop" + }], + "usage": {"prompt_tokens":1,"completion_tokens":1,"total_tokens":2} + }`) + })) + return srv, &body +} + +func textReq(model, content string) provider.Request { + return provider.Request{ + Model: model, + Messages: []provider.Message{{Role: "user", Content: content}}, + } +} + +func TestComplete_ZeroRulesPassesThrough(t *testing.T) { + srv, body := newTestServer(t) + defer srv.Close() + + temp := 0.7 + req := textReq("gpt-4o", "hi") + req.Temperature = &temp + + p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{}) + resp, err := p.Complete(context.Background(), req) + if err != nil { + t.Fatalf("Complete: %v", err) + } + if resp.Text != "ok" { + t.Errorf("Text = %q, want %q", resp.Text, "ok") + } + + // Temperature should be present since RestrictTemperature is nil. + var parsed map[string]any + if err := json.Unmarshal(*body, &parsed); err != nil { + t.Fatalf("unmarshal request body: %v", err) + } + if _, ok := parsed["temperature"]; !ok { + t.Errorf("expected temperature in request body, got: %s", *body) + } +} + +func TestComplete_RestrictTemperatureDropsField(t *testing.T) { + srv, body := newTestServer(t) + defer srv.Close() + + temp := 0.7 + req := textReq("o1", "hi") + req.Temperature = &temp + + p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{ + RestrictTemperature: func(m string) bool { return strings.HasPrefix(m, "o") }, + }) + if _, err := p.Complete(context.Background(), req); err != nil { + t.Fatalf("Complete: %v", err) + } + + var parsed map[string]any + if err := json.Unmarshal(*body, &parsed); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if _, ok := parsed["temperature"]; ok { + t.Errorf("temperature should be dropped for o1, got: %s", *body) + } +} + +func TestComplete_SupportsVisionRejectsWhenFalse(t *testing.T) { + srv, _ := newTestServer(t) + defer srv.Close() + + req := provider.Request{ + Model: "deepseek-chat", + Messages: []provider.Message{{ + Role: "user", + Content: "describe", + Images: []provider.Image{{URL: "https://example.com/a.png"}}, + }}, + } + p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{ + SupportsVision: func(string) bool { return false }, + }) + _, err := p.Complete(context.Background(), req) + var fue *openaicompat.FeatureUnsupportedError + if !errors.As(err, &fue) { + t.Fatalf("want FeatureUnsupportedError, got %v", err) + } + if fue.Feature != "vision" || fue.Model != "deepseek-chat" { + t.Errorf("unexpected err: %+v", fue) + } +} + +func TestComplete_SupportsToolsRejectsWhenFalse(t *testing.T) { + srv, _ := newTestServer(t) + defer srv.Close() + + req := provider.Request{ + Model: "deepseek-reasoner", + Messages: []provider.Message{{Role: "user", Content: "hi"}}, + Tools: []provider.ToolDef{ + {Name: "get_weather", Description: "weather", Schema: map[string]any{"type": "object"}}, + }, + } + p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{ + SupportsTools: func(m string) bool { return !strings.Contains(m, "reasoner") }, + }) + _, err := p.Complete(context.Background(), req) + var fue *openaicompat.FeatureUnsupportedError + if !errors.As(err, &fue) { + t.Fatalf("want FeatureUnsupportedError, got %v", err) + } + if fue.Feature != "tools" { + t.Errorf("feature = %q, want tools", fue.Feature) + } +} + +func TestComplete_SupportsAudioRejectsWhenFalse(t *testing.T) { + srv, _ := newTestServer(t) + defer srv.Close() + + req := provider.Request{ + Model: "groq-llama", + Messages: []provider.Message{{ + Role: "user", + Audio: []provider.Audio{{Base64: "AAA=", ContentType: "audio/wav"}}, + }}, + } + p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{ + SupportsAudio: func(string) bool { return false }, + }) + _, err := p.Complete(context.Background(), req) + var fue *openaicompat.FeatureUnsupportedError + if !errors.As(err, &fue) { + t.Fatalf("want FeatureUnsupportedError, got %v", err) + } + if fue.Feature != "audio" { + t.Errorf("feature = %q, want audio", fue.Feature) + } +} + +func TestComplete_MaxImagesPerMessage(t *testing.T) { + srv, _ := newTestServer(t) + defer srv.Close() + + req := provider.Request{ + Model: "anything", + Messages: []provider.Message{{ + Role: "user", + Images: []provider.Image{ + {URL: "a"}, {URL: "b"}, {URL: "c"}, + }, + }}, + } + p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{MaxImagesPerMessage: 2}) + _, err := p.Complete(context.Background(), req) + if err == nil || !strings.Contains(err.Error(), "max allowed is 2") { + t.Fatalf("want max-images error, got %v", err) + } + + // Exactly at limit succeeds. + req.Messages[0].Images = req.Messages[0].Images[:2] + if _, err := p.Complete(context.Background(), req); err != nil { + t.Errorf("at-limit request should succeed, got %v", err) + } +} + +func TestComplete_CustomizeRequestInvoked(t *testing.T) { + srv, body := newTestServer(t) + defer srv.Close() + + called := false + p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{ + CustomizeRequest: func(params *openai.ChatCompletionNewParams) { + called = true + // Confirm we receive a non-empty built request. + if params.Model != "gpt-4o" { + t.Errorf("CustomizeRequest saw model %q, want gpt-4o", params.Model) + } + // Mutation here should end up on the wire. + params.User = openai.String("test-user") + }, + }) + if _, err := p.Complete(context.Background(), textReq("gpt-4o", "hi")); err != nil { + t.Fatalf("Complete: %v", err) + } + if !called { + t.Fatal("CustomizeRequest hook was not invoked") + } + if !strings.Contains(string(*body), `"user":"test-user"`) { + t.Errorf("mutation from CustomizeRequest not reflected on wire: %s", *body) + } +} + +func TestStream_EmitsDoneAndText(t *testing.T) { + // SSE stream with one content chunk then [DONE]. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + for _, line := range []string{ + `data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"hel"}}]}`, + `data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"lo"}}]}`, + `data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}`, + `data: [DONE]`, + } { + _, _ = io.WriteString(w, line+"\n\n") + if flusher != nil { + flusher.Flush() + } + } + })) + defer srv.Close() + + p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{}) + events := make(chan provider.StreamEvent, 16) + go func() { + _ = p.Stream(context.Background(), textReq("gpt-4o", "hi"), events) + close(events) + }() + + var text strings.Builder + var sawDone bool + var doneUsage *provider.Usage + for ev := range events { + switch ev.Type { + case provider.StreamEventText: + text.WriteString(ev.Text) + case provider.StreamEventDone: + sawDone = true + if ev.Response != nil { + doneUsage = ev.Response.Usage + } + } + } + if text.String() != "hello" { + t.Errorf("got text %q, want %q", text.String(), "hello") + } + if !sawDone { + t.Fatal("no Done event emitted") + } + if doneUsage == nil || doneUsage.TotalTokens != 3 { + t.Errorf("usage on Done = %+v, want TotalTokens=3", doneUsage) + } +} + +func TestStream_RulesCheckedBeforeNetwork(t *testing.T) { + // Server should never be hit when rules reject up front. + hit := false + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hit = true + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{ + SupportsVision: func(string) bool { return false }, + }) + req := provider.Request{ + Model: "no-vision-model", + Messages: []provider.Message{{ + Role: "user", + Images: []provider.Image{{URL: "a"}}, + }}, + } + events := make(chan provider.StreamEvent, 4) + err := p.Stream(context.Background(), req, events) + var fue *openaicompat.FeatureUnsupportedError + if !errors.As(err, &fue) { + t.Fatalf("want FeatureUnsupportedError, got %v", err) + } + if hit { + t.Error("server was contacted despite Rules violation") + } +} diff --git a/v2/registry.go b/v2/registry.go new file mode 100644 index 0000000..58ea1f5 --- /dev/null +++ b/v2/registry.go @@ -0,0 +1,158 @@ +package llm + +import ( + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/deepseek" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/groq" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/moonshot" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/ollama" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openai" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/xai" +) + +// ProviderInfo describes a registered provider for discovery purposes (CLI +// pickers, wiring layers, admin tools). It is the single source of truth for +// "what providers exist and how do I instantiate one." +type ProviderInfo struct { + // Name is the short lowercase identifier used in provider/model strings + // (e.g., "openai", "deepseek", "moonshot"). + Name string + + // DisplayName is a human-readable label for UIs. + DisplayName string + + // EnvKey is the conventional environment variable that holds the API key + // for this provider. Empty string means "no key needed" (e.g., Ollama). + EnvKey string + + // DefaultURL is the default base URL used when no override is supplied. + DefaultURL string + + // Models is a list of well-known model names, populated for CLI pickers + // and similar. It is not exhaustive and not validated against the API. + Models []string + + // New returns a ready-to-use Client for this provider, given an API key + // (ignored for key-less providers like Ollama) and optional ClientOptions. + New func(apiKey string, opts ...ClientOption) *Client +} + +// providerRegistry is the in-process list of known providers. Order is +// intentional: the three original providers first, then OpenAI-compatible +// additions in the order they were added. +var providerRegistry = []ProviderInfo{ + { + Name: "openai", + DisplayName: "OpenAI", + EnvKey: "OPENAI_API_KEY", + DefaultURL: openai.DefaultBaseURL, + Models: []string{ + "gpt-4.1", "gpt-4.1-mini", "gpt-4.1-nano", + "gpt-4o", "gpt-4o-mini", + "gpt-4-turbo", "gpt-3.5-turbo", + "o1", "o1-mini", "o1-preview", "o3-mini", + }, + New: OpenAI, + }, + { + Name: "anthropic", + DisplayName: "Anthropic", + EnvKey: "ANTHROPIC_API_KEY", + DefaultURL: "https://api.anthropic.com", + Models: []string{ + "claude-opus-4-7", + "claude-sonnet-4-6", + "claude-haiku-4-5-20251001", + "claude-opus-4-20250514", + "claude-sonnet-4-20250514", + "claude-3-7-sonnet-20250219", + "claude-3-5-sonnet-20241022", + "claude-3-5-haiku-20241022", + }, + New: Anthropic, + }, + { + Name: "google", + DisplayName: "Google", + EnvKey: "GOOGLE_API_KEY", + DefaultURL: "https://generativelanguage.googleapis.com", + Models: []string{ + "gemini-2.0-flash", "gemini-2.0-flash-lite", + "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b", + }, + New: Google, + }, + { + Name: "deepseek", + DisplayName: "DeepSeek", + EnvKey: "DEEPSEEK_API_KEY", + DefaultURL: deepseek.DefaultBaseURL, + Models: []string{"deepseek-chat", "deepseek-reasoner"}, + New: DeepSeek, + }, + { + Name: "moonshot", + DisplayName: "Moonshot (Kimi)", + EnvKey: "MOONSHOT_API_KEY", + DefaultURL: moonshot.DefaultBaseURL, + Models: []string{ + "kimi-k2-0711-preview", + "moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k", + "moonshot-v1-8k-vision-preview", + }, + New: Moonshot, + }, + { + Name: "xai", + DisplayName: "xAI (Grok)", + EnvKey: "XAI_API_KEY", + DefaultURL: xai.DefaultBaseURL, + Models: []string{ + "grok-2", "grok-2-mini", "grok-2-vision", "grok-beta", + }, + New: XAI, + }, + { + Name: "groq", + DisplayName: "Groq", + EnvKey: "GROQ_API_KEY", + DefaultURL: groq.DefaultBaseURL, + Models: []string{ + "llama-3.3-70b-versatile", + "llama-3.1-8b-instant", + "mixtral-8x7b-32768", + "gemma2-9b-it", + "llama-3.2-90b-vision-preview", + }, + New: Groq, + }, + { + Name: "ollama", + DisplayName: "Ollama (local)", + EnvKey: "", // no key needed + DefaultURL: ollama.DefaultBaseURL, + Models: []string{ + "llama3.2", "llama3.1", "qwen2.5", "mistral", "gemma2", "phi4", + }, + New: func(_ string, opts ...ClientOption) *Client { return Ollama(opts...) }, + }, +} + +// Providers returns a copy of the registered provider list so callers cannot +// mutate library state. +func Providers() []ProviderInfo { + out := make([]ProviderInfo, len(providerRegistry)) + copy(out, providerRegistry) + return out +} + +// ProviderByName returns the registered ProviderInfo with the given name, or +// nil if no such provider is registered. Name matching is exact. +func ProviderByName(name string) *ProviderInfo { + for i := range providerRegistry { + if providerRegistry[i].Name == name { + p := providerRegistry[i] + return &p + } + } + return nil +} diff --git a/v2/xai/xai.go b/v2/xai/xai.go new file mode 100644 index 0000000..500fd4c --- /dev/null +++ b/v2/xai/xai.go @@ -0,0 +1,29 @@ +// Package xai implements the go-llm v2 provider interface for xAI (Grok, +// https://x.ai/api). xAI speaks OpenAI Chat Completions, so this package is a +// thin wrapper over openaicompat with its own defaults and per-model Rules. +package xai + +import ( + "strings" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" +) + +// DefaultBaseURL is the public xAI API endpoint. +const DefaultBaseURL = "https://api.x.ai/v1" + +// Provider is a type alias over openaicompat.Provider. +type Provider = openaicompat.Provider + +// New creates a new xAI provider. An empty baseURL uses DefaultBaseURL. +func New(apiKey, baseURL string) *Provider { + if baseURL == "" { + baseURL = DefaultBaseURL + } + return openaicompat.New(apiKey, baseURL, openaicompat.Rules{ + // Grok models whose name contains "vision" accept images; others don't. + SupportsVision: func(m string) bool { + return strings.Contains(m, "vision") + }, + }) +} diff --git a/v2/xai/xai_test.go b/v2/xai/xai_test.go new file mode 100644 index 0000000..bed5b6b --- /dev/null +++ b/v2/xai/xai_test.go @@ -0,0 +1,33 @@ +package xai_test + +import ( + "context" + "errors" + "testing" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/xai" +) + +func TestNew_Basic(t *testing.T) { + if p := xai.New("key", ""); p == nil { + t.Fatal("New returned nil") + } +} + +func TestRules_Grok2RejectsImages(t *testing.T) { + p := xai.New("key", "") + req := provider.Request{ + Model: "grok-2", + Messages: []provider.Message{{ + Role: "user", + Images: []provider.Image{{URL: "a"}}, + }}, + } + _, err := p.Complete(context.Background(), req) + var fue *openaicompat.FeatureUnsupportedError + if !errors.As(err, &fue) || fue.Feature != "vision" { + t.Fatalf("want FeatureUnsupportedError(vision), got %v", err) + } +}