feat: add DeepSeek, Moonshot, xAI, Groq, Ollama; drop v1; migrate TUI to v2
Five OpenAI-compatible providers join the library as first-class constructors (llm.DeepSeek, llm.Moonshot, llm.XAI, llm.Groq, llm.Ollama). Their wire-level implementation is shared via a new v2/openaicompat package which is the extracted guts of the old v2/openai provider; each provider supplies its own Rules value to declare per-model constraints (e.g., DeepSeek Reasoner rejects tools and temperature, Moonshot/xAI accept images only on *-vision* models, Groq rejects audio input). v2/openai itself becomes a thin wrapper that sets RestrictTemperature for o-series and gpt-5 models. A new provider registry (v2/registry.go) exposes llm.Providers() and drives the TUI's provider picker so adding a provider in future is a single-file change. The TUI at cmd/llm was migrated from v1 to v2 and moved to v2/cmd/llm. With nothing else depending on v1, the v1 code at the repo root (all .go files, schema/, internal/, provider/, root go.mod/go.sum) is deleted. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,27 @@
|
||||
# go-llm CLI environment variables
|
||||
# Copy this file to .env and fill in the keys for providers you use.
|
||||
|
||||
# OpenAI API Key (https://platform.openai.com/api-keys)
|
||||
OPENAI_API_KEY=
|
||||
|
||||
# Anthropic API Key (https://console.anthropic.com/settings/keys)
|
||||
ANTHROPIC_API_KEY=
|
||||
|
||||
# Google AI API Key (https://aistudio.google.com/apikey)
|
||||
GOOGLE_API_KEY=
|
||||
|
||||
# DeepSeek API Key (https://platform.deepseek.com)
|
||||
DEEPSEEK_API_KEY=
|
||||
|
||||
# Moonshot / Kimi API Key (https://platform.moonshot.ai)
|
||||
MOONSHOT_API_KEY=
|
||||
|
||||
# xAI / Grok API Key (https://x.ai/api)
|
||||
XAI_API_KEY=
|
||||
|
||||
# Groq API Key (https://console.groq.com/keys)
|
||||
GROQ_API_KEY=
|
||||
|
||||
# Ollama runs locally with no API key required.
|
||||
# Override the endpoint if you're not using localhost:11434.
|
||||
# OLLAMA_BASE_URL=http://localhost:11434/v1
|
||||
@@ -0,0 +1,136 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
|
||||
)
|
||||
|
||||
// Message types for async operations.
|
||||
|
||||
// ChatResponseMsg contains the response from a chat completion.
|
||||
type ChatResponseMsg struct {
|
||||
Response llm.Response
|
||||
Err error
|
||||
}
|
||||
|
||||
// ToolExecutionMsg contains results from executing tool calls, one Message
|
||||
// (RoleTool) per ToolCall, in the same order.
|
||||
type ToolExecutionMsg struct {
|
||||
Results []llm.Message
|
||||
Err error
|
||||
}
|
||||
|
||||
// ImageLoadedMsg contains a loaded image.
|
||||
type ImageLoadedMsg struct {
|
||||
Image llm.Image
|
||||
Err error
|
||||
}
|
||||
|
||||
// sendChatRequest sends a completion request with the current conversation,
|
||||
// returning a ChatResponseMsg tea.Msg when the provider responds.
|
||||
func sendChatRequest(model *llm.Model, messages []llm.Message, toolbox *llm.ToolBox, toolsEnabled bool, temperature *float64) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
opts := buildOpts(toolbox, toolsEnabled, temperature)
|
||||
resp, err := model.Complete(context.Background(), messages, opts...)
|
||||
return ChatResponseMsg{Response: resp, Err: err}
|
||||
}
|
||||
}
|
||||
|
||||
// executeTools runs each tool call via the toolbox and returns ToolExecutionMsg
|
||||
// with one RoleTool Message per call, in the same order.
|
||||
func executeTools(toolbox *llm.ToolBox, calls []llm.ToolCall) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
ctx := context.Background()
|
||||
results, err := toolbox.ExecuteAll(ctx, calls)
|
||||
return ToolExecutionMsg{Results: results, Err: err}
|
||||
}
|
||||
}
|
||||
|
||||
// buildOpts constructs RequestOptions from the current CLI state.
|
||||
func buildOpts(toolbox *llm.ToolBox, toolsEnabled bool, temperature *float64) []llm.RequestOption {
|
||||
var opts []llm.RequestOption
|
||||
if toolsEnabled && toolbox != nil && len(toolbox.AllTools()) > 0 {
|
||||
opts = append(opts, llm.WithTools(toolbox))
|
||||
}
|
||||
if temperature != nil {
|
||||
opts = append(opts, llm.WithTemperature(*temperature))
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
// loadImageFromPath loads an image from a file path.
|
||||
func loadImageFromPath(path string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
path = strings.TrimSpace(path)
|
||||
path = strings.Trim(path, "\"'")
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return ImageLoadedMsg{Err: fmt.Errorf("failed to read image file: %w", err)}
|
||||
}
|
||||
|
||||
contentType := http.DetectContentType(data)
|
||||
if !strings.HasPrefix(contentType, "image/") {
|
||||
return ImageLoadedMsg{Err: fmt.Errorf("file is not an image: %s", contentType)}
|
||||
}
|
||||
|
||||
return ImageLoadedMsg{
|
||||
Image: llm.Image{
|
||||
Base64: base64.StdEncoding.EncodeToString(data),
|
||||
ContentType: contentType,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// loadImageFromURL loads an image from a URL (kept as URL, not fetched).
|
||||
func loadImageFromURL(url string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
return ImageLoadedMsg{Image: llm.Image{URL: strings.TrimSpace(url)}}
|
||||
}
|
||||
}
|
||||
|
||||
// loadImageFromBase64 loads an image from base64 data (raw or data: URL).
|
||||
func loadImageFromBase64(data string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
data = strings.TrimSpace(data)
|
||||
|
||||
if strings.HasPrefix(data, "data:") {
|
||||
parts := strings.SplitN(data, ",", 2)
|
||||
if len(parts) != 2 {
|
||||
return ImageLoadedMsg{Err: fmt.Errorf("invalid data URL format")}
|
||||
}
|
||||
mediaType := strings.TrimPrefix(parts[0], "data:")
|
||||
mediaType = strings.TrimSuffix(mediaType, ";base64")
|
||||
return ImageLoadedMsg{
|
||||
Image: llm.Image{
|
||||
Base64: parts[1],
|
||||
ContentType: mediaType,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return ImageLoadedMsg{Err: fmt.Errorf("invalid base64 data: %w", err)}
|
||||
}
|
||||
contentType := http.DetectContentType(decoded)
|
||||
if !strings.HasPrefix(contentType, "image/") {
|
||||
return ImageLoadedMsg{Err: fmt.Errorf("data is not an image: %s", contentType)}
|
||||
}
|
||||
return ImageLoadedMsg{
|
||||
Image: llm.Image{
|
||||
Base64: data,
|
||||
ContentType: contentType,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Load .env file if it exists (ignore error if not found)
|
||||
_ = godotenv.Load()
|
||||
|
||||
p := tea.NewProgram(
|
||||
InitialModel(),
|
||||
tea.WithAltScreen(),
|
||||
tea.WithMouseCellMotion(),
|
||||
)
|
||||
|
||||
if _, err := p.Run(); err != nil {
|
||||
fmt.Printf("Error running program: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,245 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
|
||||
)
|
||||
|
||||
// State represents the current view/screen of the application.
|
||||
type State int
|
||||
|
||||
const (
|
||||
StateChat State = iota
|
||||
StateProviderSelect
|
||||
StateModelSelect
|
||||
StateImageInput
|
||||
StateToolsPanel
|
||||
StateSettings
|
||||
StateAPIKeyInput
|
||||
)
|
||||
|
||||
// DisplayMessage represents a message for display in the UI.
|
||||
type DisplayMessage struct {
|
||||
Role llm.Role
|
||||
Content string
|
||||
Images int // number of images attached
|
||||
}
|
||||
|
||||
// ProviderEntry is a CLI-local view of a registered provider, enriched with
|
||||
// UI state (which model is currently chosen, whether we have a key, etc.).
|
||||
type ProviderEntry struct {
|
||||
Info llm.ProviderInfo
|
||||
HasAPIKey bool
|
||||
ModelIndex int
|
||||
}
|
||||
|
||||
// Model is the main Bubble Tea model.
|
||||
type Model struct {
|
||||
// State
|
||||
state State
|
||||
previousState State
|
||||
|
||||
// Provider
|
||||
client *llm.Client
|
||||
chat *llm.Model
|
||||
providerName string
|
||||
modelName string
|
||||
apiKeys map[string]string
|
||||
providers []ProviderEntry
|
||||
providerIndex int
|
||||
|
||||
// Conversation
|
||||
conversation []llm.Message
|
||||
messages []DisplayMessage
|
||||
|
||||
// Tools
|
||||
toolbox *llm.ToolBox
|
||||
toolsEnabled bool
|
||||
|
||||
// Settings
|
||||
systemPrompt string
|
||||
temperature *float64
|
||||
|
||||
// Pending images
|
||||
pendingImages []llm.Image
|
||||
|
||||
// UI Components
|
||||
input textinput.Model
|
||||
viewport viewport.Model
|
||||
viewportReady bool
|
||||
|
||||
// Selection state (for lists)
|
||||
listIndex int
|
||||
listItems []string
|
||||
|
||||
// Dimensions
|
||||
width int
|
||||
height int
|
||||
|
||||
// Loading state
|
||||
loading bool
|
||||
err error
|
||||
|
||||
// For API key input
|
||||
apiKeyInput textinput.Model
|
||||
}
|
||||
|
||||
// InitialModel creates and returns the initial model.
|
||||
func InitialModel() Model {
|
||||
ti := textinput.New()
|
||||
ti.Placeholder = "Type your message..."
|
||||
ti.Focus()
|
||||
ti.CharLimit = 4096
|
||||
ti.Width = 60
|
||||
|
||||
aki := textinput.New()
|
||||
aki.Placeholder = "Enter API key..."
|
||||
aki.CharLimit = 256
|
||||
aki.Width = 60
|
||||
aki.EchoMode = textinput.EchoPassword
|
||||
|
||||
// Build provider list from the go-llm registry.
|
||||
registry := llm.Providers()
|
||||
providers := make([]ProviderEntry, len(registry))
|
||||
apiKeys := make(map[string]string)
|
||||
|
||||
for i, info := range registry {
|
||||
entry := ProviderEntry{Info: info}
|
||||
if info.EnvKey == "" {
|
||||
// Key-less provider (e.g., Ollama).
|
||||
entry.HasAPIKey = true
|
||||
} else if key := os.Getenv(info.EnvKey); key != "" {
|
||||
apiKeys[info.Name] = key
|
||||
entry.HasAPIKey = true
|
||||
}
|
||||
providers[i] = entry
|
||||
}
|
||||
|
||||
m := Model{
|
||||
state: StateProviderSelect,
|
||||
input: ti,
|
||||
apiKeyInput: aki,
|
||||
apiKeys: apiKeys,
|
||||
providers: providers,
|
||||
systemPrompt: "You are a helpful assistant.",
|
||||
toolbox: createDemoToolbox(),
|
||||
toolsEnabled: false,
|
||||
messages: []DisplayMessage{},
|
||||
conversation: []llm.Message{},
|
||||
}
|
||||
|
||||
// Build list items for provider selection.
|
||||
m.listItems = make([]string, len(providers))
|
||||
for i, p := range providers {
|
||||
status := " (no key)"
|
||||
if p.HasAPIKey {
|
||||
status = " (ready)"
|
||||
if p.Info.EnvKey == "" {
|
||||
status = " (local)"
|
||||
}
|
||||
}
|
||||
m.listItems[i] = p.Info.DisplayName + status
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// Init initializes the model.
|
||||
func (m Model) Init() tea.Cmd {
|
||||
return textinput.Blink
|
||||
}
|
||||
|
||||
// selectProvider sets up the selected provider.
|
||||
func (m *Model) selectProvider(index int) error {
|
||||
if index < 0 || index >= len(m.providers) {
|
||||
return nil
|
||||
}
|
||||
|
||||
p := m.providers[index]
|
||||
key := m.apiKeys[p.Info.Name] // empty for key-less providers like Ollama
|
||||
|
||||
if p.Info.EnvKey != "" && key == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.providerName = p.Info.DisplayName
|
||||
m.providerIndex = index
|
||||
m.client = p.Info.New(key)
|
||||
|
||||
// Select default model.
|
||||
if len(p.Info.Models) > 0 {
|
||||
return m.selectModel(p.ModelIndex)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// selectModel sets the current model.
|
||||
func (m *Model) selectModel(index int) error {
|
||||
if m.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
p := m.providers[m.providerIndex]
|
||||
if index < 0 || index >= len(p.Info.Models) {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelName := p.Info.Models[index]
|
||||
m.chat = m.client.Model(modelName)
|
||||
m.modelName = modelName
|
||||
m.providers[m.providerIndex].ModelIndex = index
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// newConversation resets the conversation.
|
||||
func (m *Model) newConversation() {
|
||||
m.conversation = []llm.Message{}
|
||||
m.messages = []DisplayMessage{}
|
||||
m.pendingImages = []llm.Image{}
|
||||
m.err = nil
|
||||
}
|
||||
|
||||
// addUserMessage adds a user message to the conversation.
|
||||
func (m *Model) addUserMessage(text string, images []llm.Image) {
|
||||
msg := llm.Message{
|
||||
Role: llm.RoleUser,
|
||||
Content: llm.Content{Text: text, Images: images},
|
||||
}
|
||||
m.conversation = append(m.conversation, msg)
|
||||
m.messages = append(m.messages, DisplayMessage{
|
||||
Role: llm.RoleUser,
|
||||
Content: text,
|
||||
Images: len(images),
|
||||
})
|
||||
}
|
||||
|
||||
// addAssistantMessage adds an assistant message to the conversation display.
|
||||
func (m *Model) addAssistantMessage(content string) {
|
||||
m.messages = append(m.messages, DisplayMessage{
|
||||
Role: llm.RoleAssistant,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
// addToolCallMessage adds a tool call message to display.
|
||||
func (m *Model) addToolCallMessage(name string, args string) {
|
||||
m.messages = append(m.messages, DisplayMessage{
|
||||
Role: llm.Role("tool_call"),
|
||||
Content: name + ": " + args,
|
||||
})
|
||||
}
|
||||
|
||||
// addToolResultMessage adds a tool result message to display.
|
||||
func (m *Model) addToolResultMessage(name string, result string) {
|
||||
m.messages = append(m.messages, DisplayMessage{
|
||||
Role: llm.Role("tool_result"),
|
||||
Content: name + " -> " + result,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
var (
|
||||
// Colors
|
||||
primaryColor = lipgloss.Color("205")
|
||||
secondaryColor = lipgloss.Color("39")
|
||||
accentColor = lipgloss.Color("212")
|
||||
mutedColor = lipgloss.Color("241")
|
||||
errorColor = lipgloss.Color("196")
|
||||
successColor = lipgloss.Color("82")
|
||||
|
||||
// App styles
|
||||
appStyle = lipgloss.NewStyle().Padding(1, 2)
|
||||
|
||||
// Header
|
||||
headerStyle = lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(primaryColor).
|
||||
BorderStyle(lipgloss.NormalBorder()).
|
||||
BorderBottom(true).
|
||||
BorderForeground(mutedColor).
|
||||
Padding(0, 1)
|
||||
|
||||
// Provider badge
|
||||
providerBadgeStyle = lipgloss.NewStyle().
|
||||
Background(secondaryColor).
|
||||
Foreground(lipgloss.Color("0")).
|
||||
Padding(0, 1).
|
||||
Bold(true)
|
||||
|
||||
// Messages
|
||||
systemMsgStyle = lipgloss.NewStyle().
|
||||
Foreground(mutedColor).
|
||||
Italic(true).
|
||||
Padding(0, 1)
|
||||
|
||||
userMsgStyle = lipgloss.NewStyle().
|
||||
Foreground(secondaryColor).
|
||||
Padding(0, 1)
|
||||
|
||||
assistantMsgStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("255")).
|
||||
Padding(0, 1)
|
||||
|
||||
roleLabelStyle = lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Width(12)
|
||||
|
||||
// Tool calls
|
||||
toolCallStyle = lipgloss.NewStyle().
|
||||
Foreground(accentColor).
|
||||
Italic(true).
|
||||
Padding(0, 1)
|
||||
|
||||
toolResultStyle = lipgloss.NewStyle().
|
||||
Foreground(successColor).
|
||||
Padding(0, 1)
|
||||
|
||||
// Input area
|
||||
inputStyle = lipgloss.NewStyle().
|
||||
BorderStyle(lipgloss.RoundedBorder()).
|
||||
BorderForeground(primaryColor).
|
||||
Padding(0, 1)
|
||||
|
||||
inputHelpStyle = lipgloss.NewStyle().
|
||||
Foreground(mutedColor).
|
||||
Italic(true)
|
||||
|
||||
// Error
|
||||
errorStyle = lipgloss.NewStyle().
|
||||
Foreground(errorColor).
|
||||
Bold(true)
|
||||
|
||||
// Loading
|
||||
loadingStyle = lipgloss.NewStyle().
|
||||
Foreground(accentColor).
|
||||
Italic(true)
|
||||
|
||||
// List selection
|
||||
selectedItemStyle = lipgloss.NewStyle().
|
||||
Foreground(primaryColor).
|
||||
Bold(true)
|
||||
|
||||
normalItemStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("255"))
|
||||
|
||||
// Settings panel
|
||||
settingLabelStyle = lipgloss.NewStyle().
|
||||
Foreground(secondaryColor).
|
||||
Width(15)
|
||||
|
||||
settingValueStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("255"))
|
||||
|
||||
// Help text
|
||||
helpStyle = lipgloss.NewStyle().
|
||||
Foreground(mutedColor).
|
||||
Padding(1, 0)
|
||||
|
||||
// Image indicator
|
||||
imageIndicatorStyle = lipgloss.NewStyle().
|
||||
Foreground(accentColor).
|
||||
Bold(true)
|
||||
|
||||
// Viewport
|
||||
viewportStyle = lipgloss.NewStyle().
|
||||
BorderStyle(lipgloss.NormalBorder()).
|
||||
BorderForeground(mutedColor)
|
||||
)
|
||||
@@ -0,0 +1,114 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
|
||||
)
|
||||
|
||||
// TimeParams is the parameter struct for the GetTime function.
|
||||
type TimeParams struct{}
|
||||
|
||||
// GetTime returns the current time.
|
||||
func GetTime(_ context.Context, _ TimeParams) (string, error) {
|
||||
return time.Now().Format("Monday, January 2, 2006 3:04:05 PM MST"), nil
|
||||
}
|
||||
|
||||
// CalcParams is the parameter struct for the Calculate function.
|
||||
type CalcParams struct {
|
||||
A float64 `json:"a" description:"First number"`
|
||||
B float64 `json:"b" description:"Second number"`
|
||||
Op string `json:"op" description:"Operation: add, subtract, multiply, divide, power, sqrt, mod"`
|
||||
}
|
||||
|
||||
// Calculate performs basic math operations.
|
||||
func Calculate(_ context.Context, params CalcParams) (string, error) {
|
||||
var result float64
|
||||
switch strings.ToLower(params.Op) {
|
||||
case "add", "+":
|
||||
result = params.A + params.B
|
||||
case "subtract", "sub", "-":
|
||||
result = params.A - params.B
|
||||
case "multiply", "mul", "*":
|
||||
result = params.A * params.B
|
||||
case "divide", "div", "/":
|
||||
if params.B == 0 {
|
||||
return "", fmt.Errorf("division by zero")
|
||||
}
|
||||
result = params.A / params.B
|
||||
case "power", "pow", "^":
|
||||
result = math.Pow(params.A, params.B)
|
||||
case "sqrt":
|
||||
if params.A < 0 {
|
||||
return "", fmt.Errorf("cannot take square root of negative number")
|
||||
}
|
||||
result = math.Sqrt(params.A)
|
||||
case "mod", "%":
|
||||
result = math.Mod(params.A, params.B)
|
||||
default:
|
||||
return "", fmt.Errorf("unknown operation: %s", params.Op)
|
||||
}
|
||||
return strconv.FormatFloat(result, 'f', -1, 64), nil
|
||||
}
|
||||
|
||||
// WeatherParams is the parameter struct for the GetWeather function.
|
||||
type WeatherParams struct {
|
||||
Location string `json:"location" description:"City name or location"`
|
||||
}
|
||||
|
||||
// GetWeather returns mock weather data (for demo purposes).
|
||||
func GetWeather(_ context.Context, params WeatherParams) (string, error) {
|
||||
weathers := []string{"sunny", "cloudy", "rainy", "partly cloudy", "windy"}
|
||||
temps := []int{65, 72, 58, 80, 45}
|
||||
idx := len(params.Location) % len(weathers)
|
||||
|
||||
out := map[string]any{
|
||||
"location": params.Location,
|
||||
"temperature": strconv.Itoa(temps[idx]) + "F",
|
||||
"condition": weathers[idx],
|
||||
"humidity": "45%",
|
||||
"note": "This is mock data for demonstration purposes",
|
||||
}
|
||||
b, err := json.Marshal(out)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// RandomNumberParams is the parameter struct for the RandomNumber function.
|
||||
type RandomNumberParams struct {
|
||||
Min int `json:"min" description:"Minimum value (inclusive)"`
|
||||
Max int `json:"max" description:"Maximum value (inclusive)"`
|
||||
}
|
||||
|
||||
// RandomNumber generates a pseudo-random number (using current time nanoseconds).
|
||||
func RandomNumber(_ context.Context, params RandomNumberParams) (string, error) {
|
||||
if params.Min > params.Max {
|
||||
return "", fmt.Errorf("min cannot be greater than max")
|
||||
}
|
||||
n := time.Now().UnixNano()
|
||||
rangeSize := params.Max - params.Min + 1
|
||||
result := params.Min + int(n%int64(rangeSize))
|
||||
return strconv.Itoa(result), nil
|
||||
}
|
||||
|
||||
// createDemoToolbox creates a toolbox with demo tools for testing.
|
||||
func createDemoToolbox() *llm.ToolBox {
|
||||
return llm.NewToolBox(
|
||||
llm.Define[TimeParams]("get_time", "Get the current date and time", GetTime),
|
||||
llm.Define[CalcParams]("calculate",
|
||||
"Perform basic math operations (add, subtract, multiply, divide, power, sqrt, mod)",
|
||||
Calculate),
|
||||
llm.Define[WeatherParams]("get_weather",
|
||||
"Get weather information for a location (demo data)", GetWeather),
|
||||
llm.Define[RandomNumberParams]("random_number",
|
||||
"Generate a random number between min and max", RandomNumber),
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,409 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
|
||||
)
|
||||
|
||||
// pendingToolCalls stores the last response's tool calls so we can pair them
|
||||
// with tool execution results for display.
|
||||
var pendingToolCalls []llm.ToolCall
|
||||
|
||||
// Update handles messages and updates the model.
|
||||
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
var cmds []tea.Cmd
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
return m.handleKeyMsg(msg)
|
||||
|
||||
case tea.WindowSizeMsg:
|
||||
m.width = msg.Width
|
||||
m.height = msg.Height
|
||||
|
||||
headerHeight := 3
|
||||
footerHeight := 4
|
||||
verticalMargins := headerHeight + footerHeight
|
||||
|
||||
if !m.viewportReady {
|
||||
m.viewport = viewport.New(msg.Width-4, msg.Height-verticalMargins)
|
||||
m.viewport.HighPerformanceRendering = false
|
||||
m.viewportReady = true
|
||||
} else {
|
||||
m.viewport.Width = msg.Width - 4
|
||||
m.viewport.Height = msg.Height - verticalMargins
|
||||
}
|
||||
|
||||
m.input.Width = msg.Width - 6
|
||||
m.apiKeyInput.Width = msg.Width - 6
|
||||
|
||||
m.viewport.SetContent(m.renderMessages())
|
||||
|
||||
case ChatResponseMsg:
|
||||
m.loading = false
|
||||
if msg.Err != nil {
|
||||
m.err = msg.Err
|
||||
return m, nil
|
||||
}
|
||||
|
||||
resp := msg.Response
|
||||
|
||||
// Add the assistant message to the conversation history.
|
||||
m.conversation = append(m.conversation, resp.Message())
|
||||
|
||||
// Show any text the assistant produced alongside tool calls.
|
||||
if resp.Text != "" {
|
||||
m.addAssistantMessage(resp.Text)
|
||||
}
|
||||
|
||||
if resp.HasToolCalls() && m.toolsEnabled {
|
||||
pendingToolCalls = resp.ToolCalls
|
||||
|
||||
for _, call := range resp.ToolCalls {
|
||||
m.addToolCallMessage(call.Name, call.Arguments)
|
||||
}
|
||||
|
||||
m.viewport.SetContent(m.renderMessages())
|
||||
m.viewport.GotoBottom()
|
||||
|
||||
m.loading = true
|
||||
return m, executeTools(m.toolbox, resp.ToolCalls)
|
||||
}
|
||||
|
||||
m.viewport.SetContent(m.renderMessages())
|
||||
m.viewport.GotoBottom()
|
||||
|
||||
case ToolExecutionMsg:
|
||||
if msg.Err != nil {
|
||||
m.loading = false
|
||||
m.err = msg.Err
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Display results paired with the tool calls that produced them.
|
||||
for i, result := range msg.Results {
|
||||
name := ""
|
||||
if i < len(pendingToolCalls) {
|
||||
name = pendingToolCalls[i].Name
|
||||
}
|
||||
m.addToolResultMessage(name, result.Content.Text)
|
||||
}
|
||||
|
||||
// Append the raw tool result messages to the conversation so the
|
||||
// assistant can reference them on the next turn.
|
||||
m.conversation = append(m.conversation, msg.Results...)
|
||||
|
||||
m.viewport.SetContent(m.renderMessages())
|
||||
m.viewport.GotoBottom()
|
||||
|
||||
// Ask the model to continue given the tool results.
|
||||
return m, sendChatRequest(m.chat, m.conversation, m.toolbox, m.toolsEnabled, m.temperature)
|
||||
|
||||
case ImageLoadedMsg:
|
||||
if msg.Err != nil {
|
||||
m.err = msg.Err
|
||||
m.state = m.previousState
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.pendingImages = append(m.pendingImages, msg.Image)
|
||||
m.state = m.previousState
|
||||
m.err = nil
|
||||
|
||||
default:
|
||||
// Update text input.
|
||||
if m.state == StateChat {
|
||||
m.input, cmd = m.input.Update(msg)
|
||||
cmds = append(cmds, cmd)
|
||||
} else if m.state == StateAPIKeyInput {
|
||||
m.apiKeyInput, cmd = m.apiKeyInput.Update(msg)
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
}
|
||||
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
// handleKeyMsg handles keyboard input.
|
||||
func (m Model) handleKeyMsg(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
switch msg.String() {
|
||||
case "ctrl+c":
|
||||
return m, tea.Quit
|
||||
case "esc":
|
||||
if m.state != StateChat {
|
||||
m.state = StateChat
|
||||
m.input.Focus()
|
||||
return m, nil
|
||||
}
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
switch m.state {
|
||||
case StateChat:
|
||||
return m.handleChatKeys(msg)
|
||||
case StateProviderSelect:
|
||||
return m.handleProviderSelectKeys(msg)
|
||||
case StateModelSelect:
|
||||
return m.handleModelSelectKeys(msg)
|
||||
case StateImageInput:
|
||||
return m.handleImageInputKeys(msg)
|
||||
case StateToolsPanel:
|
||||
return m.handleToolsPanelKeys(msg)
|
||||
case StateSettings:
|
||||
return m.handleSettingsKeys(msg)
|
||||
case StateAPIKeyInput:
|
||||
return m.handleAPIKeyInputKeys(msg)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// handleChatKeys handles keys in chat state.
|
||||
func (m Model) handleChatKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
switch msg.String() {
|
||||
case "enter":
|
||||
if m.loading {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
text := strings.TrimSpace(m.input.Value())
|
||||
if text == "" {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.chat == nil {
|
||||
m.err = fmt.Errorf("no model selected - press Ctrl+P to select a provider")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Ensure a system message is at the head of the conversation.
|
||||
if len(m.conversation) == 0 && m.systemPrompt != "" {
|
||||
m.conversation = append(m.conversation, llm.SystemMessage(m.systemPrompt))
|
||||
}
|
||||
|
||||
m.addUserMessage(text, m.pendingImages)
|
||||
|
||||
m.input.Reset()
|
||||
m.pendingImages = nil
|
||||
m.err = nil
|
||||
m.loading = true
|
||||
|
||||
m.viewport.SetContent(m.renderMessages())
|
||||
m.viewport.GotoBottom()
|
||||
|
||||
return m, sendChatRequest(m.chat, m.conversation, m.toolbox, m.toolsEnabled, m.temperature)
|
||||
|
||||
case "ctrl+i":
|
||||
m.previousState = StateChat
|
||||
m.state = StateImageInput
|
||||
m.input.SetValue("")
|
||||
m.input.Placeholder = "Enter image path or URL..."
|
||||
return m, nil
|
||||
|
||||
case "ctrl+t":
|
||||
m.state = StateToolsPanel
|
||||
return m, nil
|
||||
|
||||
case "ctrl+p":
|
||||
m.state = StateProviderSelect
|
||||
m.listIndex = m.providerIndex
|
||||
return m, nil
|
||||
|
||||
case "ctrl+m":
|
||||
if m.client == nil {
|
||||
m.err = fmt.Errorf("select a provider first")
|
||||
return m, nil
|
||||
}
|
||||
m.state = StateModelSelect
|
||||
m.listItems = m.providers[m.providerIndex].Info.Models
|
||||
m.listIndex = m.providers[m.providerIndex].ModelIndex
|
||||
return m, nil
|
||||
|
||||
case "ctrl+s":
|
||||
m.state = StateSettings
|
||||
return m, nil
|
||||
|
||||
case "ctrl+n":
|
||||
m.newConversation()
|
||||
m.viewport.SetContent(m.renderMessages())
|
||||
return m, nil
|
||||
|
||||
case "up", "down", "pgup", "pgdown":
|
||||
var cmd tea.Cmd
|
||||
m.viewport, cmd = m.viewport.Update(msg)
|
||||
return m, cmd
|
||||
|
||||
default:
|
||||
var cmd tea.Cmd
|
||||
m.input, cmd = m.input.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
}
|
||||
|
||||
// handleProviderSelectKeys handles keys in provider selection state.
|
||||
func (m Model) handleProviderSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
switch msg.String() {
|
||||
case "up", "k":
|
||||
if m.listIndex > 0 {
|
||||
m.listIndex--
|
||||
}
|
||||
case "down", "j":
|
||||
if m.listIndex < len(m.providers)-1 {
|
||||
m.listIndex++
|
||||
}
|
||||
case "enter":
|
||||
p := m.providers[m.listIndex]
|
||||
if !p.HasAPIKey {
|
||||
m.state = StateAPIKeyInput
|
||||
m.apiKeyInput.Focus()
|
||||
m.apiKeyInput.SetValue("")
|
||||
return m, textinput.Blink
|
||||
}
|
||||
|
||||
if err := m.selectProvider(m.listIndex); err != nil {
|
||||
m.err = err
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.state = StateChat
|
||||
m.input.Focus()
|
||||
m.newConversation()
|
||||
return m, nil
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// handleAPIKeyInputKeys handles keys in API key input state.
|
||||
func (m Model) handleAPIKeyInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
switch msg.String() {
|
||||
case "enter":
|
||||
key := strings.TrimSpace(m.apiKeyInput.Value())
|
||||
if key == "" {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
p := m.providers[m.listIndex]
|
||||
m.apiKeys[p.Info.Name] = key
|
||||
m.providers[m.listIndex].HasAPIKey = true
|
||||
|
||||
for i, prov := range m.providers {
|
||||
status := " (no key)"
|
||||
if prov.HasAPIKey {
|
||||
status = " (ready)"
|
||||
if prov.Info.EnvKey == "" {
|
||||
status = " (local)"
|
||||
}
|
||||
}
|
||||
m.listItems[i] = prov.Info.DisplayName + status
|
||||
}
|
||||
|
||||
if err := m.selectProvider(m.listIndex); err != nil {
|
||||
m.err = err
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.state = StateChat
|
||||
m.input.Focus()
|
||||
m.newConversation()
|
||||
return m, nil
|
||||
|
||||
default:
|
||||
var cmd tea.Cmd
|
||||
m.apiKeyInput, cmd = m.apiKeyInput.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
}
|
||||
|
||||
// handleModelSelectKeys handles keys in model selection state.
|
||||
func (m Model) handleModelSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
switch msg.String() {
|
||||
case "up", "k":
|
||||
if m.listIndex > 0 {
|
||||
m.listIndex--
|
||||
}
|
||||
case "down", "j":
|
||||
if m.listIndex < len(m.listItems)-1 {
|
||||
m.listIndex++
|
||||
}
|
||||
case "enter":
|
||||
if err := m.selectModel(m.listIndex); err != nil {
|
||||
m.err = err
|
||||
return m, nil
|
||||
}
|
||||
m.state = StateChat
|
||||
m.input.Focus()
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// handleImageInputKeys handles keys in image input state.
|
||||
func (m Model) handleImageInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
switch msg.String() {
|
||||
case "enter":
|
||||
input := strings.TrimSpace(m.input.Value())
|
||||
if input == "" {
|
||||
m.state = m.previousState
|
||||
m.input.Placeholder = "Type your message..."
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.input.Placeholder = "Type your message..."
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://"):
|
||||
return m, loadImageFromURL(input)
|
||||
case strings.HasPrefix(input, "data:") || (len(input) > 100 && !strings.Contains(input, "/") && !strings.Contains(input, "\\")):
|
||||
return m, loadImageFromBase64(input)
|
||||
default:
|
||||
return m, loadImageFromPath(input)
|
||||
}
|
||||
|
||||
default:
|
||||
var cmd tea.Cmd
|
||||
m.input, cmd = m.input.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
}
|
||||
|
||||
// handleToolsPanelKeys handles keys in tools panel state.
|
||||
func (m Model) handleToolsPanelKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
switch msg.String() {
|
||||
case "t":
|
||||
m.toolsEnabled = !m.toolsEnabled
|
||||
case "enter", "q":
|
||||
m.state = StateChat
|
||||
m.input.Focus()
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// handleSettingsKeys handles keys in settings state.
|
||||
func (m Model) handleSettingsKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
switch msg.String() {
|
||||
case "1":
|
||||
m.temperature = nil
|
||||
case "2":
|
||||
t := 0.0
|
||||
m.temperature = &t
|
||||
case "3":
|
||||
t := 0.5
|
||||
m.temperature = &t
|
||||
case "4":
|
||||
t := 0.7
|
||||
m.temperature = &t
|
||||
case "5":
|
||||
t := 1.0
|
||||
m.temperature = &t
|
||||
case "enter", "q":
|
||||
m.state = StateChat
|
||||
m.input.Focus()
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
@@ -0,0 +1,291 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
|
||||
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
|
||||
)
|
||||
|
||||
// View renders the current state.
|
||||
func (m Model) View() string {
|
||||
switch m.state {
|
||||
case StateProviderSelect:
|
||||
return m.renderProviderSelect()
|
||||
case StateModelSelect:
|
||||
return m.renderModelSelect()
|
||||
case StateImageInput:
|
||||
return m.renderImageInput()
|
||||
case StateToolsPanel:
|
||||
return m.renderToolsPanel()
|
||||
case StateSettings:
|
||||
return m.renderSettings()
|
||||
case StateAPIKeyInput:
|
||||
return m.renderAPIKeyInput()
|
||||
default:
|
||||
return m.renderChat()
|
||||
}
|
||||
}
|
||||
|
||||
// renderChat renders the main chat view.
|
||||
func (m Model) renderChat() string {
|
||||
var b strings.Builder
|
||||
|
||||
provider := m.providerName
|
||||
if provider == "" {
|
||||
provider = "None"
|
||||
}
|
||||
model := m.modelName
|
||||
if model == "" {
|
||||
model = "None"
|
||||
}
|
||||
|
||||
header := headerStyle.Render(fmt.Sprintf("go-llm CLI %s",
|
||||
providerBadgeStyle.Render(fmt.Sprintf("%s/%s", provider, model))))
|
||||
|
||||
b.WriteString(header)
|
||||
b.WriteString("\n")
|
||||
|
||||
if m.viewportReady {
|
||||
b.WriteString(m.viewport.View())
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
if len(m.pendingImages) > 0 {
|
||||
b.WriteString(imageIndicatorStyle.Render(fmt.Sprintf(" [%d image(s) attached]", len(m.pendingImages))))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
if m.err != nil {
|
||||
b.WriteString(errorStyle.Render(" Error: " + m.err.Error()))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
if m.loading {
|
||||
b.WriteString(loadingStyle.Render(" Thinking..."))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
inputBox := inputStyle.Render(m.input.View())
|
||||
b.WriteString(inputBox)
|
||||
b.WriteString("\n")
|
||||
|
||||
help := inputHelpStyle.Render("Enter: send | Ctrl+I: image | Ctrl+T: tools | Ctrl+P: provider | Ctrl+M: model | Ctrl+S: settings | Ctrl+N: new | Esc: quit")
|
||||
b.WriteString(help)
|
||||
|
||||
return appStyle.Render(b.String())
|
||||
}
|
||||
|
||||
// renderMessages renders all messages for the viewport.
|
||||
func (m Model) renderMessages() string {
|
||||
var b strings.Builder
|
||||
|
||||
if len(m.messages) == 0 {
|
||||
b.WriteString(systemMsgStyle.Render("[System] " + m.systemPrompt))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(lipgloss.NewStyle().Foreground(mutedColor).Render("Start a conversation by typing a message below."))
|
||||
return b.String()
|
||||
}
|
||||
|
||||
b.WriteString(systemMsgStyle.Render("[System] " + m.systemPrompt))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
for _, msg := range m.messages {
|
||||
var content string
|
||||
var style lipgloss.Style
|
||||
|
||||
switch msg.Role {
|
||||
case llm.RoleUser:
|
||||
style = userMsgStyle
|
||||
label := roleLabelStyle.Foreground(secondaryColor).Render("[User]")
|
||||
content = label + " " + msg.Content
|
||||
if msg.Images > 0 {
|
||||
content += imageIndicatorStyle.Render(fmt.Sprintf(" [%d image(s)]", msg.Images))
|
||||
}
|
||||
case llm.RoleAssistant:
|
||||
style = assistantMsgStyle
|
||||
label := roleLabelStyle.Foreground(lipgloss.Color("255")).Render("[Assistant]")
|
||||
content = label + " " + msg.Content
|
||||
case llm.Role("tool_call"):
|
||||
style = toolCallStyle
|
||||
content = " -> Calling: " + msg.Content
|
||||
case llm.Role("tool_result"):
|
||||
style = toolResultStyle
|
||||
content = " <- Result: " + msg.Content
|
||||
default:
|
||||
style = assistantMsgStyle
|
||||
content = msg.Content
|
||||
}
|
||||
|
||||
b.WriteString(style.Render(content))
|
||||
b.WriteString("\n\n")
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// renderProviderSelect renders the provider selection view.
|
||||
func (m Model) renderProviderSelect() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString(headerStyle.Render("Select Provider"))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
for i, item := range m.listItems {
|
||||
cursor := " "
|
||||
style := normalItemStyle
|
||||
if i == m.listIndex {
|
||||
cursor = "> "
|
||||
style = selectedItemStyle
|
||||
}
|
||||
b.WriteString(style.Render(cursor + item))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(helpStyle.Render("Use arrow keys or j/k to navigate, Enter to select, Esc to cancel"))
|
||||
|
||||
return appStyle.Render(b.String())
|
||||
}
|
||||
|
||||
// renderAPIKeyInput renders the API key input view.
|
||||
func (m Model) renderAPIKeyInput() string {
|
||||
var b strings.Builder
|
||||
|
||||
provider := m.providers[m.listIndex]
|
||||
|
||||
b.WriteString(headerStyle.Render(fmt.Sprintf("Enter API Key for %s", provider.Info.DisplayName)))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
if provider.Info.EnvKey != "" {
|
||||
b.WriteString(fmt.Sprintf("Environment variable: %s\n\n", provider.Info.EnvKey))
|
||||
}
|
||||
b.WriteString("Enter your API key below (it will be hidden):\n\n")
|
||||
|
||||
inputBox := inputStyle.Render(m.apiKeyInput.View())
|
||||
b.WriteString(inputBox)
|
||||
b.WriteString("\n\n")
|
||||
|
||||
b.WriteString(helpStyle.Render("Enter to confirm, Esc to cancel"))
|
||||
|
||||
return appStyle.Render(b.String())
|
||||
}
|
||||
|
||||
// renderModelSelect renders the model selection view.
|
||||
func (m Model) renderModelSelect() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString(headerStyle.Render(fmt.Sprintf("Select Model (%s)", m.providerName)))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
for i, item := range m.listItems {
|
||||
cursor := " "
|
||||
style := normalItemStyle
|
||||
if i == m.listIndex {
|
||||
cursor = "> "
|
||||
style = selectedItemStyle
|
||||
}
|
||||
if item == m.modelName {
|
||||
item += " (current)"
|
||||
}
|
||||
b.WriteString(style.Render(cursor + item))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(helpStyle.Render("Use arrow keys or j/k to navigate, Enter to select, Esc to cancel"))
|
||||
|
||||
return appStyle.Render(b.String())
|
||||
}
|
||||
|
||||
// renderImageInput renders the image input view.
|
||||
func (m Model) renderImageInput() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString(headerStyle.Render("Add Image"))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
b.WriteString("Enter an image source:\n")
|
||||
b.WriteString(" - File path (e.g., /path/to/image.png)\n")
|
||||
b.WriteString(" - URL (e.g., https://example.com/image.jpg)\n")
|
||||
b.WriteString(" - Base64 data or data URL\n\n")
|
||||
|
||||
if len(m.pendingImages) > 0 {
|
||||
b.WriteString(imageIndicatorStyle.Render(fmt.Sprintf("Currently attached: %d image(s)\n\n", len(m.pendingImages))))
|
||||
}
|
||||
|
||||
inputBox := inputStyle.Render(m.input.View())
|
||||
b.WriteString(inputBox)
|
||||
b.WriteString("\n\n")
|
||||
|
||||
b.WriteString(helpStyle.Render("Enter to add image, Esc to cancel"))
|
||||
|
||||
return appStyle.Render(b.String())
|
||||
}
|
||||
|
||||
// renderToolsPanel renders the tools panel.
|
||||
func (m Model) renderToolsPanel() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString(headerStyle.Render("Tools / Function Calling"))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
status := "DISABLED"
|
||||
statusStyle := errorStyle
|
||||
if m.toolsEnabled {
|
||||
status = "ENABLED"
|
||||
statusStyle = lipgloss.NewStyle().Foreground(successColor).Bold(true)
|
||||
}
|
||||
|
||||
b.WriteString(settingLabelStyle.Render("Tools Status:"))
|
||||
b.WriteString(statusStyle.Render(status))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
b.WriteString("Available tools:\n")
|
||||
if m.toolbox != nil {
|
||||
for _, t := range m.toolbox.AllTools() {
|
||||
b.WriteString(fmt.Sprintf(" - %s: %s\n", selectedItemStyle.Render(t.Name), t.Description))
|
||||
}
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(helpStyle.Render("Press 't' to toggle tools, Enter or 'q' to close"))
|
||||
|
||||
return appStyle.Render(b.String())
|
||||
}
|
||||
|
||||
// renderSettings renders the settings view.
|
||||
func (m Model) renderSettings() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString(headerStyle.Render("Settings"))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
tempStr := "default"
|
||||
if m.temperature != nil {
|
||||
tempStr = fmt.Sprintf("%.1f", *m.temperature)
|
||||
}
|
||||
b.WriteString(settingLabelStyle.Render("Temperature:"))
|
||||
b.WriteString(settingValueStyle.Render(tempStr))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
b.WriteString("Press a key to set temperature:\n")
|
||||
b.WriteString(" 1 - Default (model decides)\n")
|
||||
b.WriteString(" 2 - 0.0 (deterministic)\n")
|
||||
b.WriteString(" 3 - 0.5 (balanced)\n")
|
||||
b.WriteString(" 4 - 0.7 (creative)\n")
|
||||
b.WriteString(" 5 - 1.0 (very creative)\n")
|
||||
|
||||
b.WriteString("\n")
|
||||
|
||||
b.WriteString(settingLabelStyle.Render("System Prompt:"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(settingValueStyle.Render(" " + m.systemPrompt))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
b.WriteString(helpStyle.Render("Enter or 'q' to close"))
|
||||
|
||||
return appStyle.Render(b.String())
|
||||
}
|
||||
@@ -2,8 +2,13 @@ package llm
|
||||
|
||||
import (
|
||||
anthProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/anthropic"
|
||||
deepseekProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/deepseek"
|
||||
googleProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/google"
|
||||
groqProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/groq"
|
||||
moonshotProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/moonshot"
|
||||
ollamaProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/ollama"
|
||||
openaiProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openai"
|
||||
xaiProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/xai"
|
||||
)
|
||||
|
||||
// OpenAI creates an OpenAI client.
|
||||
@@ -46,3 +51,69 @@ func Google(apiKey string, opts ...ClientOption) *Client {
|
||||
_ = cfg // Google doesn't support custom base URL in the SDK
|
||||
return NewClient(googleProvider.New(apiKey))
|
||||
}
|
||||
|
||||
// DeepSeek creates a DeepSeek client (OpenAI-compatible).
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// model := llm.DeepSeek("sk-...").Model("deepseek-chat")
|
||||
func DeepSeek(apiKey string, opts ...ClientOption) *Client {
|
||||
cfg := &clientConfig{}
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
return NewClient(deepseekProvider.New(apiKey, cfg.baseURL))
|
||||
}
|
||||
|
||||
// Moonshot creates a Moonshot AI (Kimi) client (OpenAI-compatible).
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// model := llm.Moonshot("sk-...").Model("kimi-k2-0711-preview")
|
||||
func Moonshot(apiKey string, opts ...ClientOption) *Client {
|
||||
cfg := &clientConfig{}
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
return NewClient(moonshotProvider.New(apiKey, cfg.baseURL))
|
||||
}
|
||||
|
||||
// XAI creates an xAI (Grok) client (OpenAI-compatible).
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// model := llm.XAI("xai-...").Model("grok-2")
|
||||
func XAI(apiKey string, opts ...ClientOption) *Client {
|
||||
cfg := &clientConfig{}
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
return NewClient(xaiProvider.New(apiKey, cfg.baseURL))
|
||||
}
|
||||
|
||||
// Groq creates a Groq client (OpenAI-compatible).
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// model := llm.Groq("gsk-...").Model("llama-3.3-70b-versatile")
|
||||
func Groq(apiKey string, opts ...ClientOption) *Client {
|
||||
cfg := &clientConfig{}
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
return NewClient(groqProvider.New(apiKey, cfg.baseURL))
|
||||
}
|
||||
|
||||
// Ollama creates a client for a local Ollama instance (OpenAI-compatible).
|
||||
// No API key is required. Use WithBaseURL to point at a non-default host/port.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// model := llm.Ollama().Model("llama3.2")
|
||||
func Ollama(opts ...ClientOption) *Client {
|
||||
cfg := &clientConfig{}
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
return NewClient(ollamaProvider.New("", cfg.baseURL))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
// Package deepseek implements the go-llm v2 provider interface for DeepSeek
|
||||
// (https://platform.deepseek.com). DeepSeek speaks the OpenAI Chat Completions
|
||||
// protocol, so this package is a thin wrapper around openaicompat with its own
|
||||
// defaults and per-model Rules.
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||
)
|
||||
|
||||
// DefaultBaseURL is the public DeepSeek API endpoint.
|
||||
const DefaultBaseURL = "https://api.deepseek.com/v1"
|
||||
|
||||
// Provider is a type alias over openaicompat.Provider.
|
||||
type Provider = openaicompat.Provider
|
||||
|
||||
// New creates a new DeepSeek provider. An empty baseURL uses DefaultBaseURL.
|
||||
func New(apiKey, baseURL string) *Provider {
|
||||
if baseURL == "" {
|
||||
baseURL = DefaultBaseURL
|
||||
}
|
||||
return openaicompat.New(apiKey, baseURL, openaicompat.Rules{
|
||||
// DeepSeek's chat and reasoner models are text-only.
|
||||
SupportsVision: func(string) bool { return false },
|
||||
// Reasoner doesn't accept tool calls.
|
||||
SupportsTools: func(m string) bool {
|
||||
return !strings.Contains(m, "reasoner")
|
||||
},
|
||||
// Reasoner rejects user-supplied temperature.
|
||||
RestrictTemperature: func(m string) bool {
|
||||
return strings.Contains(m, "reasoner")
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package deepseek_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/deepseek"
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||
)
|
||||
|
||||
func TestNew_DefaultBaseURL(t *testing.T) {
|
||||
if p := deepseek.New("key", ""); p == nil {
|
||||
t.Fatal("New returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRules_ReasonerRejectsTools(t *testing.T) {
|
||||
p := deepseek.New("key", "")
|
||||
req := provider.Request{
|
||||
Model: "deepseek-reasoner",
|
||||
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
||||
Tools: []provider.ToolDef{
|
||||
{Name: "x", Schema: map[string]any{"type": "object"}},
|
||||
},
|
||||
}
|
||||
_, err := p.Complete(context.Background(), req)
|
||||
var fue *openaicompat.FeatureUnsupportedError
|
||||
if !errors.As(err, &fue) || fue.Feature != "tools" {
|
||||
t.Fatalf("want FeatureUnsupportedError(tools), got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRules_ChatRejectsImages(t *testing.T) {
|
||||
p := deepseek.New("key", "")
|
||||
req := provider.Request{
|
||||
Model: "deepseek-chat",
|
||||
Messages: []provider.Message{{
|
||||
Role: "user",
|
||||
Images: []provider.Image{{URL: "a"}},
|
||||
}},
|
||||
}
|
||||
_, err := p.Complete(context.Background(), req)
|
||||
var fue *openaicompat.FeatureUnsupportedError
|
||||
if !errors.As(err, &fue) || fue.Feature != "vision" {
|
||||
t.Fatalf("want FeatureUnsupportedError(vision), got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,12 @@
|
||||
module gitea.stevedudenhoeffer.com/steve/go-llm/v2
|
||||
|
||||
go 1.24.0
|
||||
|
||||
toolchain go1.24.2
|
||||
go 1.24.2
|
||||
|
||||
require (
|
||||
github.com/charmbracelet/bubbles v1.0.0
|
||||
github.com/charmbracelet/bubbletea v1.3.10
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/liushuangls/go-anthropic/v2 v2.17.0
|
||||
github.com/modelcontextprotocol/go-sdk v1.2.0
|
||||
github.com/openai/openai-go v1.12.0
|
||||
@@ -18,6 +20,16 @@ require (
|
||||
cloud.google.com/go v0.116.0 // indirect
|
||||
cloud.google.com/go/auth v0.9.3 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.5.0 // indirect
|
||||
github.com/atotto/clipboard v0.1.4 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.4.1 // indirect
|
||||
github.com/charmbracelet/x/ansi v0.11.6 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||
github.com/clipperhouse/displaywidth v0.9.0 // indirect
|
||||
github.com/clipperhouse/stringish v0.1.1 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/jsonschema-go v0.3.0 // indirect
|
||||
@@ -25,15 +37,24 @@ require (
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/kr/fs v0.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.19 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/tidwall/gjson v1.14.4 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
golang.org/x/net v0.42.0 // indirect
|
||||
golang.org/x/oauth2 v0.30.0 // indirect
|
||||
golang.org/x/sys v0.35.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.33.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
|
||||
google.golang.org/grpc v1.66.2 // indirect
|
||||
|
||||
@@ -6,8 +6,32 @@ cloud.google.com/go/auth v0.9.3/go.mod h1:7z6VY+7h3KUdRov5F1i8NDP5ZzWKYmEPO842Bg
|
||||
cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY=
|
||||
cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
||||
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
|
||||
github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
|
||||
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
||||
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
|
||||
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
|
||||
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
|
||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
|
||||
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
|
||||
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
|
||||
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
|
||||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
|
||||
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
|
||||
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
|
||||
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
||||
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
@@ -16,6 +40,8 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF
|
||||
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||
@@ -49,12 +75,28 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gT
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
|
||||
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
|
||||
github.com/liushuangls/go-anthropic/v2 v2.17.0 h1:iBA6h7aghi1q86owEQ95XE2R2MF/0dQ7bCxtwTxOg4c=
|
||||
github.com/liushuangls/go-anthropic/v2 v2.17.0/go.mod h1:a550cJXPoTG2FL3DvfKG2zzD5O2vjgvo4tHtoGPzFLU=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
|
||||
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s=
|
||||
github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0=
|
||||
github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
|
||||
github.com/pkg/sftp v1.13.10 h1:+5FbKNTe5Z9aspU88DPIKJ9z2KZoaGCu6Sr6kKR/5mU=
|
||||
@@ -62,6 +104,8 @@ github.com/pkg/sftp v1.13.10/go.mod h1:bJ1a7uDhrX/4OII+agvy28lzRvQrmIQuaHrcI1Hbe
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
@@ -80,6 +124,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
|
||||
@@ -89,6 +135,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
|
||||
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
||||
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
|
||||
golang.org/x/image v0.35.0 h1:LKjiHdgMtO8z7Fh18nGY6KDcoEtVfsgLDPeLyguqb7I=
|
||||
golang.org/x/image v0.35.0/go.mod h1:MwPLTVgvxSASsxdLzKrl8BRFuyqMyGhLwmC+TO1Sybk=
|
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
@@ -114,8 +162,10 @@ golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5h
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4=
|
||||
golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
// Package groq implements the go-llm v2 provider interface for Groq
|
||||
// (https://console.groq.com). Groq hosts open-source models behind an OpenAI
|
||||
// Chat Completions-compatible endpoint, so this package is a thin wrapper over
|
||||
// openaicompat with its own defaults and per-model Rules.
|
||||
package groq
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||
)
|
||||
|
||||
// DefaultBaseURL is the public Groq OpenAI-compatible endpoint.
|
||||
const DefaultBaseURL = "https://api.groq.com/openai/v1"
|
||||
|
||||
// Provider is a type alias over openaicompat.Provider.
|
||||
type Provider = openaicompat.Provider
|
||||
|
||||
// New creates a new Groq provider. An empty baseURL uses DefaultBaseURL.
|
||||
func New(apiKey, baseURL string) *Provider {
|
||||
if baseURL == "" {
|
||||
baseURL = DefaultBaseURL
|
||||
}
|
||||
return openaicompat.New(apiKey, baseURL, openaicompat.Rules{
|
||||
// Only Groq-hosted vision variants (e.g. *-vision-preview) accept images.
|
||||
SupportsVision: func(m string) bool {
|
||||
return strings.Contains(m, "vision")
|
||||
},
|
||||
// Chat completions endpoint does not accept audio input; audio is via
|
||||
// dedicated transcription endpoints, which go-llm doesn't cover here.
|
||||
SupportsAudio: func(string) bool { return false },
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package groq_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/groq"
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||
)
|
||||
|
||||
func TestNew_Basic(t *testing.T) {
|
||||
if p := groq.New("key", ""); p == nil {
|
||||
t.Fatal("New returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRules_AudioRejected(t *testing.T) {
|
||||
p := groq.New("key", "")
|
||||
req := provider.Request{
|
||||
Model: "llama-3.3-70b-versatile",
|
||||
Messages: []provider.Message{{
|
||||
Role: "user",
|
||||
Audio: []provider.Audio{{Base64: "AAA=", ContentType: "audio/wav"}},
|
||||
}},
|
||||
}
|
||||
_, err := p.Complete(context.Background(), req)
|
||||
var fue *openaicompat.FeatureUnsupportedError
|
||||
if !errors.As(err, &fue) || fue.Feature != "audio" {
|
||||
t.Fatalf("want FeatureUnsupportedError(audio), got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
// Package moonshot implements the go-llm v2 provider interface for Moonshot
|
||||
// AI (Kimi, https://platform.moonshot.ai). Moonshot speaks OpenAI Chat
|
||||
// Completions, so this package is a thin wrapper over openaicompat with its
|
||||
// own defaults and per-model Rules.
|
||||
package moonshot
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||
)
|
||||
|
||||
// DefaultBaseURL is the public Moonshot API endpoint (international).
|
||||
const DefaultBaseURL = "https://api.moonshot.ai/v1"
|
||||
|
||||
// Provider is a type alias over openaicompat.Provider.
|
||||
type Provider = openaicompat.Provider
|
||||
|
||||
// New creates a new Moonshot provider. An empty baseURL uses DefaultBaseURL.
|
||||
func New(apiKey, baseURL string) *Provider {
|
||||
if baseURL == "" {
|
||||
baseURL = DefaultBaseURL
|
||||
}
|
||||
return openaicompat.New(apiKey, baseURL, openaicompat.Rules{
|
||||
// Only Moonshot models whose name contains "vision" accept images.
|
||||
SupportsVision: func(m string) bool {
|
||||
return strings.Contains(m, "vision")
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package moonshot_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/moonshot"
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||
)
|
||||
|
||||
func TestNew_Basic(t *testing.T) {
|
||||
if p := moonshot.New("key", ""); p == nil {
|
||||
t.Fatal("New returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRules_NonVisionModelRejectsImages(t *testing.T) {
|
||||
p := moonshot.New("key", "")
|
||||
req := provider.Request{
|
||||
Model: "moonshot-v1-8k",
|
||||
Messages: []provider.Message{{
|
||||
Role: "user",
|
||||
Images: []provider.Image{{URL: "a"}},
|
||||
}},
|
||||
}
|
||||
_, err := p.Complete(context.Background(), req)
|
||||
var fue *openaicompat.FeatureUnsupportedError
|
||||
if !errors.As(err, &fue) || fue.Feature != "vision" {
|
||||
t.Fatalf("want FeatureUnsupportedError(vision), got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
// Package ollama implements the go-llm v2 provider interface for Ollama
|
||||
// (https://ollama.com), a local model runner that exposes an OpenAI Chat
|
||||
// Completions-compatible endpoint. No API key is required; capability depends
|
||||
// on whichever model the user has pulled locally, so Rules are intentionally
|
||||
// empty — we trust the local user.
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||
)
|
||||
|
||||
// DefaultBaseURL points at a local Ollama instance with default port.
|
||||
const DefaultBaseURL = "http://localhost:11434/v1"
|
||||
|
||||
// Provider is a type alias over openaicompat.Provider.
|
||||
type Provider = openaicompat.Provider
|
||||
|
||||
// New creates a new Ollama provider. An empty baseURL uses DefaultBaseURL.
|
||||
// Ollama ignores the API key; callers may pass "".
|
||||
func New(apiKey, baseURL string) *Provider {
|
||||
if baseURL == "" {
|
||||
baseURL = DefaultBaseURL
|
||||
}
|
||||
return openaicompat.New(apiKey, baseURL, openaicompat.Rules{})
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package ollama_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/ollama"
|
||||
)
|
||||
|
||||
func TestNew_NoKeyNeeded(t *testing.T) {
|
||||
if p := ollama.New("", ""); p == nil {
|
||||
t.Fatal("New returned nil")
|
||||
}
|
||||
}
|
||||
+22
-420
@@ -1,433 +1,35 @@
|
||||
// Package openai implements the go-llm v2 provider interface for OpenAI.
|
||||
//
|
||||
// The actual wire-protocol logic lives in the shared openaicompat package;
|
||||
// this file encodes OpenAI-specific Rules (temperature is rejected on o-series
|
||||
// and gpt-5* models) and supplies the default base URL.
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/option"
|
||||
"github.com/openai/openai-go/packages/param"
|
||||
"github.com/openai/openai-go/shared"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||
)
|
||||
|
||||
// Provider implements the provider.Provider interface for OpenAI.
|
||||
type Provider struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
}
|
||||
// DefaultBaseURL is the public OpenAI Chat Completions endpoint.
|
||||
const DefaultBaseURL = "https://api.openai.com/v1"
|
||||
|
||||
// New creates a new OpenAI provider.
|
||||
// Provider is the OpenAI chat-completion provider. It's a type alias over
|
||||
// openaicompat.Provider so existing callers using openai.Provider keep compiling.
|
||||
type Provider = openaicompat.Provider
|
||||
|
||||
// New creates a new OpenAI provider. An empty baseURL uses DefaultBaseURL.
|
||||
func New(apiKey string, baseURL string) *Provider {
|
||||
return &Provider{apiKey: apiKey, baseURL: baseURL}
|
||||
if baseURL == "" {
|
||||
baseURL = DefaultBaseURL
|
||||
}
|
||||
return openaicompat.New(apiKey, baseURL, openaicompat.Rules{
|
||||
RestrictTemperature: restrictTemperature,
|
||||
})
|
||||
}
|
||||
|
||||
// Complete performs a non-streaming completion.
|
||||
func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
var opts []option.RequestOption
|
||||
opts = append(opts, option.WithAPIKey(p.apiKey))
|
||||
if p.baseURL != "" {
|
||||
opts = append(opts, option.WithBaseURL(p.baseURL))
|
||||
}
|
||||
|
||||
cl := openai.NewClient(opts...)
|
||||
oaiReq := p.buildRequest(req)
|
||||
|
||||
resp, err := cl.Chat.Completions.New(ctx, oaiReq)
|
||||
if err != nil {
|
||||
return provider.Response{}, fmt.Errorf("openai completion error: %w", err)
|
||||
}
|
||||
|
||||
return p.convertResponse(resp), nil
|
||||
}
|
||||
|
||||
// Stream performs a streaming completion.
|
||||
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
|
||||
var opts []option.RequestOption
|
||||
opts = append(opts, option.WithAPIKey(p.apiKey))
|
||||
if p.baseURL != "" {
|
||||
opts = append(opts, option.WithBaseURL(p.baseURL))
|
||||
}
|
||||
|
||||
cl := openai.NewClient(opts...)
|
||||
oaiReq := p.buildRequest(req)
|
||||
oaiReq.StreamOptions = openai.ChatCompletionStreamOptionsParam{
|
||||
IncludeUsage: openai.Bool(true),
|
||||
}
|
||||
|
||||
stream := cl.Chat.Completions.NewStreaming(ctx, oaiReq)
|
||||
|
||||
var fullText strings.Builder
|
||||
var toolCalls []provider.ToolCall
|
||||
toolCallArgs := map[int]*strings.Builder{}
|
||||
var usage *provider.Usage
|
||||
|
||||
for stream.Next() {
|
||||
chunk := stream.Current()
|
||||
|
||||
// Capture usage from the final chunk (present when StreamOptions.IncludeUsage is true)
|
||||
if chunk.Usage.TotalTokens > 0 {
|
||||
usage = &provider.Usage{
|
||||
InputTokens: int(chunk.Usage.PromptTokens),
|
||||
OutputTokens: int(chunk.Usage.CompletionTokens),
|
||||
TotalTokens: int(chunk.Usage.TotalTokens),
|
||||
Details: extractUsageDetails(chunk.Usage),
|
||||
}
|
||||
}
|
||||
|
||||
for _, choice := range chunk.Choices {
|
||||
// Text delta
|
||||
if choice.Delta.Content != "" {
|
||||
fullText.WriteString(choice.Delta.Content)
|
||||
events <- provider.StreamEvent{
|
||||
Type: provider.StreamEventText,
|
||||
Text: choice.Delta.Content,
|
||||
}
|
||||
}
|
||||
|
||||
// Tool call deltas
|
||||
for _, tc := range choice.Delta.ToolCalls {
|
||||
idx := int(tc.Index)
|
||||
|
||||
if tc.ID != "" {
|
||||
// New tool call starting
|
||||
for len(toolCalls) <= idx {
|
||||
toolCalls = append(toolCalls, provider.ToolCall{})
|
||||
}
|
||||
toolCalls[idx].ID = tc.ID
|
||||
toolCalls[idx].Name = tc.Function.Name
|
||||
toolCallArgs[idx] = &strings.Builder{}
|
||||
|
||||
events <- provider.StreamEvent{
|
||||
Type: provider.StreamEventToolStart,
|
||||
ToolCall: &provider.ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
},
|
||||
ToolIndex: idx,
|
||||
}
|
||||
}
|
||||
|
||||
if tc.Function.Arguments != "" {
|
||||
if b, ok := toolCallArgs[idx]; ok {
|
||||
b.WriteString(tc.Function.Arguments)
|
||||
}
|
||||
events <- provider.StreamEvent{
|
||||
Type: provider.StreamEventToolDelta,
|
||||
ToolIndex: idx,
|
||||
ToolCall: &provider.ToolCall{
|
||||
Arguments: tc.Function.Arguments,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil {
|
||||
return fmt.Errorf("openai stream error: %w", err)
|
||||
}
|
||||
|
||||
// Finalize tool calls
|
||||
for idx := range toolCalls {
|
||||
if b, ok := toolCallArgs[idx]; ok {
|
||||
toolCalls[idx].Arguments = b.String()
|
||||
}
|
||||
events <- provider.StreamEvent{
|
||||
Type: provider.StreamEventToolEnd,
|
||||
ToolIndex: idx,
|
||||
ToolCall: &toolCalls[idx],
|
||||
}
|
||||
}
|
||||
|
||||
// Send done event
|
||||
events <- provider.StreamEvent{
|
||||
Type: provider.StreamEventDone,
|
||||
Response: &provider.Response{
|
||||
Text: fullText.String(),
|
||||
ToolCalls: toolCalls,
|
||||
Usage: usage,
|
||||
},
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) buildRequest(req provider.Request) openai.ChatCompletionNewParams {
|
||||
oaiReq := openai.ChatCompletionNewParams{
|
||||
Model: req.Model,
|
||||
}
|
||||
|
||||
for _, msg := range req.Messages {
|
||||
oaiReq.Messages = append(oaiReq.Messages, convertMessage(msg, req.Model))
|
||||
}
|
||||
|
||||
for _, tool := range req.Tools {
|
||||
oaiReq.Tools = append(oaiReq.Tools, openai.ChatCompletionToolParam{
|
||||
Type: "function",
|
||||
Function: shared.FunctionDefinitionParam{
|
||||
Name: tool.Name,
|
||||
Description: openai.String(tool.Description),
|
||||
Parameters: openai.FunctionParameters(tool.Schema),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if req.Temperature != nil {
|
||||
// o* and gpt-5* models don't support custom temperatures
|
||||
if !strings.HasPrefix(req.Model, "o") && !strings.HasPrefix(req.Model, "gpt-5") {
|
||||
oaiReq.Temperature = openai.Float(*req.Temperature)
|
||||
}
|
||||
}
|
||||
|
||||
if req.MaxTokens != nil {
|
||||
oaiReq.MaxCompletionTokens = openai.Int(int64(*req.MaxTokens))
|
||||
}
|
||||
|
||||
if req.TopP != nil {
|
||||
oaiReq.TopP = openai.Float(*req.TopP)
|
||||
}
|
||||
|
||||
if len(req.Stop) > 0 {
|
||||
oaiReq.Stop = openai.ChatCompletionNewParamsStopUnion{OfString: openai.String(req.Stop[0])}
|
||||
}
|
||||
|
||||
return oaiReq
|
||||
}
|
||||
|
||||
func convertMessage(msg provider.Message, model string) openai.ChatCompletionMessageParamUnion {
|
||||
var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam
|
||||
var textContent param.Opt[string]
|
||||
|
||||
for _, img := range msg.Images {
|
||||
var url string
|
||||
if img.Base64 != "" {
|
||||
url = "data:" + img.ContentType + ";base64," + img.Base64
|
||||
} else if img.URL != "" {
|
||||
url = img.URL
|
||||
}
|
||||
if url != "" {
|
||||
arrayOfContentParts = append(arrayOfContentParts,
|
||||
openai.ChatCompletionContentPartUnionParam{
|
||||
OfImageURL: &openai.ChatCompletionContentPartImageParam{
|
||||
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
|
||||
URL: url,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
for _, aud := range msg.Audio {
|
||||
var b64Data string
|
||||
var format string
|
||||
|
||||
if aud.Base64 != "" {
|
||||
b64Data = aud.Base64
|
||||
format = audioFormat(aud.ContentType)
|
||||
} else if aud.URL != "" {
|
||||
resp, err := http.Get(aud.URL)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
b64Data = base64.StdEncoding.EncodeToString(data)
|
||||
ct := resp.Header.Get("Content-Type")
|
||||
if ct == "" {
|
||||
ct = aud.ContentType
|
||||
}
|
||||
if ct == "" {
|
||||
ct = audioFormatFromURL(aud.URL)
|
||||
}
|
||||
format = audioFormat(ct)
|
||||
}
|
||||
|
||||
if b64Data != "" && format != "" {
|
||||
arrayOfContentParts = append(arrayOfContentParts,
|
||||
openai.ChatCompletionContentPartUnionParam{
|
||||
OfInputAudio: &openai.ChatCompletionContentPartInputAudioParam{
|
||||
InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
|
||||
Data: b64Data,
|
||||
Format: format,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if msg.Content != "" {
|
||||
if len(arrayOfContentParts) > 0 {
|
||||
arrayOfContentParts = append(arrayOfContentParts,
|
||||
openai.ChatCompletionContentPartUnionParam{
|
||||
OfText: &openai.ChatCompletionContentPartTextParam{
|
||||
Text: msg.Content,
|
||||
},
|
||||
},
|
||||
)
|
||||
} else {
|
||||
textContent = openai.String(msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Determine if this model uses developer messages instead of system
|
||||
useDeveloper := false
|
||||
parts := strings.Split(model, "-")
|
||||
if len(parts) > 1 && len(parts[0]) > 0 && parts[0][0] == 'o' {
|
||||
useDeveloper = true
|
||||
}
|
||||
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
if useDeveloper {
|
||||
return openai.ChatCompletionMessageParamUnion{
|
||||
OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{
|
||||
Content: openai.ChatCompletionDeveloperMessageParamContentUnion{
|
||||
OfString: textContent,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return openai.ChatCompletionMessageParamUnion{
|
||||
OfSystem: &openai.ChatCompletionSystemMessageParam{
|
||||
Content: openai.ChatCompletionSystemMessageParamContentUnion{
|
||||
OfString: textContent,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
case "user":
|
||||
return openai.ChatCompletionMessageParamUnion{
|
||||
OfUser: &openai.ChatCompletionUserMessageParam{
|
||||
Content: openai.ChatCompletionUserMessageParamContentUnion{
|
||||
OfString: textContent,
|
||||
OfArrayOfContentParts: arrayOfContentParts,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
case "assistant":
|
||||
as := &openai.ChatCompletionAssistantMessageParam{}
|
||||
if msg.Content != "" {
|
||||
as.Content.OfString = openai.String(msg.Content)
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{
|
||||
ID: tc.ID,
|
||||
Function: openai.ChatCompletionMessageToolCallFunctionParam{
|
||||
Name: tc.Name,
|
||||
Arguments: tc.Arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
return openai.ChatCompletionMessageParamUnion{OfAssistant: as}
|
||||
|
||||
case "tool":
|
||||
return openai.ChatCompletionMessageParamUnion{
|
||||
OfTool: &openai.ChatCompletionToolMessageParam{
|
||||
ToolCallID: msg.ToolCallID,
|
||||
Content: openai.ChatCompletionToolMessageParamContentUnion{
|
||||
OfString: openai.String(msg.Content),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to user message
|
||||
return openai.ChatCompletionMessageParamUnion{
|
||||
OfUser: &openai.ChatCompletionUserMessageParam{
|
||||
Content: openai.ChatCompletionUserMessageParamContentUnion{
|
||||
OfString: textContent,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Provider) convertResponse(resp *openai.ChatCompletion) provider.Response {
|
||||
var res provider.Response
|
||||
|
||||
if resp == nil || len(resp.Choices) == 0 {
|
||||
return res
|
||||
}
|
||||
|
||||
choice := resp.Choices[0]
|
||||
res.Text = choice.Message.Content
|
||||
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
res.ToolCalls = append(res.ToolCalls, provider.ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: strings.TrimSpace(tc.Function.Arguments),
|
||||
})
|
||||
}
|
||||
|
||||
if resp.Usage.TotalTokens > 0 {
|
||||
res.Usage = &provider.Usage{
|
||||
InputTokens: int(resp.Usage.PromptTokens),
|
||||
OutputTokens: int(resp.Usage.CompletionTokens),
|
||||
TotalTokens: int(resp.Usage.TotalTokens),
|
||||
}
|
||||
res.Usage.Details = extractUsageDetails(resp.Usage)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// audioFormat converts a MIME type to an OpenAI audio format string ("wav" or "mp3").
|
||||
func audioFormat(contentType string) string {
|
||||
ct := strings.ToLower(contentType)
|
||||
switch {
|
||||
case strings.Contains(ct, "wav"):
|
||||
return "wav"
|
||||
case strings.Contains(ct, "mp3"), strings.Contains(ct, "mpeg"):
|
||||
return "mp3"
|
||||
default:
|
||||
return "wav"
|
||||
}
|
||||
}
|
||||
|
||||
// extractUsageDetails extracts provider-specific detail tokens from an OpenAI CompletionUsage.
|
||||
func extractUsageDetails(usage openai.CompletionUsage) map[string]int {
|
||||
details := map[string]int{}
|
||||
if usage.CompletionTokensDetails.ReasoningTokens > 0 {
|
||||
details[provider.UsageDetailReasoningTokens] = int(usage.CompletionTokensDetails.ReasoningTokens)
|
||||
}
|
||||
if usage.CompletionTokensDetails.AudioTokens > 0 {
|
||||
details[provider.UsageDetailAudioOutputTokens] = int(usage.CompletionTokensDetails.AudioTokens)
|
||||
}
|
||||
if usage.PromptTokensDetails.CachedTokens > 0 {
|
||||
details[provider.UsageDetailCachedInputTokens] = int(usage.PromptTokensDetails.CachedTokens)
|
||||
}
|
||||
if usage.PromptTokensDetails.AudioTokens > 0 {
|
||||
details[provider.UsageDetailAudioInputTokens] = int(usage.PromptTokensDetails.AudioTokens)
|
||||
}
|
||||
if len(details) == 0 {
|
||||
return nil
|
||||
}
|
||||
return details
|
||||
}
|
||||
|
||||
// audioFormatFromURL guesses the audio format from a URL's file extension.
|
||||
func audioFormatFromURL(u string) string {
|
||||
ext := strings.ToLower(path.Ext(u))
|
||||
switch ext {
|
||||
case ".mp3":
|
||||
return "audio/mp3"
|
||||
case ".wav":
|
||||
return "audio/wav"
|
||||
default:
|
||||
return "audio/wav"
|
||||
}
|
||||
// restrictTemperature reports whether OpenAI rejects a user-supplied
|
||||
// temperature for this model. o-series reasoning models and gpt-5* both do.
|
||||
func restrictTemperature(model string) bool {
|
||||
return strings.HasPrefix(model, "o") || strings.HasPrefix(model, "gpt-5")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,537 @@
|
||||
// Package openaicompat implements a shared chat-completion Provider for any
|
||||
// service that speaks the OpenAI Chat Completions API (OpenAI itself, DeepSeek,
|
||||
// Moonshot, xAI, Groq, Ollama, and friends).
|
||||
//
|
||||
// Most providers differ from vanilla OpenAI only in endpoint URL and a handful
|
||||
// of per-model quirks (e.g., "this model is text-only", "this model doesn't
|
||||
// accept tools", "drop temperature on reasoning models"). Those quirks are
|
||||
// captured declaratively via Rules, so a concrete provider package is usually
|
||||
// a one-function wrapper that calls New with its own base URL and Rules.
|
||||
package openaicompat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/option"
|
||||
"github.com/openai/openai-go/packages/param"
|
||||
"github.com/openai/openai-go/shared"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||
)
|
||||
|
||||
// Rules encodes provider-specific constraints on top of the OpenAI wire
|
||||
// protocol. The zero value means "no restrictions" and behaves like vanilla
|
||||
// OpenAI. Individual fields are documented inline.
|
||||
type Rules struct {
|
||||
// MaxImagesPerMessage rejects requests whose any single message carries
|
||||
// more images than this cap. 0 means "no cap".
|
||||
MaxImagesPerMessage int
|
||||
|
||||
// MaxAudioPerMessage rejects requests whose any single message carries
|
||||
// more audio attachments than this cap. 0 means "no cap".
|
||||
MaxAudioPerMessage int
|
||||
|
||||
// SupportsVision, when non-nil, is consulted for every request that
|
||||
// includes any image attachments. If it returns false for the request's
|
||||
// model, the call fails with a FeatureUnsupportedError before hitting
|
||||
// the network.
|
||||
SupportsVision func(model string) bool
|
||||
|
||||
// SupportsTools, when non-nil, is consulted for every request that
|
||||
// includes any tool definitions. If it returns false for the model,
|
||||
// the call fails with a FeatureUnsupportedError before hitting the
|
||||
// network.
|
||||
SupportsTools func(model string) bool
|
||||
|
||||
// SupportsAudio, when non-nil, is consulted for every request that
|
||||
// includes any audio attachments. If it returns false for the model,
|
||||
// the call fails with a FeatureUnsupportedError.
|
||||
SupportsAudio func(model string) bool
|
||||
|
||||
// RestrictTemperature, when non-nil and returning true for the request's
|
||||
// model, causes the Temperature field to be silently dropped from the
|
||||
// outgoing request. Used by OpenAI o-series and gpt-5* which reject a
|
||||
// user-provided temperature.
|
||||
RestrictTemperature func(model string) bool
|
||||
|
||||
// CustomizeRequest is a last-mile hook invoked after buildRequest but
|
||||
// before the call is sent. It receives the fully built OpenAI SDK
|
||||
// parameters and may mutate them freely (add headers, flip flags, tweak
|
||||
// response_format, etc.).
|
||||
CustomizeRequest func(params *openai.ChatCompletionNewParams)
|
||||
}
|
||||
|
||||
// FeatureUnsupportedError is returned when a Rules predicate rejects a request
|
||||
// because the target model does not support a feature the caller included.
|
||||
type FeatureUnsupportedError struct {
|
||||
Feature string
|
||||
Model string
|
||||
}
|
||||
|
||||
func (e *FeatureUnsupportedError) Error() string {
|
||||
return fmt.Sprintf("openaicompat: model %q does not support %s", e.Model, e.Feature)
|
||||
}
|
||||
|
||||
// Provider implements provider.Provider for any OpenAI-compatible endpoint.
|
||||
type Provider struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
rules Rules
|
||||
}
|
||||
|
||||
// New creates a Provider. baseURL may be empty to let the OpenAI SDK use its
|
||||
// default; in practice concrete provider packages always pass a default.
|
||||
func New(apiKey, baseURL string, rules Rules) *Provider {
|
||||
return &Provider{apiKey: apiKey, baseURL: baseURL, rules: rules}
|
||||
}
|
||||
|
||||
// Complete performs a non-streaming completion.
|
||||
func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
if err := p.checkRules(req); err != nil {
|
||||
return provider.Response{}, err
|
||||
}
|
||||
|
||||
cl := openai.NewClient(p.requestOptions()...)
|
||||
oaiReq := p.buildRequest(req)
|
||||
if p.rules.CustomizeRequest != nil {
|
||||
p.rules.CustomizeRequest(&oaiReq)
|
||||
}
|
||||
|
||||
resp, err := cl.Chat.Completions.New(ctx, oaiReq)
|
||||
if err != nil {
|
||||
return provider.Response{}, fmt.Errorf("openai completion error: %w", err)
|
||||
}
|
||||
|
||||
return p.convertResponse(resp), nil
|
||||
}
|
||||
|
||||
// Stream performs a streaming completion.
|
||||
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
|
||||
if err := p.checkRules(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cl := openai.NewClient(p.requestOptions()...)
|
||||
oaiReq := p.buildRequest(req)
|
||||
oaiReq.StreamOptions = openai.ChatCompletionStreamOptionsParam{
|
||||
IncludeUsage: openai.Bool(true),
|
||||
}
|
||||
if p.rules.CustomizeRequest != nil {
|
||||
p.rules.CustomizeRequest(&oaiReq)
|
||||
}
|
||||
|
||||
stream := cl.Chat.Completions.NewStreaming(ctx, oaiReq)
|
||||
|
||||
var fullText strings.Builder
|
||||
var toolCalls []provider.ToolCall
|
||||
toolCallArgs := map[int]*strings.Builder{}
|
||||
var usage *provider.Usage
|
||||
|
||||
for stream.Next() {
|
||||
chunk := stream.Current()
|
||||
|
||||
// Capture usage from the final chunk (present when StreamOptions.IncludeUsage is true)
|
||||
if chunk.Usage.TotalTokens > 0 {
|
||||
usage = &provider.Usage{
|
||||
InputTokens: int(chunk.Usage.PromptTokens),
|
||||
OutputTokens: int(chunk.Usage.CompletionTokens),
|
||||
TotalTokens: int(chunk.Usage.TotalTokens),
|
||||
Details: extractUsageDetails(chunk.Usage),
|
||||
}
|
||||
}
|
||||
|
||||
for _, choice := range chunk.Choices {
|
||||
// Text delta
|
||||
if choice.Delta.Content != "" {
|
||||
fullText.WriteString(choice.Delta.Content)
|
||||
events <- provider.StreamEvent{
|
||||
Type: provider.StreamEventText,
|
||||
Text: choice.Delta.Content,
|
||||
}
|
||||
}
|
||||
|
||||
// Tool call deltas
|
||||
for _, tc := range choice.Delta.ToolCalls {
|
||||
idx := int(tc.Index)
|
||||
|
||||
if tc.ID != "" {
|
||||
// New tool call starting
|
||||
for len(toolCalls) <= idx {
|
||||
toolCalls = append(toolCalls, provider.ToolCall{})
|
||||
}
|
||||
toolCalls[idx].ID = tc.ID
|
||||
toolCalls[idx].Name = tc.Function.Name
|
||||
toolCallArgs[idx] = &strings.Builder{}
|
||||
|
||||
events <- provider.StreamEvent{
|
||||
Type: provider.StreamEventToolStart,
|
||||
ToolCall: &provider.ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
},
|
||||
ToolIndex: idx,
|
||||
}
|
||||
}
|
||||
|
||||
if tc.Function.Arguments != "" {
|
||||
if b, ok := toolCallArgs[idx]; ok {
|
||||
b.WriteString(tc.Function.Arguments)
|
||||
}
|
||||
events <- provider.StreamEvent{
|
||||
Type: provider.StreamEventToolDelta,
|
||||
ToolIndex: idx,
|
||||
ToolCall: &provider.ToolCall{
|
||||
Arguments: tc.Function.Arguments,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil {
|
||||
return fmt.Errorf("openai stream error: %w", err)
|
||||
}
|
||||
|
||||
// Finalize tool calls
|
||||
for idx := range toolCalls {
|
||||
if b, ok := toolCallArgs[idx]; ok {
|
||||
toolCalls[idx].Arguments = b.String()
|
||||
}
|
||||
events <- provider.StreamEvent{
|
||||
Type: provider.StreamEventToolEnd,
|
||||
ToolIndex: idx,
|
||||
ToolCall: &toolCalls[idx],
|
||||
}
|
||||
}
|
||||
|
||||
events <- provider.StreamEvent{
|
||||
Type: provider.StreamEventDone,
|
||||
Response: &provider.Response{
|
||||
Text: fullText.String(),
|
||||
ToolCalls: toolCalls,
|
||||
Usage: usage,
|
||||
},
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) requestOptions() []option.RequestOption {
|
||||
opts := []option.RequestOption{option.WithAPIKey(p.apiKey)}
|
||||
if p.baseURL != "" {
|
||||
opts = append(opts, option.WithBaseURL(p.baseURL))
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
// checkRules applies all Rules predicates against a request and returns an
|
||||
// error if any constraint is violated. Runs before any network call.
|
||||
func (p *Provider) checkRules(req provider.Request) error {
|
||||
var hasImages, hasAudio bool
|
||||
for _, msg := range req.Messages {
|
||||
if len(msg.Images) > 0 {
|
||||
hasImages = true
|
||||
}
|
||||
if len(msg.Audio) > 0 {
|
||||
hasAudio = true
|
||||
}
|
||||
if p.rules.MaxImagesPerMessage > 0 && len(msg.Images) > p.rules.MaxImagesPerMessage {
|
||||
return fmt.Errorf("openaicompat: message has %d images, max allowed is %d for model %q",
|
||||
len(msg.Images), p.rules.MaxImagesPerMessage, req.Model)
|
||||
}
|
||||
if p.rules.MaxAudioPerMessage > 0 && len(msg.Audio) > p.rules.MaxAudioPerMessage {
|
||||
return fmt.Errorf("openaicompat: message has %d audio attachments, max allowed is %d for model %q",
|
||||
len(msg.Audio), p.rules.MaxAudioPerMessage, req.Model)
|
||||
}
|
||||
}
|
||||
|
||||
if hasImages && p.rules.SupportsVision != nil && !p.rules.SupportsVision(req.Model) {
|
||||
return &FeatureUnsupportedError{Feature: "vision", Model: req.Model}
|
||||
}
|
||||
if hasAudio && p.rules.SupportsAudio != nil && !p.rules.SupportsAudio(req.Model) {
|
||||
return &FeatureUnsupportedError{Feature: "audio", Model: req.Model}
|
||||
}
|
||||
if len(req.Tools) > 0 && p.rules.SupportsTools != nil && !p.rules.SupportsTools(req.Model) {
|
||||
return &FeatureUnsupportedError{Feature: "tools", Model: req.Model}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) buildRequest(req provider.Request) openai.ChatCompletionNewParams {
|
||||
oaiReq := openai.ChatCompletionNewParams{
|
||||
Model: req.Model,
|
||||
}
|
||||
|
||||
for _, msg := range req.Messages {
|
||||
oaiReq.Messages = append(oaiReq.Messages, convertMessage(msg, req.Model))
|
||||
}
|
||||
|
||||
for _, tool := range req.Tools {
|
||||
oaiReq.Tools = append(oaiReq.Tools, openai.ChatCompletionToolParam{
|
||||
Type: "function",
|
||||
Function: shared.FunctionDefinitionParam{
|
||||
Name: tool.Name,
|
||||
Description: openai.String(tool.Description),
|
||||
Parameters: openai.FunctionParameters(tool.Schema),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if req.Temperature != nil {
|
||||
if p.rules.RestrictTemperature == nil || !p.rules.RestrictTemperature(req.Model) {
|
||||
oaiReq.Temperature = openai.Float(*req.Temperature)
|
||||
}
|
||||
}
|
||||
|
||||
if req.MaxTokens != nil {
|
||||
oaiReq.MaxCompletionTokens = openai.Int(int64(*req.MaxTokens))
|
||||
}
|
||||
|
||||
if req.TopP != nil {
|
||||
oaiReq.TopP = openai.Float(*req.TopP)
|
||||
}
|
||||
|
||||
if len(req.Stop) > 0 {
|
||||
oaiReq.Stop = openai.ChatCompletionNewParamsStopUnion{OfString: openai.String(req.Stop[0])}
|
||||
}
|
||||
|
||||
return oaiReq
|
||||
}
|
||||
|
||||
func convertMessage(msg provider.Message, model string) openai.ChatCompletionMessageParamUnion {
|
||||
var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam
|
||||
var textContent param.Opt[string]
|
||||
|
||||
for _, img := range msg.Images {
|
||||
var url string
|
||||
if img.Base64 != "" {
|
||||
url = "data:" + img.ContentType + ";base64," + img.Base64
|
||||
} else if img.URL != "" {
|
||||
url = img.URL
|
||||
}
|
||||
if url != "" {
|
||||
arrayOfContentParts = append(arrayOfContentParts,
|
||||
openai.ChatCompletionContentPartUnionParam{
|
||||
OfImageURL: &openai.ChatCompletionContentPartImageParam{
|
||||
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
|
||||
URL: url,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
for _, aud := range msg.Audio {
|
||||
var b64Data string
|
||||
var format string
|
||||
|
||||
if aud.Base64 != "" {
|
||||
b64Data = aud.Base64
|
||||
format = audioFormat(aud.ContentType)
|
||||
} else if aud.URL != "" {
|
||||
resp, err := http.Get(aud.URL)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
b64Data = base64.StdEncoding.EncodeToString(data)
|
||||
ct := resp.Header.Get("Content-Type")
|
||||
if ct == "" {
|
||||
ct = aud.ContentType
|
||||
}
|
||||
if ct == "" {
|
||||
ct = audioFormatFromURL(aud.URL)
|
||||
}
|
||||
format = audioFormat(ct)
|
||||
}
|
||||
|
||||
if b64Data != "" && format != "" {
|
||||
arrayOfContentParts = append(arrayOfContentParts,
|
||||
openai.ChatCompletionContentPartUnionParam{
|
||||
OfInputAudio: &openai.ChatCompletionContentPartInputAudioParam{
|
||||
InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
|
||||
Data: b64Data,
|
||||
Format: format,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if msg.Content != "" {
|
||||
if len(arrayOfContentParts) > 0 {
|
||||
arrayOfContentParts = append(arrayOfContentParts,
|
||||
openai.ChatCompletionContentPartUnionParam{
|
||||
OfText: &openai.ChatCompletionContentPartTextParam{
|
||||
Text: msg.Content,
|
||||
},
|
||||
},
|
||||
)
|
||||
} else {
|
||||
textContent = openai.String(msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Determine if this model uses developer messages instead of system
|
||||
useDeveloper := false
|
||||
parts := strings.Split(model, "-")
|
||||
if len(parts) > 1 && len(parts[0]) > 0 && parts[0][0] == 'o' {
|
||||
useDeveloper = true
|
||||
}
|
||||
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
if useDeveloper {
|
||||
return openai.ChatCompletionMessageParamUnion{
|
||||
OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{
|
||||
Content: openai.ChatCompletionDeveloperMessageParamContentUnion{
|
||||
OfString: textContent,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return openai.ChatCompletionMessageParamUnion{
|
||||
OfSystem: &openai.ChatCompletionSystemMessageParam{
|
||||
Content: openai.ChatCompletionSystemMessageParamContentUnion{
|
||||
OfString: textContent,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
case "user":
|
||||
return openai.ChatCompletionMessageParamUnion{
|
||||
OfUser: &openai.ChatCompletionUserMessageParam{
|
||||
Content: openai.ChatCompletionUserMessageParamContentUnion{
|
||||
OfString: textContent,
|
||||
OfArrayOfContentParts: arrayOfContentParts,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
case "assistant":
|
||||
as := &openai.ChatCompletionAssistantMessageParam{}
|
||||
if msg.Content != "" {
|
||||
as.Content.OfString = openai.String(msg.Content)
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{
|
||||
ID: tc.ID,
|
||||
Function: openai.ChatCompletionMessageToolCallFunctionParam{
|
||||
Name: tc.Name,
|
||||
Arguments: tc.Arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
return openai.ChatCompletionMessageParamUnion{OfAssistant: as}
|
||||
|
||||
case "tool":
|
||||
return openai.ChatCompletionMessageParamUnion{
|
||||
OfTool: &openai.ChatCompletionToolMessageParam{
|
||||
ToolCallID: msg.ToolCallID,
|
||||
Content: openai.ChatCompletionToolMessageParamContentUnion{
|
||||
OfString: openai.String(msg.Content),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to user message
|
||||
return openai.ChatCompletionMessageParamUnion{
|
||||
OfUser: &openai.ChatCompletionUserMessageParam{
|
||||
Content: openai.ChatCompletionUserMessageParamContentUnion{
|
||||
OfString: textContent,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Provider) convertResponse(resp *openai.ChatCompletion) provider.Response {
|
||||
var res provider.Response
|
||||
|
||||
if resp == nil || len(resp.Choices) == 0 {
|
||||
return res
|
||||
}
|
||||
|
||||
choice := resp.Choices[0]
|
||||
res.Text = choice.Message.Content
|
||||
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
res.ToolCalls = append(res.ToolCalls, provider.ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: strings.TrimSpace(tc.Function.Arguments),
|
||||
})
|
||||
}
|
||||
|
||||
if resp.Usage.TotalTokens > 0 {
|
||||
res.Usage = &provider.Usage{
|
||||
InputTokens: int(resp.Usage.PromptTokens),
|
||||
OutputTokens: int(resp.Usage.CompletionTokens),
|
||||
TotalTokens: int(resp.Usage.TotalTokens),
|
||||
}
|
||||
res.Usage.Details = extractUsageDetails(resp.Usage)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// audioFormat converts a MIME type to an OpenAI audio format string ("wav" or "mp3").
|
||||
func audioFormat(contentType string) string {
|
||||
ct := strings.ToLower(contentType)
|
||||
switch {
|
||||
case strings.Contains(ct, "wav"):
|
||||
return "wav"
|
||||
case strings.Contains(ct, "mp3"), strings.Contains(ct, "mpeg"):
|
||||
return "mp3"
|
||||
default:
|
||||
return "wav"
|
||||
}
|
||||
}
|
||||
|
||||
// extractUsageDetails extracts provider-specific detail tokens from an OpenAI CompletionUsage.
|
||||
func extractUsageDetails(usage openai.CompletionUsage) map[string]int {
|
||||
details := map[string]int{}
|
||||
if usage.CompletionTokensDetails.ReasoningTokens > 0 {
|
||||
details[provider.UsageDetailReasoningTokens] = int(usage.CompletionTokensDetails.ReasoningTokens)
|
||||
}
|
||||
if usage.CompletionTokensDetails.AudioTokens > 0 {
|
||||
details[provider.UsageDetailAudioOutputTokens] = int(usage.CompletionTokensDetails.AudioTokens)
|
||||
}
|
||||
if usage.PromptTokensDetails.CachedTokens > 0 {
|
||||
details[provider.UsageDetailCachedInputTokens] = int(usage.PromptTokensDetails.CachedTokens)
|
||||
}
|
||||
if usage.PromptTokensDetails.AudioTokens > 0 {
|
||||
details[provider.UsageDetailAudioInputTokens] = int(usage.PromptTokensDetails.AudioTokens)
|
||||
}
|
||||
if len(details) == 0 {
|
||||
return nil
|
||||
}
|
||||
return details
|
||||
}
|
||||
|
||||
// audioFormatFromURL guesses the audio format from a URL's file extension.
|
||||
func audioFormatFromURL(u string) string {
|
||||
ext := strings.ToLower(path.Ext(u))
|
||||
switch ext {
|
||||
case ".mp3":
|
||||
return "audio/mp3"
|
||||
case ".wav":
|
||||
return "audio/wav"
|
||||
default:
|
||||
return "audio/wav"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,313 @@
|
||||
package openaicompat_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/openai/openai-go"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||
)
|
||||
|
||||
// newTestServer returns a httptest server that captures the raw request body
|
||||
// on POST /chat/completions and returns a canned OpenAI response so Complete()
|
||||
// succeeds. Use `captured` to assert on what the provider would send.
|
||||
func newTestServer(t *testing.T) (*httptest.Server, *[]byte) {
|
||||
t.Helper()
|
||||
var body []byte
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/chat/completions" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
b, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Errorf("read body: %v", err)
|
||||
}
|
||||
body = b
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{
|
||||
"id": "cmpl-1",
|
||||
"object": "chat.completion",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role":"assistant","content":"ok"},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}
|
||||
}`)
|
||||
}))
|
||||
return srv, &body
|
||||
}
|
||||
|
||||
func textReq(model, content string) provider.Request {
|
||||
return provider.Request{
|
||||
Model: model,
|
||||
Messages: []provider.Message{{Role: "user", Content: content}},
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_ZeroRulesPassesThrough(t *testing.T) {
|
||||
srv, body := newTestServer(t)
|
||||
defer srv.Close()
|
||||
|
||||
temp := 0.7
|
||||
req := textReq("gpt-4o", "hi")
|
||||
req.Temperature = &temp
|
||||
|
||||
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{})
|
||||
resp, err := p.Complete(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("Complete: %v", err)
|
||||
}
|
||||
if resp.Text != "ok" {
|
||||
t.Errorf("Text = %q, want %q", resp.Text, "ok")
|
||||
}
|
||||
|
||||
// Temperature should be present since RestrictTemperature is nil.
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(*body, &parsed); err != nil {
|
||||
t.Fatalf("unmarshal request body: %v", err)
|
||||
}
|
||||
if _, ok := parsed["temperature"]; !ok {
|
||||
t.Errorf("expected temperature in request body, got: %s", *body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_RestrictTemperatureDropsField(t *testing.T) {
|
||||
srv, body := newTestServer(t)
|
||||
defer srv.Close()
|
||||
|
||||
temp := 0.7
|
||||
req := textReq("o1", "hi")
|
||||
req.Temperature = &temp
|
||||
|
||||
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
||||
RestrictTemperature: func(m string) bool { return strings.HasPrefix(m, "o") },
|
||||
})
|
||||
if _, err := p.Complete(context.Background(), req); err != nil {
|
||||
t.Fatalf("Complete: %v", err)
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(*body, &parsed); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if _, ok := parsed["temperature"]; ok {
|
||||
t.Errorf("temperature should be dropped for o1, got: %s", *body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_SupportsVisionRejectsWhenFalse(t *testing.T) {
|
||||
srv, _ := newTestServer(t)
|
||||
defer srv.Close()
|
||||
|
||||
req := provider.Request{
|
||||
Model: "deepseek-chat",
|
||||
Messages: []provider.Message{{
|
||||
Role: "user",
|
||||
Content: "describe",
|
||||
Images: []provider.Image{{URL: "https://example.com/a.png"}},
|
||||
}},
|
||||
}
|
||||
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
||||
SupportsVision: func(string) bool { return false },
|
||||
})
|
||||
_, err := p.Complete(context.Background(), req)
|
||||
var fue *openaicompat.FeatureUnsupportedError
|
||||
if !errors.As(err, &fue) {
|
||||
t.Fatalf("want FeatureUnsupportedError, got %v", err)
|
||||
}
|
||||
if fue.Feature != "vision" || fue.Model != "deepseek-chat" {
|
||||
t.Errorf("unexpected err: %+v", fue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_SupportsToolsRejectsWhenFalse(t *testing.T) {
|
||||
srv, _ := newTestServer(t)
|
||||
defer srv.Close()
|
||||
|
||||
req := provider.Request{
|
||||
Model: "deepseek-reasoner",
|
||||
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
||||
Tools: []provider.ToolDef{
|
||||
{Name: "get_weather", Description: "weather", Schema: map[string]any{"type": "object"}},
|
||||
},
|
||||
}
|
||||
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
||||
SupportsTools: func(m string) bool { return !strings.Contains(m, "reasoner") },
|
||||
})
|
||||
_, err := p.Complete(context.Background(), req)
|
||||
var fue *openaicompat.FeatureUnsupportedError
|
||||
if !errors.As(err, &fue) {
|
||||
t.Fatalf("want FeatureUnsupportedError, got %v", err)
|
||||
}
|
||||
if fue.Feature != "tools" {
|
||||
t.Errorf("feature = %q, want tools", fue.Feature)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_SupportsAudioRejectsWhenFalse(t *testing.T) {
|
||||
srv, _ := newTestServer(t)
|
||||
defer srv.Close()
|
||||
|
||||
req := provider.Request{
|
||||
Model: "groq-llama",
|
||||
Messages: []provider.Message{{
|
||||
Role: "user",
|
||||
Audio: []provider.Audio{{Base64: "AAA=", ContentType: "audio/wav"}},
|
||||
}},
|
||||
}
|
||||
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
||||
SupportsAudio: func(string) bool { return false },
|
||||
})
|
||||
_, err := p.Complete(context.Background(), req)
|
||||
var fue *openaicompat.FeatureUnsupportedError
|
||||
if !errors.As(err, &fue) {
|
||||
t.Fatalf("want FeatureUnsupportedError, got %v", err)
|
||||
}
|
||||
if fue.Feature != "audio" {
|
||||
t.Errorf("feature = %q, want audio", fue.Feature)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_MaxImagesPerMessage(t *testing.T) {
|
||||
srv, _ := newTestServer(t)
|
||||
defer srv.Close()
|
||||
|
||||
req := provider.Request{
|
||||
Model: "anything",
|
||||
Messages: []provider.Message{{
|
||||
Role: "user",
|
||||
Images: []provider.Image{
|
||||
{URL: "a"}, {URL: "b"}, {URL: "c"},
|
||||
},
|
||||
}},
|
||||
}
|
||||
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{MaxImagesPerMessage: 2})
|
||||
_, err := p.Complete(context.Background(), req)
|
||||
if err == nil || !strings.Contains(err.Error(), "max allowed is 2") {
|
||||
t.Fatalf("want max-images error, got %v", err)
|
||||
}
|
||||
|
||||
// Exactly at limit succeeds.
|
||||
req.Messages[0].Images = req.Messages[0].Images[:2]
|
||||
if _, err := p.Complete(context.Background(), req); err != nil {
|
||||
t.Errorf("at-limit request should succeed, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_CustomizeRequestInvoked(t *testing.T) {
|
||||
srv, body := newTestServer(t)
|
||||
defer srv.Close()
|
||||
|
||||
called := false
|
||||
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
||||
CustomizeRequest: func(params *openai.ChatCompletionNewParams) {
|
||||
called = true
|
||||
// Confirm we receive a non-empty built request.
|
||||
if params.Model != "gpt-4o" {
|
||||
t.Errorf("CustomizeRequest saw model %q, want gpt-4o", params.Model)
|
||||
}
|
||||
// Mutation here should end up on the wire.
|
||||
params.User = openai.String("test-user")
|
||||
},
|
||||
})
|
||||
if _, err := p.Complete(context.Background(), textReq("gpt-4o", "hi")); err != nil {
|
||||
t.Fatalf("Complete: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Fatal("CustomizeRequest hook was not invoked")
|
||||
}
|
||||
if !strings.Contains(string(*body), `"user":"test-user"`) {
|
||||
t.Errorf("mutation from CustomizeRequest not reflected on wire: %s", *body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStream_EmitsDoneAndText(t *testing.T) {
|
||||
// SSE stream with one content chunk then [DONE].
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
flusher, _ := w.(http.Flusher)
|
||||
for _, line := range []string{
|
||||
`data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"hel"}}]}`,
|
||||
`data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"lo"}}]}`,
|
||||
`data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}`,
|
||||
`data: [DONE]`,
|
||||
} {
|
||||
_, _ = io.WriteString(w, line+"\n\n")
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{})
|
||||
events := make(chan provider.StreamEvent, 16)
|
||||
go func() {
|
||||
_ = p.Stream(context.Background(), textReq("gpt-4o", "hi"), events)
|
||||
close(events)
|
||||
}()
|
||||
|
||||
var text strings.Builder
|
||||
var sawDone bool
|
||||
var doneUsage *provider.Usage
|
||||
for ev := range events {
|
||||
switch ev.Type {
|
||||
case provider.StreamEventText:
|
||||
text.WriteString(ev.Text)
|
||||
case provider.StreamEventDone:
|
||||
sawDone = true
|
||||
if ev.Response != nil {
|
||||
doneUsage = ev.Response.Usage
|
||||
}
|
||||
}
|
||||
}
|
||||
if text.String() != "hello" {
|
||||
t.Errorf("got text %q, want %q", text.String(), "hello")
|
||||
}
|
||||
if !sawDone {
|
||||
t.Fatal("no Done event emitted")
|
||||
}
|
||||
if doneUsage == nil || doneUsage.TotalTokens != 3 {
|
||||
t.Errorf("usage on Done = %+v, want TotalTokens=3", doneUsage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStream_RulesCheckedBeforeNetwork(t *testing.T) {
|
||||
// Server should never be hit when rules reject up front.
|
||||
hit := false
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hit = true
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
||||
SupportsVision: func(string) bool { return false },
|
||||
})
|
||||
req := provider.Request{
|
||||
Model: "no-vision-model",
|
||||
Messages: []provider.Message{{
|
||||
Role: "user",
|
||||
Images: []provider.Image{{URL: "a"}},
|
||||
}},
|
||||
}
|
||||
events := make(chan provider.StreamEvent, 4)
|
||||
err := p.Stream(context.Background(), req, events)
|
||||
var fue *openaicompat.FeatureUnsupportedError
|
||||
if !errors.As(err, &fue) {
|
||||
t.Fatalf("want FeatureUnsupportedError, got %v", err)
|
||||
}
|
||||
if hit {
|
||||
t.Error("server was contacted despite Rules violation")
|
||||
}
|
||||
}
|
||||
+158
@@ -0,0 +1,158 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/deepseek"
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/groq"
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/moonshot"
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/ollama"
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openai"
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/xai"
|
||||
)
|
||||
|
||||
// ProviderInfo describes a registered provider for discovery purposes (CLI
|
||||
// pickers, wiring layers, admin tools). It is the single source of truth for
|
||||
// "what providers exist and how do I instantiate one."
|
||||
type ProviderInfo struct {
|
||||
// Name is the short lowercase identifier used in provider/model strings
|
||||
// (e.g., "openai", "deepseek", "moonshot").
|
||||
Name string
|
||||
|
||||
// DisplayName is a human-readable label for UIs.
|
||||
DisplayName string
|
||||
|
||||
// EnvKey is the conventional environment variable that holds the API key
|
||||
// for this provider. Empty string means "no key needed" (e.g., Ollama).
|
||||
EnvKey string
|
||||
|
||||
// DefaultURL is the default base URL used when no override is supplied.
|
||||
DefaultURL string
|
||||
|
||||
// Models is a list of well-known model names, populated for CLI pickers
|
||||
// and similar. It is not exhaustive and not validated against the API.
|
||||
Models []string
|
||||
|
||||
// New returns a ready-to-use Client for this provider, given an API key
|
||||
// (ignored for key-less providers like Ollama) and optional ClientOptions.
|
||||
New func(apiKey string, opts ...ClientOption) *Client
|
||||
}
|
||||
|
||||
// providerRegistry is the in-process list of known providers. Order is
|
||||
// intentional: the three original providers first, then OpenAI-compatible
|
||||
// additions in the order they were added.
|
||||
var providerRegistry = []ProviderInfo{
|
||||
{
|
||||
Name: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
EnvKey: "OPENAI_API_KEY",
|
||||
DefaultURL: openai.DefaultBaseURL,
|
||||
Models: []string{
|
||||
"gpt-4.1", "gpt-4.1-mini", "gpt-4.1-nano",
|
||||
"gpt-4o", "gpt-4o-mini",
|
||||
"gpt-4-turbo", "gpt-3.5-turbo",
|
||||
"o1", "o1-mini", "o1-preview", "o3-mini",
|
||||
},
|
||||
New: OpenAI,
|
||||
},
|
||||
{
|
||||
Name: "anthropic",
|
||||
DisplayName: "Anthropic",
|
||||
EnvKey: "ANTHROPIC_API_KEY",
|
||||
DefaultURL: "https://api.anthropic.com",
|
||||
Models: []string{
|
||||
"claude-opus-4-7",
|
||||
"claude-sonnet-4-6",
|
||||
"claude-haiku-4-5-20251001",
|
||||
"claude-opus-4-20250514",
|
||||
"claude-sonnet-4-20250514",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-5-haiku-20241022",
|
||||
},
|
||||
New: Anthropic,
|
||||
},
|
||||
{
|
||||
Name: "google",
|
||||
DisplayName: "Google",
|
||||
EnvKey: "GOOGLE_API_KEY",
|
||||
DefaultURL: "https://generativelanguage.googleapis.com",
|
||||
Models: []string{
|
||||
"gemini-2.0-flash", "gemini-2.0-flash-lite",
|
||||
"gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b",
|
||||
},
|
||||
New: Google,
|
||||
},
|
||||
{
|
||||
Name: "deepseek",
|
||||
DisplayName: "DeepSeek",
|
||||
EnvKey: "DEEPSEEK_API_KEY",
|
||||
DefaultURL: deepseek.DefaultBaseURL,
|
||||
Models: []string{"deepseek-chat", "deepseek-reasoner"},
|
||||
New: DeepSeek,
|
||||
},
|
||||
{
|
||||
Name: "moonshot",
|
||||
DisplayName: "Moonshot (Kimi)",
|
||||
EnvKey: "MOONSHOT_API_KEY",
|
||||
DefaultURL: moonshot.DefaultBaseURL,
|
||||
Models: []string{
|
||||
"kimi-k2-0711-preview",
|
||||
"moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k",
|
||||
"moonshot-v1-8k-vision-preview",
|
||||
},
|
||||
New: Moonshot,
|
||||
},
|
||||
{
|
||||
Name: "xai",
|
||||
DisplayName: "xAI (Grok)",
|
||||
EnvKey: "XAI_API_KEY",
|
||||
DefaultURL: xai.DefaultBaseURL,
|
||||
Models: []string{
|
||||
"grok-2", "grok-2-mini", "grok-2-vision", "grok-beta",
|
||||
},
|
||||
New: XAI,
|
||||
},
|
||||
{
|
||||
Name: "groq",
|
||||
DisplayName: "Groq",
|
||||
EnvKey: "GROQ_API_KEY",
|
||||
DefaultURL: groq.DefaultBaseURL,
|
||||
Models: []string{
|
||||
"llama-3.3-70b-versatile",
|
||||
"llama-3.1-8b-instant",
|
||||
"mixtral-8x7b-32768",
|
||||
"gemma2-9b-it",
|
||||
"llama-3.2-90b-vision-preview",
|
||||
},
|
||||
New: Groq,
|
||||
},
|
||||
{
|
||||
Name: "ollama",
|
||||
DisplayName: "Ollama (local)",
|
||||
EnvKey: "", // no key needed
|
||||
DefaultURL: ollama.DefaultBaseURL,
|
||||
Models: []string{
|
||||
"llama3.2", "llama3.1", "qwen2.5", "mistral", "gemma2", "phi4",
|
||||
},
|
||||
New: func(_ string, opts ...ClientOption) *Client { return Ollama(opts...) },
|
||||
},
|
||||
}
|
||||
|
||||
// Providers returns a copy of the registered provider list so callers cannot
|
||||
// mutate library state.
|
||||
func Providers() []ProviderInfo {
|
||||
out := make([]ProviderInfo, len(providerRegistry))
|
||||
copy(out, providerRegistry)
|
||||
return out
|
||||
}
|
||||
|
||||
// ProviderByName returns the registered ProviderInfo with the given name, or
|
||||
// nil if no such provider is registered. Name matching is exact.
|
||||
func ProviderByName(name string) *ProviderInfo {
|
||||
for i := range providerRegistry {
|
||||
if providerRegistry[i].Name == name {
|
||||
p := providerRegistry[i]
|
||||
return &p
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
// Package xai implements the go-llm v2 provider interface for xAI (Grok,
|
||||
// https://x.ai/api). xAI speaks OpenAI Chat Completions, so this package is a
|
||||
// thin wrapper over openaicompat with its own defaults and per-model Rules.
|
||||
package xai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||
)
|
||||
|
||||
// DefaultBaseURL is the public xAI API endpoint.
|
||||
const DefaultBaseURL = "https://api.x.ai/v1"
|
||||
|
||||
// Provider is a type alias over openaicompat.Provider.
|
||||
type Provider = openaicompat.Provider
|
||||
|
||||
// New creates a new xAI provider. An empty baseURL uses DefaultBaseURL.
|
||||
func New(apiKey, baseURL string) *Provider {
|
||||
if baseURL == "" {
|
||||
baseURL = DefaultBaseURL
|
||||
}
|
||||
return openaicompat.New(apiKey, baseURL, openaicompat.Rules{
|
||||
// Grok models whose name contains "vision" accept images; others don't.
|
||||
SupportsVision: func(m string) bool {
|
||||
return strings.Contains(m, "vision")
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package xai_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/xai"
|
||||
)
|
||||
|
||||
func TestNew_Basic(t *testing.T) {
|
||||
if p := xai.New("key", ""); p == nil {
|
||||
t.Fatal("New returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRules_Grok2RejectsImages(t *testing.T) {
|
||||
p := xai.New("key", "")
|
||||
req := provider.Request{
|
||||
Model: "grok-2",
|
||||
Messages: []provider.Message{{
|
||||
Role: "user",
|
||||
Images: []provider.Image{{URL: "a"}},
|
||||
}},
|
||||
}
|
||||
_, err := p.Complete(context.Background(), req)
|
||||
var fue *openaicompat.FeatureUnsupportedError
|
||||
if !errors.As(err, &fue) || fue.Feature != "vision" {
|
||||
t.Fatalf("want FeatureUnsupportedError(vision), got %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user