feat: add DeepSeek, Moonshot, xAI, Groq, Ollama; drop v1; migrate TUI to v2
Five OpenAI-compatible providers join the library as first-class constructors (llm.DeepSeek, llm.Moonshot, llm.XAI, llm.Groq, llm.Ollama). Their wire-level implementation is shared via a new v2/openaicompat package which is the extracted guts of the old v2/openai provider; each provider supplies its own Rules value to declare per-model constraints (e.g., DeepSeek Reasoner rejects tools and temperature, Moonshot/xAI accept images only on *-vision* models, Groq rejects audio input). v2/openai itself becomes a thin wrapper that sets RestrictTemperature for o-series and gpt-5 models. A new provider registry (v2/registry.go) exposes llm.Providers() and drives the TUI's provider picker so adding a provider in future is a single-file change. The TUI at cmd/llm was migrated from v1 to v2 and moved to v2/cmd/llm. With nothing else depending on v1, the v1 code at the repo root (all .go files, schema/, internal/, provider/, root go.mod/go.sum) is deleted. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -1,88 +1,31 @@
|
|||||||
# CLAUDE.md for go-llm
|
# CLAUDE.md for go-llm
|
||||||
|
|
||||||
## Build and Test Commands
|
All Go code now lives under `v2/`. The module path is
|
||||||
- Build project: `go build ./...`
|
`gitea.stevedudenhoeffer.com/steve/go-llm/v2`. There is no module at the
|
||||||
- Run all tests: `go test ./...`
|
repository root anymore; the v1 code at the root was deleted after all
|
||||||
- Run specific test: `go test -v -run <TestName> ./...`
|
consumers migrated to v2.
|
||||||
- Tidy dependencies: `go mod tidy`
|
|
||||||
|
|
||||||
## Code Style Guidelines
|
See `v2/CLAUDE.md` for build/test commands and per-package guidance.
|
||||||
- **Indentation**: Use standard Go tabs for indentation.
|
|
||||||
- **Naming**:
|
|
||||||
- Use `camelCase` for internal/private variables and functions.
|
|
||||||
- Use `PascalCase` for exported types, functions, and struct fields.
|
|
||||||
- Interface names should be concise (e.g., `LLM`, `ChatCompletion`).
|
|
||||||
- **Error Handling**:
|
|
||||||
- Always check and handle errors immediately.
|
|
||||||
- Wrap errors with context using `fmt.Errorf("%w: ...", err)`.
|
|
||||||
- Use the project's internal `Error` struct in `error.go` when differentiating between error types is needed.
|
|
||||||
- **Project Structure**:
|
|
||||||
- `llm.go`: Contains core interfaces (`LLM`, `ChatCompletion`) and shared types (`Message`, `Role`, `Image`).
|
|
||||||
- Provider implementations are in `openai.go`, `anthropic.go`, and `google.go`.
|
|
||||||
- Schema definitions for tool calling are in the `schema/` directory.
|
|
||||||
- `mcp.go`: MCP (Model Context Protocol) client integration for connecting to MCP servers.
|
|
||||||
- **Imports**: Organize imports into groups: standard library, then third-party libraries.
|
|
||||||
- **Documentation**: Use standard Go doc comments for exported symbols.
|
|
||||||
- **README.md**: The README.md file should always be kept up to date with any significant changes to the project.
|
|
||||||
|
|
||||||
## CLI Tool
|
## CLI
|
||||||
- Build CLI: `go build ./cmd/llm`
|
|
||||||
- Run CLI: `./llm` (or `llm.exe` on Windows)
|
|
||||||
- Run without building: `go run ./cmd/llm`
|
|
||||||
|
|
||||||
### CLI Features
|
The interactive TUI lives at `v2/cmd/llm`:
|
||||||
- Interactive TUI for testing all go-llm features
|
|
||||||
- Support for OpenAI, Anthropic, and Google providers
|
|
||||||
- Image input (file path, URL, or base64)
|
|
||||||
- Tool/function calling with demo tools
|
|
||||||
- Temperature control and settings
|
|
||||||
|
|
||||||
### Key Bindings
|
```
|
||||||
- `Enter` - Send message
|
cd v2 && go run ./cmd/llm
|
||||||
- `Ctrl+I` - Add image
|
|
||||||
- `Ctrl+T` - Toggle tools panel
|
|
||||||
- `Ctrl+P` - Change provider
|
|
||||||
- `Ctrl+M` - Change model
|
|
||||||
- `Ctrl+S` - Settings
|
|
||||||
- `Ctrl+N` - New conversation
|
|
||||||
- `Esc` - Exit/Cancel
|
|
||||||
|
|
||||||
## MCP (Model Context Protocol) Support
|
|
||||||
|
|
||||||
The library supports connecting to MCP servers to use their tools. MCP servers can be connected via:
|
|
||||||
- **stdio**: Run a command as a subprocess
|
|
||||||
- **sse**: Connect to an SSE endpoint
|
|
||||||
- **http**: Connect to a streamable HTTP endpoint
|
|
||||||
|
|
||||||
### Usage Example
|
|
||||||
```go
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// Create and connect to an MCP server
|
|
||||||
server := &llm.MCPServer{
|
|
||||||
Name: "my-server",
|
|
||||||
Command: "my-mcp-server",
|
|
||||||
Args: []string{"--some-flag"},
|
|
||||||
}
|
|
||||||
if err := server.Connect(ctx); err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
// Add the server to a toolbox
|
|
||||||
toolbox := llm.NewToolBox().WithMCPServer(server)
|
|
||||||
|
|
||||||
// Use the toolbox in requests - MCP tools are automatically available
|
|
||||||
req := llm.Request{
|
|
||||||
Messages: []llm.Message{{Role: llm.RoleUser, Text: "Use the MCP tool"}},
|
|
||||||
Toolbox: toolbox,
|
|
||||||
}
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### MCPServer Options
|
It iterates `llm.Providers()` so every registered provider (OpenAI, Anthropic,
|
||||||
- `Name`: Friendly name for logging
|
Google, DeepSeek, Moonshot, xAI, Groq, Ollama) appears in the picker
|
||||||
- `Command`: Command to run (for stdio transport)
|
automatically. Status is derived from each provider's env var; Ollama shows as
|
||||||
- `Args`: Command arguments
|
"(local)" because it needs no key.
|
||||||
- `Env`: Additional environment variables
|
|
||||||
- `URL`: Endpoint URL (for sse/http transport)
|
### Key bindings
|
||||||
- `Transport`: "stdio" (default), "sse", or "http"
|
- `Enter` — Send message
|
||||||
|
- `Ctrl+I` — Add image
|
||||||
|
- `Ctrl+T` — Toggle tools panel
|
||||||
|
- `Ctrl+P` — Change provider
|
||||||
|
- `Ctrl+M` — Change model
|
||||||
|
- `Ctrl+S` — Settings
|
||||||
|
- `Ctrl+N` — New conversation
|
||||||
|
- `Esc` — Exit/Cancel
|
||||||
|
|||||||
-225
@@ -1,225 +0,0 @@
|
|||||||
package llm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"log/slog"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"gitea.stevedudenhoeffer.com/steve/go-llm/internal/imageutil"
|
|
||||||
|
|
||||||
anth "github.com/liushuangls/go-anthropic/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
type anthropicImpl struct {
|
|
||||||
key string
|
|
||||||
model string
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ LLM = anthropicImpl{}
|
|
||||||
|
|
||||||
func (a anthropicImpl) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
|
||||||
a.model = modelVersion
|
|
||||||
|
|
||||||
// TODO: model verification?
|
|
||||||
return a, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func deferClose(c io.Closer) {
|
|
||||||
err := c.Close()
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("error closing", "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a anthropicImpl) requestToAnthropicRequest(req Request) anth.MessagesRequest {
|
|
||||||
res := anth.MessagesRequest{
|
|
||||||
Model: anth.Model(a.model),
|
|
||||||
MaxTokens: 1000,
|
|
||||||
}
|
|
||||||
|
|
||||||
msgs := []anth.Message{}
|
|
||||||
|
|
||||||
// we gotta convert messages into anthropic messages, however
|
|
||||||
// anthropic does not have a "system" message type, so we need to
|
|
||||||
// append it to the res.System field instead
|
|
||||||
|
|
||||||
for _, msg := range req.Messages {
|
|
||||||
if msg.Role == RoleSystem {
|
|
||||||
if len(res.System) > 0 {
|
|
||||||
res.System += "\n"
|
|
||||||
}
|
|
||||||
res.System += msg.Text
|
|
||||||
} else {
|
|
||||||
role := anth.RoleUser
|
|
||||||
|
|
||||||
if msg.Role == RoleAssistant {
|
|
||||||
role = anth.RoleAssistant
|
|
||||||
}
|
|
||||||
|
|
||||||
m := anth.Message{
|
|
||||||
Role: role,
|
|
||||||
Content: []anth.MessageContent{},
|
|
||||||
}
|
|
||||||
|
|
||||||
if msg.Text != "" {
|
|
||||||
m.Content = append(m.Content, anth.MessageContent{
|
|
||||||
Type: anth.MessagesContentTypeText,
|
|
||||||
Text: &msg.Text,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, img := range msg.Images {
|
|
||||||
// anthropic doesn't allow the assistant to send images, so we need to say it's from the user
|
|
||||||
if m.Role == anth.RoleAssistant {
|
|
||||||
m.Role = anth.RoleUser
|
|
||||||
}
|
|
||||||
|
|
||||||
if img.Base64 != "" {
|
|
||||||
// Anthropic models expect images to be < 5MiB in size
|
|
||||||
raw, err := base64.StdEncoding.DecodeString(img.Base64)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if image size exceeds 5MiB (5242880 bytes)
|
|
||||||
if len(raw) >= 5242880 {
|
|
||||||
|
|
||||||
compressed, mime, err := imageutil.CompressImage(img.Base64, 5*1024*1024)
|
|
||||||
|
|
||||||
// just replace the image with the compressed one
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
img.Base64 = compressed
|
|
||||||
img.ContentType = mime
|
|
||||||
}
|
|
||||||
|
|
||||||
m.Content = append(m.Content, anth.NewImageMessageContent(
|
|
||||||
anth.NewMessageContentSource(
|
|
||||||
anth.MessagesContentSourceTypeBase64,
|
|
||||||
img.ContentType,
|
|
||||||
img.Base64,
|
|
||||||
)))
|
|
||||||
} else if img.Url != "" {
|
|
||||||
|
|
||||||
// download the image
|
|
||||||
cl, err := http.NewRequest(http.MethodGet, img.Url, nil)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("failed to create request", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(cl)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("failed to download image", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
defer deferClose(resp.Body)
|
|
||||||
|
|
||||||
img.ContentType = resp.Header.Get("Content-Type")
|
|
||||||
|
|
||||||
// read the image
|
|
||||||
b, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("failed to read image", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// base64 encode the image
|
|
||||||
img.Base64 = string(b)
|
|
||||||
|
|
||||||
m.Content = append(m.Content, anth.NewImageMessageContent(
|
|
||||||
anth.NewMessageContentSource(
|
|
||||||
anth.MessagesContentSourceTypeBase64,
|
|
||||||
img.ContentType,
|
|
||||||
img.Base64,
|
|
||||||
)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// if this has the same role as the previous message, we can append it to the previous message
|
|
||||||
// as anthropic expects alternating assistant and user roles
|
|
||||||
if len(msgs) > 0 && msgs[len(msgs)-1].Role == role {
|
|
||||||
m2 := &msgs[len(msgs)-1]
|
|
||||||
|
|
||||||
m2.Content = append(m2.Content, m.Content...)
|
|
||||||
} else {
|
|
||||||
msgs = append(msgs, m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tool := range req.Toolbox.Functions() {
|
|
||||||
res.Tools = append(res.Tools, anth.ToolDefinition{
|
|
||||||
Name: tool.Name,
|
|
||||||
Description: tool.Description,
|
|
||||||
InputSchema: tool.Parameters.AnthropicInputSchema(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
res.Messages = msgs
|
|
||||||
|
|
||||||
if req.Temperature != nil {
|
|
||||||
var f = float32(*req.Temperature)
|
|
||||||
res.Temperature = &f
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Println("llm request to anthropic request", res)
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a anthropicImpl) responseToLLMResponse(in anth.MessagesResponse) Response {
|
|
||||||
choice := ResponseChoice{}
|
|
||||||
for _, msg := range in.Content {
|
|
||||||
|
|
||||||
switch msg.Type {
|
|
||||||
case anth.MessagesContentTypeText:
|
|
||||||
if msg.Text != nil {
|
|
||||||
choice.Content += *msg.Text
|
|
||||||
}
|
|
||||||
|
|
||||||
case anth.MessagesContentTypeToolUse:
|
|
||||||
if msg.MessageContentToolUse != nil {
|
|
||||||
b, e := json.Marshal(msg.MessageContentToolUse.Input)
|
|
||||||
if e != nil {
|
|
||||||
log.Println("failed to marshal input", e)
|
|
||||||
} else {
|
|
||||||
choice.Calls = append(choice.Calls, ToolCall{
|
|
||||||
ID: msg.MessageContentToolUse.ID,
|
|
||||||
FunctionCall: FunctionCall{
|
|
||||||
Name: msg.MessageContentToolUse.Name,
|
|
||||||
Arguments: string(b),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Println("anthropic response to llm response", choice)
|
|
||||||
|
|
||||||
return Response{
|
|
||||||
Choices: []ResponseChoice{choice},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a anthropicImpl) ChatComplete(ctx context.Context, req Request) (Response, error) {
|
|
||||||
cl := anth.NewClient(a.key)
|
|
||||||
|
|
||||||
res, err := cl.CreateMessages(ctx, a.requestToAnthropicRequest(req))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return Response{}, fmt.Errorf("failed to chat complete: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return a.responseToLLMResponse(res), nil
|
|
||||||
}
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
# go-llm CLI Environment Variables
|
|
||||||
# Copy this file to .env and fill in your API keys
|
|
||||||
|
|
||||||
# OpenAI API Key (https://platform.openai.com/api-keys)
|
|
||||||
OPENAI_API_KEY=
|
|
||||||
|
|
||||||
# Anthropic API Key (https://console.anthropic.com/settings/keys)
|
|
||||||
ANTHROPIC_API_KEY=
|
|
||||||
|
|
||||||
# Google AI API Key (https://aistudio.google.com/apikey)
|
|
||||||
GOOGLE_API_KEY=
|
|
||||||
@@ -1,182 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
|
||||||
|
|
||||||
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Message types for async operations
|
|
||||||
|
|
||||||
// ChatResponseMsg contains the response from a chat completion
|
|
||||||
type ChatResponseMsg struct {
|
|
||||||
Response llm.Response
|
|
||||||
Err error
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToolExecutionMsg contains results from tool execution
|
|
||||||
type ToolExecutionMsg struct {
|
|
||||||
Results []llm.ToolCallResponse
|
|
||||||
Err error
|
|
||||||
}
|
|
||||||
|
|
||||||
// ImageLoadedMsg contains a loaded image
|
|
||||||
type ImageLoadedMsg struct {
|
|
||||||
Image llm.Image
|
|
||||||
Err error
|
|
||||||
}
|
|
||||||
|
|
||||||
// sendChatRequest sends a chat completion request
|
|
||||||
func sendChatRequest(chat llm.ChatCompletion, req llm.Request) tea.Cmd {
|
|
||||||
return func() tea.Msg {
|
|
||||||
resp, err := chat.ChatComplete(context.Background(), req)
|
|
||||||
return ChatResponseMsg{Response: resp, Err: err}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// executeTools executes tool calls and returns results
|
|
||||||
func executeTools(toolbox llm.ToolBox, req llm.Request, resp llm.ResponseChoice) tea.Cmd {
|
|
||||||
return func() tea.Msg {
|
|
||||||
ctx := llm.NewContext(context.Background(), req, &resp, nil)
|
|
||||||
var results []llm.ToolCallResponse
|
|
||||||
|
|
||||||
for _, call := range resp.Calls {
|
|
||||||
result, err := toolbox.Execute(ctx, call)
|
|
||||||
results = append(results, llm.ToolCallResponse{
|
|
||||||
ID: call.ID,
|
|
||||||
Result: result,
|
|
||||||
Error: err,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return ToolExecutionMsg{Results: results, Err: nil}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadImageFromPath loads an image from a file path
|
|
||||||
func loadImageFromPath(path string) tea.Cmd {
|
|
||||||
return func() tea.Msg {
|
|
||||||
// Clean up the path
|
|
||||||
path = strings.TrimSpace(path)
|
|
||||||
path = strings.Trim(path, "\"'")
|
|
||||||
|
|
||||||
// Read the file
|
|
||||||
data, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
return ImageLoadedMsg{Err: fmt.Errorf("failed to read image file: %w", err)}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detect content type
|
|
||||||
contentType := http.DetectContentType(data)
|
|
||||||
if !strings.HasPrefix(contentType, "image/") {
|
|
||||||
return ImageLoadedMsg{Err: fmt.Errorf("file is not an image: %s", contentType)}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Base64 encode
|
|
||||||
encoded := base64.StdEncoding.EncodeToString(data)
|
|
||||||
|
|
||||||
return ImageLoadedMsg{
|
|
||||||
Image: llm.Image{
|
|
||||||
Base64: encoded,
|
|
||||||
ContentType: contentType,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadImageFromURL loads an image from a URL
|
|
||||||
func loadImageFromURL(url string) tea.Cmd {
|
|
||||||
return func() tea.Msg {
|
|
||||||
url = strings.TrimSpace(url)
|
|
||||||
|
|
||||||
// For URL images, we can just use the URL directly
|
|
||||||
return ImageLoadedMsg{
|
|
||||||
Image: llm.Image{
|
|
||||||
Url: url,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadImageFromBase64 loads an image from base64 data
|
|
||||||
func loadImageFromBase64(data string) tea.Cmd {
|
|
||||||
return func() tea.Msg {
|
|
||||||
data = strings.TrimSpace(data)
|
|
||||||
|
|
||||||
// Check if it's a data URL
|
|
||||||
if strings.HasPrefix(data, "data:") {
|
|
||||||
// Parse data URL: data:image/png;base64,....
|
|
||||||
parts := strings.SplitN(data, ",", 2)
|
|
||||||
if len(parts) != 2 {
|
|
||||||
return ImageLoadedMsg{Err: fmt.Errorf("invalid data URL format")}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract content type from first part
|
|
||||||
mediaType := strings.TrimPrefix(parts[0], "data:")
|
|
||||||
mediaType = strings.TrimSuffix(mediaType, ";base64")
|
|
||||||
|
|
||||||
return ImageLoadedMsg{
|
|
||||||
Image: llm.Image{
|
|
||||||
Base64: parts[1],
|
|
||||||
ContentType: mediaType,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assume it's raw base64, try to detect content type
|
|
||||||
decoded, err := base64.StdEncoding.DecodeString(data)
|
|
||||||
if err != nil {
|
|
||||||
return ImageLoadedMsg{Err: fmt.Errorf("invalid base64 data: %w", err)}
|
|
||||||
}
|
|
||||||
|
|
||||||
contentType := http.DetectContentType(decoded)
|
|
||||||
if !strings.HasPrefix(contentType, "image/") {
|
|
||||||
return ImageLoadedMsg{Err: fmt.Errorf("data is not an image: %s", contentType)}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ImageLoadedMsg{
|
|
||||||
Image: llm.Image{
|
|
||||||
Base64: data,
|
|
||||||
ContentType: contentType,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildRequest builds a chat request from the current state
|
|
||||||
func buildRequest(m *Model, userText string) llm.Request {
|
|
||||||
// Create the user message with any pending images
|
|
||||||
userMsg := llm.Message{
|
|
||||||
Role: llm.RoleUser,
|
|
||||||
Text: userText,
|
|
||||||
Images: m.pendingImages,
|
|
||||||
}
|
|
||||||
|
|
||||||
req := llm.Request{
|
|
||||||
Conversation: m.conversation,
|
|
||||||
Messages: []llm.Message{
|
|
||||||
{Role: llm.RoleSystem, Text: m.systemPrompt},
|
|
||||||
userMsg,
|
|
||||||
},
|
|
||||||
Temperature: m.temperature,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add toolbox if enabled
|
|
||||||
if m.toolsEnabled && len(m.toolbox.Functions()) > 0 {
|
|
||||||
req.Toolbox = m.toolbox.WithRequireTool(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
return req
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildFollowUpRequest builds a follow-up request after tool execution
|
|
||||||
func buildFollowUpRequest(m *Model, previousReq llm.Request, resp llm.ResponseChoice, toolResults []llm.ToolCallResponse) llm.Request {
|
|
||||||
return previousReq.NextRequest(resp, toolResults)
|
|
||||||
}
|
|
||||||
@@ -1,105 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TimeParams is the parameter struct for the GetTime function
|
|
||||||
type TimeParams struct{}
|
|
||||||
|
|
||||||
// GetTime returns the current time
|
|
||||||
func GetTime(_ *llm.Context, _ TimeParams) (any, error) {
|
|
||||||
return time.Now().Format("Monday, January 2, 2006 3:04:05 PM MST"), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CalcParams is the parameter struct for the Calculate function
|
|
||||||
type CalcParams struct {
|
|
||||||
A float64 `json:"a" description:"First number"`
|
|
||||||
B float64 `json:"b" description:"Second number"`
|
|
||||||
Op string `json:"op" description:"Operation: add, subtract, multiply, divide, power, sqrt, mod"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate performs basic math operations
|
|
||||||
func Calculate(_ *llm.Context, params CalcParams) (any, error) {
|
|
||||||
switch strings.ToLower(params.Op) {
|
|
||||||
case "add", "+":
|
|
||||||
return params.A + params.B, nil
|
|
||||||
case "subtract", "sub", "-":
|
|
||||||
return params.A - params.B, nil
|
|
||||||
case "multiply", "mul", "*":
|
|
||||||
return params.A * params.B, nil
|
|
||||||
case "divide", "div", "/":
|
|
||||||
if params.B == 0 {
|
|
||||||
return nil, fmt.Errorf("division by zero")
|
|
||||||
}
|
|
||||||
return params.A / params.B, nil
|
|
||||||
case "power", "pow", "^":
|
|
||||||
return math.Pow(params.A, params.B), nil
|
|
||||||
case "sqrt":
|
|
||||||
if params.A < 0 {
|
|
||||||
return nil, fmt.Errorf("cannot take square root of negative number")
|
|
||||||
}
|
|
||||||
return math.Sqrt(params.A), nil
|
|
||||||
case "mod", "%":
|
|
||||||
return math.Mod(params.A, params.B), nil
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unknown operation: %s", params.Op)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WeatherParams is the parameter struct for the GetWeather function
|
|
||||||
type WeatherParams struct {
|
|
||||||
Location string `json:"location" description:"City name or location"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetWeather returns mock weather data (for demo purposes)
|
|
||||||
func GetWeather(_ *llm.Context, params WeatherParams) (any, error) {
|
|
||||||
// This is a demo function - returns mock data
|
|
||||||
weathers := []string{"sunny", "cloudy", "rainy", "partly cloudy", "windy"}
|
|
||||||
temps := []int{65, 72, 58, 80, 45}
|
|
||||||
|
|
||||||
// Use location string to deterministically pick weather
|
|
||||||
idx := len(params.Location) % len(weathers)
|
|
||||||
|
|
||||||
return map[string]any{
|
|
||||||
"location": params.Location,
|
|
||||||
"temperature": strconv.Itoa(temps[idx]) + "F",
|
|
||||||
"condition": weathers[idx],
|
|
||||||
"humidity": "45%",
|
|
||||||
"note": "This is mock data for demonstration purposes",
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RandomNumberParams is the parameter struct for the RandomNumber function
|
|
||||||
type RandomNumberParams struct {
|
|
||||||
Min int `json:"min" description:"Minimum value (inclusive)"`
|
|
||||||
Max int `json:"max" description:"Maximum value (inclusive)"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// RandomNumber generates a pseudo-random number (using current time nanoseconds)
|
|
||||||
func RandomNumber(_ *llm.Context, params RandomNumberParams) (any, error) {
|
|
||||||
if params.Min > params.Max {
|
|
||||||
return nil, fmt.Errorf("min cannot be greater than max")
|
|
||||||
}
|
|
||||||
// Simple pseudo-random using time
|
|
||||||
n := time.Now().UnixNano()
|
|
||||||
rangeSize := params.Max - params.Min + 1
|
|
||||||
result := params.Min + int(n%int64(rangeSize))
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// createDemoToolbox creates a toolbox with demo tools for testing
|
|
||||||
func createDemoToolbox() llm.ToolBox {
|
|
||||||
return llm.NewToolBox(
|
|
||||||
llm.NewFunction("get_time", "Get the current date and time", GetTime),
|
|
||||||
llm.NewFunction("calculate", "Perform basic math operations (add, subtract, multiply, divide, power, sqrt, mod)", Calculate),
|
|
||||||
llm.NewFunction("get_weather", "Get weather information for a location (demo data)", GetWeather),
|
|
||||||
llm.NewFunction("random_number", "Generate a random number between min and max", RandomNumber),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
-120
@@ -1,120 +0,0 @@
|
|||||||
package llm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Context struct {
|
|
||||||
context.Context
|
|
||||||
request Request
|
|
||||||
response *ResponseChoice
|
|
||||||
toolcall *ToolCall
|
|
||||||
syntheticFields map[string]string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) ToNewRequest(toolResults ...ToolCallResponse) Request {
|
|
||||||
var res Request
|
|
||||||
|
|
||||||
res.Toolbox = c.request.Toolbox
|
|
||||||
res.Temperature = c.request.Temperature
|
|
||||||
|
|
||||||
res.Conversation = make([]Input, len(c.request.Conversation))
|
|
||||||
copy(res.Conversation, c.request.Conversation)
|
|
||||||
|
|
||||||
// now for every input message, convert those to an Input to add to the conversation
|
|
||||||
for _, msg := range c.request.Messages {
|
|
||||||
res.Conversation = append(res.Conversation, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// if there are tool calls, then we need to add those to the conversation
|
|
||||||
if c.response != nil {
|
|
||||||
res.Conversation = append(res.Conversation, *c.response)
|
|
||||||
}
|
|
||||||
|
|
||||||
// if there are tool results, then we need to add those to the conversation
|
|
||||||
for _, result := range toolResults {
|
|
||||||
res.Conversation = append(res.Conversation, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewContext(ctx context.Context, request Request, response *ResponseChoice, toolcall *ToolCall) *Context {
|
|
||||||
return &Context{Context: ctx, request: request, response: response, toolcall: toolcall}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) Request() Request {
|
|
||||||
return c.request
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) Response() *ResponseChoice {
|
|
||||||
return c.response
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) ToolCall() *ToolCall {
|
|
||||||
return c.toolcall
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) SyntheticFields() map[string]string {
|
|
||||||
if c.syntheticFields == nil {
|
|
||||||
c.syntheticFields = map[string]string{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.syntheticFields
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) WithContext(ctx context.Context) *Context {
|
|
||||||
return &Context{Context: ctx, request: c.request, response: c.response, toolcall: c.toolcall, syntheticFields: c.syntheticFields}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) WithRequest(request Request) *Context {
|
|
||||||
return &Context{Context: c.Context, request: request, response: c.response, toolcall: c.toolcall, syntheticFields: c.syntheticFields}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) WithResponse(response *ResponseChoice) *Context {
|
|
||||||
return &Context{Context: c.Context, request: c.request, response: response, toolcall: c.toolcall, syntheticFields: c.syntheticFields}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) WithToolCall(toolcall *ToolCall) *Context {
|
|
||||||
return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: toolcall, syntheticFields: c.syntheticFields}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) WithSyntheticFields(syntheticFields map[string]string) *Context {
|
|
||||||
return &Context{Context: c.Context, request: c.request, response: c.response, toolcall: c.toolcall, syntheticFields: syntheticFields}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) Deadline() (deadline time.Time, ok bool) {
|
|
||||||
return c.Context.Deadline()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) Done() <-chan struct{} {
|
|
||||||
return c.Context.Done()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) Err() error {
|
|
||||||
return c.Context.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) Value(key any) any {
|
|
||||||
switch key {
|
|
||||||
case "request":
|
|
||||||
return c.request
|
|
||||||
|
|
||||||
case "response":
|
|
||||||
return c.response
|
|
||||||
|
|
||||||
case "toolcall":
|
|
||||||
return c.toolcall
|
|
||||||
|
|
||||||
case "syntheticFields":
|
|
||||||
return c.syntheticFields
|
|
||||||
|
|
||||||
}
|
|
||||||
return c.Context.Value(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) WithTimeout(timeout time.Duration) (*Context, context.CancelFunc) {
|
|
||||||
ctx, cancel := context.WithTimeout(c.Context, timeout)
|
|
||||||
return c.WithContext(ctx), cancel
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
package llm
|
|
||||||
|
|
||||||
import "fmt"
|
|
||||||
|
|
||||||
// Error is essentially just an error, but it is used to differentiate between a normal error and a fatal error.
|
|
||||||
type Error struct {
|
|
||||||
error
|
|
||||||
|
|
||||||
Source error
|
|
||||||
Parameter error
|
|
||||||
}
|
|
||||||
|
|
||||||
func newError(parent error, err error) Error {
|
|
||||||
e := fmt.Errorf("%w: %w", parent, err)
|
|
||||||
return Error{
|
|
||||||
error: e,
|
|
||||||
|
|
||||||
Source: parent,
|
|
||||||
Parameter: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
-136
@@ -1,136 +0,0 @@
|
|||||||
package llm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"log/slog"
|
|
||||||
"reflect"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Function struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Description string `json:"description,omitempty"`
|
|
||||||
Strict bool `json:"strict,omitempty"`
|
|
||||||
Parameters schema.Type `json:"parameters"`
|
|
||||||
|
|
||||||
Forced bool `json:"forced,omitempty"`
|
|
||||||
|
|
||||||
// Timeout is the maximum time to wait for the function to complete
|
|
||||||
Timeout time.Duration `json:"-"`
|
|
||||||
|
|
||||||
// fn is the function to call, only set if this is constructed with NewFunction
|
|
||||||
fn reflect.Value
|
|
||||||
|
|
||||||
paramType reflect.Type
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Function) WithSyntheticField(name string, description string) Function {
|
|
||||||
if obj, o := f.Parameters.(schema.Object); o {
|
|
||||||
f.Parameters = obj.WithSyntheticField(name, description)
|
|
||||||
}
|
|
||||||
|
|
||||||
return f
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Function) WithSyntheticFields(fieldsAndDescriptions map[string]string) Function {
|
|
||||||
if obj, o := f.Parameters.(schema.Object); o {
|
|
||||||
for k, v := range fieldsAndDescriptions {
|
|
||||||
obj = obj.WithSyntheticField(k, v)
|
|
||||||
}
|
|
||||||
f.Parameters = obj
|
|
||||||
}
|
|
||||||
|
|
||||||
return f
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Function) WithDescription(description string) Function {
|
|
||||||
f.Description = description
|
|
||||||
return f
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Function) Execute(ctx *Context, input string) (any, error) {
|
|
||||||
if !f.fn.IsValid() {
|
|
||||||
return "", fmt.Errorf("function %s is not implemented", f.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("Function.Execute", "name", f.Name, "input", input, "f", f.paramType)
|
|
||||||
// first, we need to parse the input into the struct
|
|
||||||
p := reflect.New(f.paramType)
|
|
||||||
fmt.Println("Function.Execute", f.Name, "input:", input)
|
|
||||||
|
|
||||||
var vals map[string]any
|
|
||||||
err := json.Unmarshal([]byte(input), &vals)
|
|
||||||
|
|
||||||
var syntheticFields map[string]string
|
|
||||||
|
|
||||||
// first eat up any synthetic fields
|
|
||||||
if obj, o := f.Parameters.(schema.Object); o {
|
|
||||||
for k := range obj.SyntheticFields() {
|
|
||||||
key := schema.SyntheticFieldPrefix + k
|
|
||||||
if val, ok := vals[key]; ok {
|
|
||||||
if syntheticFields == nil {
|
|
||||||
syntheticFields = map[string]string{}
|
|
||||||
}
|
|
||||||
|
|
||||||
syntheticFields[k] = fmt.Sprint(val)
|
|
||||||
delete(vals, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// now for any remaining fields, re-marshal them into json and then unmarshal into the struct
|
|
||||||
b, err := json.Marshal(vals)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to marshal input: %w (input: %s)", err, input)
|
|
||||||
}
|
|
||||||
|
|
||||||
// now we can unmarshal the input into the struct
|
|
||||||
err = json.Unmarshal(b, p.Interface())
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to unmarshal input: %w (input: %s)", err, input)
|
|
||||||
}
|
|
||||||
|
|
||||||
// now we can call the function
|
|
||||||
exec := func(ctx *Context) (any, error) {
|
|
||||||
out := f.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()})
|
|
||||||
|
|
||||||
if len(out) != 2 {
|
|
||||||
return "", fmt.Errorf("function %s must return two values, got %d", f.Name, len(out))
|
|
||||||
}
|
|
||||||
|
|
||||||
if out[1].IsNil() {
|
|
||||||
return out[0].Interface(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", out[1].Interface().(error)
|
|
||||||
}
|
|
||||||
|
|
||||||
var cancel context.CancelFunc
|
|
||||||
if f.Timeout > 0 {
|
|
||||||
ctx, cancel = ctx.WithTimeout(f.Timeout)
|
|
||||||
defer cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
return exec(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
type FunctionCall struct {
|
|
||||||
Name string `json:"name,omitempty"`
|
|
||||||
Arguments string `json:"arguments,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fc *FunctionCall) toRaw() map[string]any {
|
|
||||||
res := map[string]interface{}{
|
|
||||||
"name": fc.Name,
|
|
||||||
}
|
|
||||||
|
|
||||||
if fc.Arguments != "" {
|
|
||||||
res["arguments"] = fc.Arguments
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
package llm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Parse takes a function pointer and returns a function object.
|
|
||||||
// fn must be a pointer to a function that takes a context.Context as its first argument, and then a struct that contains
|
|
||||||
// the parameters for the function. The struct must contain only the types: string, int, float64, bool, and pointers to
|
|
||||||
// those types.
|
|
||||||
// The struct parameters can have the following tags:
|
|
||||||
// - Description: a string that describes the parameter, passed to openaiImpl to tell it what the parameter is for
|
|
||||||
|
|
||||||
func NewFunction[T any](name string, description string, fn func(*Context, T) (any, error)) Function {
|
|
||||||
var o T
|
|
||||||
|
|
||||||
res := Function{
|
|
||||||
Name: name,
|
|
||||||
Description: description,
|
|
||||||
Parameters: schema.GetType(o),
|
|
||||||
fn: reflect.ValueOf(fn),
|
|
||||||
paramType: reflect.TypeOf(o),
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.fn.Kind() != reflect.Func {
|
|
||||||
panic("fn must be a function")
|
|
||||||
}
|
|
||||||
if res.paramType.Kind() != reflect.Struct {
|
|
||||||
panic("function parameter must be a struct")
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
module gitea.stevedudenhoeffer.com/steve/go-llm
|
|
||||||
|
|
||||||
go 1.24.0
|
|
||||||
|
|
||||||
toolchain go1.24.2
|
|
||||||
|
|
||||||
require (
|
|
||||||
github.com/charmbracelet/bubbles v0.21.0
|
|
||||||
github.com/charmbracelet/bubbletea v1.3.10
|
|
||||||
github.com/charmbracelet/lipgloss v1.1.0
|
|
||||||
github.com/joho/godotenv v1.5.1
|
|
||||||
github.com/liushuangls/go-anthropic/v2 v2.17.0
|
|
||||||
github.com/modelcontextprotocol/go-sdk v1.2.0
|
|
||||||
github.com/openai/openai-go v1.12.0
|
|
||||||
golang.org/x/image v0.35.0
|
|
||||||
google.golang.org/genai v1.43.0
|
|
||||||
)
|
|
||||||
|
|
||||||
require (
|
|
||||||
cloud.google.com/go v0.123.0 // indirect
|
|
||||||
cloud.google.com/go/auth v0.18.1 // indirect
|
|
||||||
cloud.google.com/go/compute/metadata v0.9.0 // indirect
|
|
||||||
github.com/atotto/clipboard v0.1.4 // indirect
|
|
||||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
|
||||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
|
||||||
github.com/charmbracelet/x/ansi v0.10.1 // indirect
|
|
||||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
|
|
||||||
github.com/charmbracelet/x/term v0.2.1 // indirect
|
|
||||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
|
||||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
|
||||||
github.com/go-logr/logr v1.4.3 // indirect
|
|
||||||
github.com/go-logr/stdr v1.2.2 // indirect
|
|
||||||
github.com/google/go-cmp v0.7.0 // indirect
|
|
||||||
github.com/google/jsonschema-go v0.3.0 // indirect
|
|
||||||
github.com/google/s2a-go v0.1.9 // indirect
|
|
||||||
github.com/googleapis/enterprise-certificate-proxy v0.3.11 // indirect
|
|
||||||
github.com/googleapis/gax-go/v2 v2.16.0 // indirect
|
|
||||||
github.com/gorilla/websocket v1.5.3 // indirect
|
|
||||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
|
||||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
|
||||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
|
||||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
|
||||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
|
||||||
github.com/muesli/termenv v0.16.0 // indirect
|
|
||||||
github.com/rivo/uniseg v0.4.7 // indirect
|
|
||||||
github.com/tidwall/gjson v1.18.0 // indirect
|
|
||||||
github.com/tidwall/match v1.2.0 // indirect
|
|
||||||
github.com/tidwall/pretty v1.2.1 // indirect
|
|
||||||
github.com/tidwall/sjson v1.2.5 // indirect
|
|
||||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
|
||||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
|
||||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
|
||||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 // indirect
|
|
||||||
go.opentelemetry.io/otel v1.39.0 // indirect
|
|
||||||
go.opentelemetry.io/otel/metric v1.39.0 // indirect
|
|
||||||
go.opentelemetry.io/otel/trace v1.39.0 // indirect
|
|
||||||
golang.org/x/crypto v0.47.0 // indirect
|
|
||||||
golang.org/x/net v0.49.0 // indirect
|
|
||||||
golang.org/x/oauth2 v0.32.0 // indirect
|
|
||||||
golang.org/x/sys v0.40.0 // indirect
|
|
||||||
golang.org/x/text v0.33.0 // indirect
|
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260122232226-8e98ce8d340d // indirect
|
|
||||||
google.golang.org/grpc v1.78.0 // indirect
|
|
||||||
google.golang.org/protobuf v1.36.11 // indirect
|
|
||||||
)
|
|
||||||
@@ -1,145 +0,0 @@
|
|||||||
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
|
|
||||||
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
|
|
||||||
cloud.google.com/go/auth v0.18.1 h1:IwTEx92GFUo2pJ6Qea0EU3zYvKnTAeRCODxfA/G5UWs=
|
|
||||||
cloud.google.com/go/auth v0.18.1/go.mod h1:GfTYoS9G3CWpRA3Va9doKN9mjPGRS+v41jmZAhBzbrA=
|
|
||||||
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
|
|
||||||
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
|
|
||||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
|
||||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
|
||||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
|
||||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
|
||||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
|
||||||
github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs=
|
|
||||||
github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg=
|
|
||||||
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
|
||||||
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
|
|
||||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
|
|
||||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
|
|
||||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
|
||||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
|
||||||
github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ=
|
|
||||||
github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
|
|
||||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
|
|
||||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
|
|
||||||
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
|
|
||||||
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
|
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
|
||||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
|
||||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
|
||||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
|
||||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
|
||||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
|
||||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
|
||||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
|
||||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
|
||||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
|
||||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
|
||||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
|
||||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
|
||||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
|
||||||
github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q=
|
|
||||||
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
|
||||||
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
|
|
||||||
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
|
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
|
||||||
github.com/googleapis/enterprise-certificate-proxy v0.3.11 h1:vAe81Msw+8tKUxi2Dqh/NZMz7475yUvmRIkXr4oN2ao=
|
|
||||||
github.com/googleapis/enterprise-certificate-proxy v0.3.11/go.mod h1:RFV7MUdlb7AgEq2v7FmMCfeSMCllAzWxFgRdusoGks8=
|
|
||||||
github.com/googleapis/gax-go/v2 v2.16.0 h1:iHbQmKLLZrexmb0OSsNGTeSTS0HO4YvFOG8g5E4Zd0Y=
|
|
||||||
github.com/googleapis/gax-go/v2 v2.16.0/go.mod h1:o1vfQjjNZn4+dPnRdl/4ZD7S9414Y4xA+a/6Icj6l14=
|
|
||||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
|
||||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
|
||||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
|
||||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
|
||||||
github.com/liushuangls/go-anthropic/v2 v2.17.0 h1:iBA6h7aghi1q86owEQ95XE2R2MF/0dQ7bCxtwTxOg4c=
|
|
||||||
github.com/liushuangls/go-anthropic/v2 v2.17.0/go.mod h1:a550cJXPoTG2FL3DvfKG2zzD5O2vjgvo4tHtoGPzFLU=
|
|
||||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
|
||||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
|
||||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
|
||||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
|
||||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
|
||||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
|
||||||
github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s=
|
|
||||||
github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10=
|
|
||||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
|
||||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
|
||||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
|
||||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
|
||||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
|
||||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
|
||||||
github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0=
|
|
||||||
github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
|
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
|
||||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
|
||||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
|
||||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
|
||||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
|
||||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
|
||||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
|
||||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
|
||||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
|
||||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
|
||||||
github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM=
|
|
||||||
github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
|
||||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
|
||||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
|
||||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
|
||||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
|
||||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
|
||||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
|
||||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
|
||||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
|
||||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
|
||||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
|
||||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
|
||||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 h1:ssfIgGNANqpVFCndZvcuyKbl0g+UAVcbBcqGkG28H0Y=
|
|
||||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0/go.mod h1:GQ/474YrbE4Jx8gZ4q5I4hrhUzM6UPzyrqJYV2AqPoQ=
|
|
||||||
go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48=
|
|
||||||
go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8=
|
|
||||||
go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0=
|
|
||||||
go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs=
|
|
||||||
go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18=
|
|
||||||
go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE=
|
|
||||||
go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8=
|
|
||||||
go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew=
|
|
||||||
go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI=
|
|
||||||
go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA=
|
|
||||||
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
|
|
||||||
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
|
|
||||||
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E=
|
|
||||||
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
|
|
||||||
golang.org/x/image v0.35.0 h1:LKjiHdgMtO8z7Fh18nGY6KDcoEtVfsgLDPeLyguqb7I=
|
|
||||||
golang.org/x/image v0.35.0/go.mod h1:MwPLTVgvxSASsxdLzKrl8BRFuyqMyGhLwmC+TO1Sybk=
|
|
||||||
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
|
|
||||||
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
|
|
||||||
golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY=
|
|
||||||
golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
|
||||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
|
||||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
|
||||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
|
||||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
|
||||||
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
|
|
||||||
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
|
|
||||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
|
||||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
|
||||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
|
||||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
|
||||||
google.golang.org/genai v1.43.0 h1:8vhqhzJNZu1U94e2m+KvDq/TUUjSmDrs1aKkvTa8SoM=
|
|
||||||
google.golang.org/genai v1.43.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260122232226-8e98ce8d340d h1:xXzuihhT3gL/ntduUZwHECzAn57E8dA6l8SOtYWdD8Q=
|
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260122232226-8e98ce8d340d/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ=
|
|
||||||
google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc=
|
|
||||||
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
|
|
||||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
|
||||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
|
||||||
@@ -1,165 +0,0 @@
|
|||||||
package llm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"google.golang.org/genai"
|
|
||||||
)
|
|
||||||
|
|
||||||
type googleImpl struct {
|
|
||||||
key string
|
|
||||||
model string
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ LLM = googleImpl{}
|
|
||||||
|
|
||||||
func (g googleImpl) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
|
||||||
g.model = modelVersion
|
|
||||||
|
|
||||||
return g, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g googleImpl) requestToContents(in Request) ([]*genai.Content, *genai.GenerateContentConfig) {
|
|
||||||
var contents []*genai.Content
|
|
||||||
var cfg genai.GenerateContentConfig
|
|
||||||
|
|
||||||
for _, tool := range in.Toolbox.Functions() {
|
|
||||||
cfg.Tools = append(cfg.Tools, &genai.Tool{
|
|
||||||
FunctionDeclarations: []*genai.FunctionDeclaration{
|
|
||||||
{
|
|
||||||
Name: tool.Name,
|
|
||||||
Description: tool.Description,
|
|
||||||
Parameters: tool.Parameters.GoogleParameters(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if in.Toolbox.RequiresTool() {
|
|
||||||
cfg.ToolConfig = &genai.ToolConfig{FunctionCallingConfig: &genai.FunctionCallingConfig{
|
|
||||||
Mode: genai.FunctionCallingConfigModeAny,
|
|
||||||
}}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range in.Messages {
|
|
||||||
var role genai.Role
|
|
||||||
switch c.Role {
|
|
||||||
case RoleAssistant, RoleSystem:
|
|
||||||
role = genai.RoleModel
|
|
||||||
case RoleUser:
|
|
||||||
role = genai.RoleUser
|
|
||||||
}
|
|
||||||
|
|
||||||
var parts []*genai.Part
|
|
||||||
if c.Text != "" {
|
|
||||||
parts = append(parts, genai.NewPartFromText(c.Text))
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, img := range c.Images {
|
|
||||||
if img.Url != "" {
|
|
||||||
// gemini does not support URLs, so we need to download the image and convert it to a blob
|
|
||||||
resp, err := http.Get(img.Url)
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("error downloading image: %v", err))
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.ContentLength > 20*1024*1024 {
|
|
||||||
panic(fmt.Sprintf("image size exceeds 20MB: %d bytes", resp.ContentLength))
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("error reading image data: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
mimeType := http.DetectContentType(data)
|
|
||||||
switch mimeType {
|
|
||||||
case "image/jpeg", "image/png", "image/gif":
|
|
||||||
// MIME type is valid
|
|
||||||
default:
|
|
||||||
panic(fmt.Sprintf("unsupported image MIME type: %s", mimeType))
|
|
||||||
}
|
|
||||||
|
|
||||||
parts = append(parts, genai.NewPartFromBytes(data, mimeType))
|
|
||||||
} else {
|
|
||||||
b, e := base64.StdEncoding.DecodeString(img.Base64)
|
|
||||||
if e != nil {
|
|
||||||
panic(fmt.Sprintf("error decoding base64: %v", e))
|
|
||||||
}
|
|
||||||
|
|
||||||
parts = append(parts, genai.NewPartFromBytes(b, img.ContentType))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
contents = append(contents, genai.NewContentFromParts(parts, role))
|
|
||||||
}
|
|
||||||
|
|
||||||
return contents, &cfg
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g googleImpl) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) {
|
|
||||||
res := Response{}
|
|
||||||
|
|
||||||
for _, c := range in.Candidates {
|
|
||||||
var choice ResponseChoice
|
|
||||||
var set = false
|
|
||||||
if c.Content != nil {
|
|
||||||
for _, p := range c.Content.Parts {
|
|
||||||
if p.Text != "" {
|
|
||||||
set = true
|
|
||||||
choice.Content = p.Text
|
|
||||||
} else if p.FunctionCall != nil {
|
|
||||||
v := p.FunctionCall
|
|
||||||
b, e := json.Marshal(v.Args)
|
|
||||||
if e != nil {
|
|
||||||
return Response{}, fmt.Errorf("error marshalling args: %w", e)
|
|
||||||
}
|
|
||||||
|
|
||||||
call := ToolCall{
|
|
||||||
ID: v.Name,
|
|
||||||
FunctionCall: FunctionCall{
|
|
||||||
Name: v.Name,
|
|
||||||
Arguments: string(b),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
choice.Calls = append(choice.Calls, call)
|
|
||||||
set = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if set {
|
|
||||||
choice.Role = RoleAssistant
|
|
||||||
res.Choices = append(res.Choices, choice)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g googleImpl) ChatComplete(ctx context.Context, req Request) (Response, error) {
|
|
||||||
cl, err := genai.NewClient(ctx, &genai.ClientConfig{
|
|
||||||
APIKey: g.key,
|
|
||||||
Backend: genai.BackendGeminiAPI,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return Response{}, fmt.Errorf("error creating genai client: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
contents, cfg := g.requestToContents(req)
|
|
||||||
|
|
||||||
resp, err := cl.Models.GenerateContent(ctx, g.model, contents, cfg)
|
|
||||||
if err != nil {
|
|
||||||
return Response{}, fmt.Errorf("error generating content: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return g.responseToLLMResponse(resp)
|
|
||||||
}
|
|
||||||
@@ -1,114 +0,0 @@
|
|||||||
package imageutil
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"image"
|
|
||||||
"image/gif"
|
|
||||||
"image/jpeg"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"golang.org/x/image/draw"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CompressImage takes a base-64-encoded image (JPEG, PNG or GIF) and returns
|
|
||||||
// a base-64-encoded version that is at most maxLength in size, or an error.
|
|
||||||
func CompressImage(b64 string, maxLength int) (string, string, error) {
|
|
||||||
raw, err := base64.StdEncoding.DecodeString(b64)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", fmt.Errorf("base64 decode: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
mime := http.DetectContentType(raw)
|
|
||||||
if len(raw) <= maxLength {
|
|
||||||
return b64, mime, nil // small enough already
|
|
||||||
}
|
|
||||||
|
|
||||||
switch mime {
|
|
||||||
case "image/gif":
|
|
||||||
return compressGIF(raw, maxLength)
|
|
||||||
|
|
||||||
default: // jpeg, png, webp, etc. -> treat as raster
|
|
||||||
return compressRaster(raw, maxLength)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------- Raster path (jpeg / png / single-frame gif) ----------
|
|
||||||
|
|
||||||
func compressRaster(src []byte, maxLength int) (string, string, error) {
|
|
||||||
img, _, err := image.Decode(bytes.NewReader(src))
|
|
||||||
if err != nil {
|
|
||||||
return "", "", fmt.Errorf("decode raster: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
quality := 95
|
|
||||||
for {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: quality}); err != nil {
|
|
||||||
return "", "", fmt.Errorf("jpeg encode: %w", err)
|
|
||||||
}
|
|
||||||
if buf.Len() <= maxLength {
|
|
||||||
return base64.StdEncoding.EncodeToString(buf.Bytes()), "image/jpeg", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if quality > 20 {
|
|
||||||
quality -= 5
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// down-scale 80%
|
|
||||||
b := img.Bounds()
|
|
||||||
if b.Dx() < 100 || b.Dy() < 100 {
|
|
||||||
return "", "", fmt.Errorf("cannot compress below %.02fMiB without destroying image", float64(maxLength)/1048576.0)
|
|
||||||
}
|
|
||||||
dst := image.NewRGBA(image.Rect(0, 0, int(float64(b.Dx())*0.8), int(float64(b.Dy())*0.8)))
|
|
||||||
draw.ApproxBiLinear.Scale(dst, dst.Bounds(), img, b, draw.Over, nil)
|
|
||||||
img = dst
|
|
||||||
quality = 95 // restart ladder
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------- Animated GIF path ----------
|
|
||||||
|
|
||||||
func compressGIF(src []byte, maxLength int) (string, string, error) {
|
|
||||||
g, err := gif.DecodeAll(bytes.NewReader(src))
|
|
||||||
if err != nil {
|
|
||||||
return "", "", fmt.Errorf("gif decode: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
if err := gif.EncodeAll(&buf, g); err != nil {
|
|
||||||
return "", "", fmt.Errorf("gif encode: %w", err)
|
|
||||||
}
|
|
||||||
if buf.Len() <= maxLength {
|
|
||||||
return base64.StdEncoding.EncodeToString(buf.Bytes()), "image/gif", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// down-scale every frame by 80%
|
|
||||||
w, h := g.Config.Width, g.Config.Height
|
|
||||||
if w < 100 || h < 100 {
|
|
||||||
return "", "", fmt.Errorf("cannot compress animated GIF below 5 MiB without excessive quality loss")
|
|
||||||
}
|
|
||||||
|
|
||||||
nw, nh := int(float64(w)*0.8), int(float64(h)*0.8)
|
|
||||||
for i, frm := range g.Image {
|
|
||||||
// convert paletted frame -> RGBA for scaling
|
|
||||||
rgba := image.NewRGBA(frm.Bounds())
|
|
||||||
draw.Draw(rgba, rgba.Bounds(), frm, frm.Bounds().Min, draw.Src)
|
|
||||||
|
|
||||||
// scaled destination
|
|
||||||
dst := image.NewRGBA(image.Rect(0, 0, nw, nh))
|
|
||||||
draw.ApproxBiLinear.Scale(dst, dst.Bounds(), rgba, rgba.Bounds(), draw.Over, nil)
|
|
||||||
|
|
||||||
// quantize back to paletted using default encoder quantizer
|
|
||||||
paletted := image.NewPaletted(dst.Bounds(), nil)
|
|
||||||
draw.FloydSteinberg.Draw(paletted, paletted.Bounds(), dst, dst.Bounds().Min)
|
|
||||||
|
|
||||||
g.Image[i] = paletted
|
|
||||||
}
|
|
||||||
g.Config.Width, g.Config.Height = nw, nh
|
|
||||||
// loop back and test size again ...
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
package llm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ChatCompletion is the interface for chat completion.
|
|
||||||
type ChatCompletion interface {
|
|
||||||
ChatComplete(ctx context.Context, req Request) (Response, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// LLM is the interface for language model providers.
|
|
||||||
type LLM interface {
|
|
||||||
ModelVersion(modelVersion string) (ChatCompletion, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OpenAI creates a new OpenAI LLM provider with the given API key.
|
|
||||||
func OpenAI(key string) LLM {
|
|
||||||
return openaiImpl{key: key}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Anthropic creates a new Anthropic LLM provider with the given API key.
|
|
||||||
func Anthropic(key string) LLM {
|
|
||||||
return anthropicImpl{key: key}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Google creates a new Google LLM provider with the given API key.
|
|
||||||
func Google(key string) LLM {
|
|
||||||
return googleImpl{key: key}
|
|
||||||
}
|
|
||||||
@@ -1,238 +0,0 @@
|
|||||||
package llm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
|
||||||
|
|
||||||
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MCPServer represents a connection to an MCP server.
|
|
||||||
// It manages the lifecycle of the connection and provides access to the server's tools.
|
|
||||||
type MCPServer struct {
|
|
||||||
// Name is a friendly name for this server (used for logging/identification)
|
|
||||||
Name string
|
|
||||||
|
|
||||||
// Command is the command to run the MCP server (for stdio transport)
|
|
||||||
Command string
|
|
||||||
|
|
||||||
// Args are arguments to pass to the command
|
|
||||||
Args []string
|
|
||||||
|
|
||||||
// Env are environment variables to set for the command (in addition to current environment)
|
|
||||||
Env []string
|
|
||||||
|
|
||||||
// URL is the URL for SSE or HTTP transport (alternative to Command)
|
|
||||||
URL string
|
|
||||||
|
|
||||||
// Transport specifies the transport type: "stdio" (default), "sse", or "http"
|
|
||||||
Transport string
|
|
||||||
|
|
||||||
client *mcp.Client
|
|
||||||
session *mcp.ClientSession
|
|
||||||
tools map[string]*mcp.Tool // tool name -> tool definition
|
|
||||||
mu sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connect establishes a connection to the MCP server.
|
|
||||||
func (m *MCPServer) Connect(ctx context.Context) error {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
if m.session != nil {
|
|
||||||
return nil // Already connected
|
|
||||||
}
|
|
||||||
|
|
||||||
m.client = mcp.NewClient(&mcp.Implementation{
|
|
||||||
Name: "go-llm",
|
|
||||||
Version: "1.0.0",
|
|
||||||
}, nil)
|
|
||||||
|
|
||||||
var transport mcp.Transport
|
|
||||||
|
|
||||||
switch m.Transport {
|
|
||||||
case "sse":
|
|
||||||
transport = &mcp.SSEClientTransport{
|
|
||||||
Endpoint: m.URL,
|
|
||||||
}
|
|
||||||
case "http":
|
|
||||||
transport = &mcp.StreamableClientTransport{
|
|
||||||
Endpoint: m.URL,
|
|
||||||
}
|
|
||||||
default: // "stdio" or empty
|
|
||||||
cmd := exec.Command(m.Command, m.Args...)
|
|
||||||
cmd.Env = append(os.Environ(), m.Env...)
|
|
||||||
transport = &mcp.CommandTransport{
|
|
||||||
Command: cmd,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
session, err := m.client.Connect(ctx, transport, nil)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to connect to MCP server %s: %w", m.Name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.session = session
|
|
||||||
|
|
||||||
// Load tools
|
|
||||||
m.tools = make(map[string]*mcp.Tool)
|
|
||||||
for tool, err := range session.Tools(ctx, nil) {
|
|
||||||
if err != nil {
|
|
||||||
m.session.Close()
|
|
||||||
m.session = nil
|
|
||||||
return fmt.Errorf("failed to list tools from %s: %w", m.Name, err)
|
|
||||||
}
|
|
||||||
m.tools[tool.Name] = tool
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the connection to the MCP server.
|
|
||||||
func (m *MCPServer) Close() error {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
if m.session == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err := m.session.Close()
|
|
||||||
m.session = nil
|
|
||||||
m.tools = nil
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsConnected returns true if the server is connected.
|
|
||||||
func (m *MCPServer) IsConnected() bool {
|
|
||||||
m.mu.RLock()
|
|
||||||
defer m.mu.RUnlock()
|
|
||||||
return m.session != nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tools returns the list of tool names available from this server.
|
|
||||||
func (m *MCPServer) Tools() []string {
|
|
||||||
m.mu.RLock()
|
|
||||||
defer m.mu.RUnlock()
|
|
||||||
|
|
||||||
var names []string
|
|
||||||
for name := range m.tools {
|
|
||||||
names = append(names, name)
|
|
||||||
}
|
|
||||||
return names
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasTool returns true if this server provides the named tool.
|
|
||||||
func (m *MCPServer) HasTool(name string) bool {
|
|
||||||
m.mu.RLock()
|
|
||||||
defer m.mu.RUnlock()
|
|
||||||
_, ok := m.tools[name]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// CallTool calls a tool on the MCP server.
|
|
||||||
func (m *MCPServer) CallTool(ctx context.Context, name string, arguments map[string]any) (any, error) {
|
|
||||||
m.mu.RLock()
|
|
||||||
session := m.session
|
|
||||||
m.mu.RUnlock()
|
|
||||||
|
|
||||||
if session == nil {
|
|
||||||
return nil, fmt.Errorf("not connected to MCP server %s", m.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := session.CallTool(ctx, &mcp.CallToolParams{
|
|
||||||
Name: name,
|
|
||||||
Arguments: arguments,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process the result content
|
|
||||||
if len(result.Content) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// If there's a single text content, return it as a string
|
|
||||||
if len(result.Content) == 1 {
|
|
||||||
if textContent, ok := result.Content[0].(*mcp.TextContent); ok {
|
|
||||||
return textContent.Text, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// For multiple contents or non-text, serialize to string
|
|
||||||
return contentToString(result.Content), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// toFunction converts an MCP tool to a go-llm Function (for schema purposes only).
|
|
||||||
func (m *MCPServer) toFunction(tool *mcp.Tool) Function {
|
|
||||||
var inputSchema map[string]any
|
|
||||||
if tool.InputSchema != nil {
|
|
||||||
data, err := json.Marshal(tool.InputSchema)
|
|
||||||
if err == nil {
|
|
||||||
_ = json.Unmarshal(data, &inputSchema)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if inputSchema == nil {
|
|
||||||
inputSchema = map[string]any{
|
|
||||||
"type": "object",
|
|
||||||
"properties": map[string]any{},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Function{
|
|
||||||
Name: tool.Name,
|
|
||||||
Description: tool.Description,
|
|
||||||
Parameters: schema.NewRaw(inputSchema),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// contentToString converts MCP content to a string representation.
|
|
||||||
func contentToString(content []mcp.Content) string {
|
|
||||||
var parts []string
|
|
||||||
for _, c := range content {
|
|
||||||
switch tc := c.(type) {
|
|
||||||
case *mcp.TextContent:
|
|
||||||
parts = append(parts, tc.Text)
|
|
||||||
default:
|
|
||||||
if data, err := json.Marshal(c); err == nil {
|
|
||||||
parts = append(parts, string(data))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(parts) == 1 {
|
|
||||||
return parts[0]
|
|
||||||
}
|
|
||||||
data, _ := json.Marshal(parts)
|
|
||||||
return string(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithMCPServer adds an MCP server to the toolbox.
|
|
||||||
// The server must already be connected. Tools from the server will be available
|
|
||||||
// for use, and tool calls will be routed to the appropriate server.
|
|
||||||
func (t ToolBox) WithMCPServer(server *MCPServer) ToolBox {
|
|
||||||
if t.mcpServers == nil {
|
|
||||||
t.mcpServers = make(map[string]*MCPServer)
|
|
||||||
}
|
|
||||||
|
|
||||||
server.mu.RLock()
|
|
||||||
defer server.mu.RUnlock()
|
|
||||||
|
|
||||||
for name, tool := range server.tools {
|
|
||||||
// Add the function definition (for schema)
|
|
||||||
fn := server.toFunction(tool)
|
|
||||||
t.functions[name] = fn
|
|
||||||
|
|
||||||
// Track which server owns this tool
|
|
||||||
t.mcpServers[name] = server
|
|
||||||
}
|
|
||||||
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
-115
@@ -1,115 +0,0 @@
|
|||||||
package llm
|
|
||||||
|
|
||||||
// Role represents the role of a message in a conversation.
|
|
||||||
type Role string
|
|
||||||
|
|
||||||
const (
|
|
||||||
RoleSystem Role = "system"
|
|
||||||
RoleUser Role = "user"
|
|
||||||
RoleAssistant Role = "assistant"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Image represents an image that can be included in a message.
|
|
||||||
type Image struct {
|
|
||||||
Base64 string
|
|
||||||
ContentType string
|
|
||||||
Url string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i Image) toRaw() map[string]any {
|
|
||||||
res := map[string]any{
|
|
||||||
"base64": i.Base64,
|
|
||||||
"contenttype": i.ContentType,
|
|
||||||
"url": i.Url,
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *Image) fromRaw(raw map[string]any) Image {
|
|
||||||
var res Image
|
|
||||||
|
|
||||||
res.Base64 = raw["base64"].(string)
|
|
||||||
res.ContentType = raw["contenttype"].(string)
|
|
||||||
res.Url = raw["url"].(string)
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// Message represents a message in a conversation.
|
|
||||||
type Message struct {
|
|
||||||
Role Role
|
|
||||||
Name string
|
|
||||||
Text string
|
|
||||||
Images []Image
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Message) toRaw() map[string]any {
|
|
||||||
res := map[string]any{
|
|
||||||
"role": m.Role,
|
|
||||||
"name": m.Name,
|
|
||||||
"text": m.Text,
|
|
||||||
}
|
|
||||||
|
|
||||||
images := make([]map[string]any, 0, len(m.Images))
|
|
||||||
for _, img := range m.Images {
|
|
||||||
images = append(images, img.toRaw())
|
|
||||||
}
|
|
||||||
|
|
||||||
res["images"] = images
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Message) fromRaw(raw map[string]any) Message {
|
|
||||||
var res Message
|
|
||||||
|
|
||||||
res.Role = Role(raw["role"].(string))
|
|
||||||
res.Name = raw["name"].(string)
|
|
||||||
res.Text = raw["text"].(string)
|
|
||||||
|
|
||||||
images := raw["images"].([]map[string]any)
|
|
||||||
for _, img := range images {
|
|
||||||
var i Image
|
|
||||||
|
|
||||||
res.Images = append(res.Images, i.fromRaw(img))
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToolCall represents a tool call made by an assistant.
|
|
||||||
type ToolCall struct {
|
|
||||||
ID string
|
|
||||||
FunctionCall FunctionCall
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t ToolCall) toRaw() map[string]any {
|
|
||||||
res := map[string]any{
|
|
||||||
"id": t.ID,
|
|
||||||
}
|
|
||||||
|
|
||||||
res["function"] = t.FunctionCall.toRaw()
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToolCallResponse represents the response to a tool call.
|
|
||||||
type ToolCallResponse struct {
|
|
||||||
ID string
|
|
||||||
Result any
|
|
||||||
Error error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t ToolCallResponse) toRaw() map[string]any {
|
|
||||||
res := map[string]any{
|
|
||||||
"id": t.ID,
|
|
||||||
"result": t.Result,
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.Error != nil {
|
|
||||||
res["error"] = t.Error.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
@@ -1,322 +0,0 @@
|
|||||||
package llm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/openai/openai-go"
|
|
||||||
"github.com/openai/openai-go/option"
|
|
||||||
"github.com/openai/openai-go/packages/param"
|
|
||||||
"github.com/openai/openai-go/shared"
|
|
||||||
)
|
|
||||||
|
|
||||||
type openaiImpl struct {
|
|
||||||
key string
|
|
||||||
model string
|
|
||||||
baseUrl string
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ LLM = openaiImpl{}
|
|
||||||
|
|
||||||
func (o openaiImpl) newRequestToOpenAIRequest(request Request) openai.ChatCompletionNewParams {
|
|
||||||
res := openai.ChatCompletionNewParams{
|
|
||||||
Model: o.model,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, i := range request.Conversation {
|
|
||||||
res.Messages = append(res.Messages, inputToChatCompletionMessages(i, o.model)...)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, msg := range request.Messages {
|
|
||||||
res.Messages = append(res.Messages, messageToChatCompletionMessages(msg, o.model)...)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tool := range request.Toolbox.Functions() {
|
|
||||||
res.Tools = append(res.Tools, openai.ChatCompletionToolParam{
|
|
||||||
Type: "function",
|
|
||||||
Function: shared.FunctionDefinitionParam{
|
|
||||||
Name: tool.Name,
|
|
||||||
Description: openai.String(tool.Description),
|
|
||||||
Strict: openai.Bool(tool.Strict),
|
|
||||||
Parameters: tool.Parameters.OpenAIParameters(),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if request.Toolbox.RequiresTool() {
|
|
||||||
res.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{
|
|
||||||
OfAuto: openai.String("required"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if request.Temperature != nil {
|
|
||||||
// these are known models that do not support custom temperatures
|
|
||||||
// all the o* models
|
|
||||||
// gpt-5* models
|
|
||||||
if !strings.HasPrefix(o.model, "o") && !strings.HasPrefix(o.model, "gpt-5") {
|
|
||||||
res.Temperature = openai.Float(*request.Temperature)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o openaiImpl) responseToLLMResponse(response *openai.ChatCompletion) Response {
|
|
||||||
var res Response
|
|
||||||
|
|
||||||
if response == nil {
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(response.Choices) == 0 {
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, choice := range response.Choices {
|
|
||||||
var toolCalls []ToolCall
|
|
||||||
for _, call := range choice.Message.ToolCalls {
|
|
||||||
toolCall := ToolCall{
|
|
||||||
ID: call.ID,
|
|
||||||
FunctionCall: FunctionCall{
|
|
||||||
Name: call.Function.Name,
|
|
||||||
Arguments: strings.TrimSpace(call.Function.Arguments),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
toolCalls = append(toolCalls, toolCall)
|
|
||||||
|
|
||||||
}
|
|
||||||
res.Choices = append(res.Choices, ResponseChoice{
|
|
||||||
Content: choice.Message.Content,
|
|
||||||
Role: Role(choice.Message.Role),
|
|
||||||
Refusal: choice.Message.Refusal,
|
|
||||||
Calls: toolCalls,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) {
|
|
||||||
var opts = []option.RequestOption{
|
|
||||||
option.WithAPIKey(o.key),
|
|
||||||
}
|
|
||||||
|
|
||||||
if o.baseUrl != "" {
|
|
||||||
opts = append(opts, option.WithBaseURL(o.baseUrl))
|
|
||||||
}
|
|
||||||
|
|
||||||
cl := openai.NewClient(opts...)
|
|
||||||
|
|
||||||
req := o.newRequestToOpenAIRequest(request)
|
|
||||||
|
|
||||||
resp, err := cl.Chat.Completions.New(ctx, req)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return Response{}, fmt.Errorf("unhandled openai error: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return o.responseToLLMResponse(resp), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o openaiImpl) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
|
||||||
return openaiImpl{
|
|
||||||
key: o.key,
|
|
||||||
model: modelVersion,
|
|
||||||
baseUrl: o.baseUrl,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// inputToChatCompletionMessages converts an Input to OpenAI chat completion messages.
|
|
||||||
func inputToChatCompletionMessages(input Input, model string) []openai.ChatCompletionMessageParamUnion {
|
|
||||||
switch v := input.(type) {
|
|
||||||
case Message:
|
|
||||||
return messageToChatCompletionMessages(v, model)
|
|
||||||
case ToolCall:
|
|
||||||
return toolCallToChatCompletionMessages(v)
|
|
||||||
case ToolCallResponse:
|
|
||||||
return toolCallResponseToChatCompletionMessages(v)
|
|
||||||
case ResponseChoice:
|
|
||||||
return responseChoiceToChatCompletionMessages(v)
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func messageToChatCompletionMessages(m Message, model string) []openai.ChatCompletionMessageParamUnion {
|
|
||||||
var res openai.ChatCompletionMessageParamUnion
|
|
||||||
|
|
||||||
var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam
|
|
||||||
var textContent param.Opt[string]
|
|
||||||
|
|
||||||
for _, img := range m.Images {
|
|
||||||
if img.Base64 != "" {
|
|
||||||
arrayOfContentParts = append(arrayOfContentParts,
|
|
||||||
openai.ChatCompletionContentPartUnionParam{
|
|
||||||
OfImageURL: &openai.ChatCompletionContentPartImageParam{
|
|
||||||
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
|
|
||||||
URL: "data:" + img.ContentType + ";base64," + img.Base64,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
} else if img.Url != "" {
|
|
||||||
arrayOfContentParts = append(arrayOfContentParts,
|
|
||||||
openai.ChatCompletionContentPartUnionParam{
|
|
||||||
OfImageURL: &openai.ChatCompletionContentPartImageParam{
|
|
||||||
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
|
|
||||||
URL: img.Url,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.Text != "" {
|
|
||||||
if len(arrayOfContentParts) > 0 {
|
|
||||||
arrayOfContentParts = append(arrayOfContentParts,
|
|
||||||
openai.ChatCompletionContentPartUnionParam{
|
|
||||||
OfText: &openai.ChatCompletionContentPartTextParam{
|
|
||||||
Text: "\n",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
textContent = openai.String(m.Text)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
a := strings.Split(model, "-")
|
|
||||||
|
|
||||||
useSystemInsteadOfDeveloper := true
|
|
||||||
if len(a) > 1 && a[0][0] == 'o' {
|
|
||||||
useSystemInsteadOfDeveloper = false
|
|
||||||
}
|
|
||||||
|
|
||||||
switch m.Role {
|
|
||||||
case RoleSystem:
|
|
||||||
if useSystemInsteadOfDeveloper {
|
|
||||||
res = openai.ChatCompletionMessageParamUnion{
|
|
||||||
OfSystem: &openai.ChatCompletionSystemMessageParam{
|
|
||||||
Content: openai.ChatCompletionSystemMessageParamContentUnion{
|
|
||||||
OfString: textContent,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
res = openai.ChatCompletionMessageParamUnion{
|
|
||||||
OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{
|
|
||||||
Content: openai.ChatCompletionDeveloperMessageParamContentUnion{
|
|
||||||
OfString: textContent,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case RoleUser:
|
|
||||||
var name param.Opt[string]
|
|
||||||
if m.Name != "" {
|
|
||||||
name = openai.String(m.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
res = openai.ChatCompletionMessageParamUnion{
|
|
||||||
OfUser: &openai.ChatCompletionUserMessageParam{
|
|
||||||
Name: name,
|
|
||||||
Content: openai.ChatCompletionUserMessageParamContentUnion{
|
|
||||||
OfString: textContent,
|
|
||||||
OfArrayOfContentParts: arrayOfContentParts,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
case RoleAssistant:
|
|
||||||
var name param.Opt[string]
|
|
||||||
if m.Name != "" {
|
|
||||||
name = openai.String(m.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
res = openai.ChatCompletionMessageParamUnion{
|
|
||||||
OfAssistant: &openai.ChatCompletionAssistantMessageParam{
|
|
||||||
Name: name,
|
|
||||||
Content: openai.ChatCompletionAssistantMessageParamContentUnion{
|
|
||||||
OfString: textContent,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return []openai.ChatCompletionMessageParamUnion{res}
|
|
||||||
}
|
|
||||||
|
|
||||||
func toolCallToChatCompletionMessages(t ToolCall) []openai.ChatCompletionMessageParamUnion {
|
|
||||||
return []openai.ChatCompletionMessageParamUnion{{
|
|
||||||
OfAssistant: &openai.ChatCompletionAssistantMessageParam{
|
|
||||||
ToolCalls: []openai.ChatCompletionMessageToolCallParam{
|
|
||||||
{
|
|
||||||
ID: t.ID,
|
|
||||||
Function: openai.ChatCompletionMessageToolCallFunctionParam{
|
|
||||||
Name: t.FunctionCall.Name,
|
|
||||||
Arguments: t.FunctionCall.Arguments,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
}
|
|
||||||
|
|
||||||
func toolCallResponseToChatCompletionMessages(t ToolCallResponse) []openai.ChatCompletionMessageParamUnion {
|
|
||||||
var refusal string
|
|
||||||
if t.Error != nil {
|
|
||||||
refusal = t.Error.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
result := t.Result
|
|
||||||
if refusal != "" {
|
|
||||||
if result != "" {
|
|
||||||
result = fmt.Sprint(result) + " (error in execution: " + refusal + ")"
|
|
||||||
} else {
|
|
||||||
result = "error in execution:" + refusal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return []openai.ChatCompletionMessageParamUnion{{
|
|
||||||
OfTool: &openai.ChatCompletionToolMessageParam{
|
|
||||||
ToolCallID: t.ID,
|
|
||||||
Content: openai.ChatCompletionToolMessageParamContentUnion{
|
|
||||||
OfString: openai.String(fmt.Sprint(result)),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
}
|
|
||||||
|
|
||||||
func responseChoiceToChatCompletionMessages(r ResponseChoice) []openai.ChatCompletionMessageParamUnion {
|
|
||||||
var as openai.ChatCompletionAssistantMessageParam
|
|
||||||
|
|
||||||
if r.Name != "" {
|
|
||||||
as.Name = openai.String(r.Name)
|
|
||||||
}
|
|
||||||
if r.Refusal != "" {
|
|
||||||
as.Refusal = openai.String(r.Refusal)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.Content != "" {
|
|
||||||
as.Content.OfString = openai.String(r.Content)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, call := range r.Calls {
|
|
||||||
as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{
|
|
||||||
ID: call.ID,
|
|
||||||
Function: openai.ChatCompletionMessageToolCallFunctionParam{
|
|
||||||
Name: call.FunctionCall.Name,
|
|
||||||
Arguments: call.FunctionCall.Arguments,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return []openai.ChatCompletionMessageParamUnion{
|
|
||||||
{
|
|
||||||
OfAssistant: &as,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,219 +0,0 @@
|
|||||||
package llm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/openai/openai-go"
|
|
||||||
"github.com/openai/openai-go/option"
|
|
||||||
)
|
|
||||||
|
|
||||||
type openaiTranscriber struct {
|
|
||||||
key string
|
|
||||||
model string
|
|
||||||
baseUrl string
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Transcriber = openaiTranscriber{}
|
|
||||||
|
|
||||||
// OpenAITranscriber creates a transcriber backed by OpenAI's audio models.
|
|
||||||
// If model is empty, whisper-1 is used by default.
|
|
||||||
func OpenAITranscriber(key string, model string) Transcriber {
|
|
||||||
if strings.TrimSpace(model) == "" {
|
|
||||||
model = "whisper-1"
|
|
||||||
}
|
|
||||||
return openaiTranscriber{
|
|
||||||
key: key,
|
|
||||||
model: model,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o openaiTranscriber) Transcribe(ctx context.Context, wav []byte, opts TranscriptionOptions) (Transcription, error) {
|
|
||||||
if len(wav) == 0 {
|
|
||||||
return Transcription{}, fmt.Errorf("wav data is empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
format := opts.ResponseFormat
|
|
||||||
if format == "" {
|
|
||||||
if strings.HasPrefix(o.model, "gpt-4o") {
|
|
||||||
format = TranscriptionResponseFormatJSON
|
|
||||||
} else {
|
|
||||||
format = TranscriptionResponseFormatVerboseJSON
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if format != TranscriptionResponseFormatJSON && format != TranscriptionResponseFormatVerboseJSON {
|
|
||||||
return Transcription{}, fmt.Errorf("openai transcriber requires response_format json or verbose_json for structured output")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(opts.TimestampGranularities) > 0 && format != TranscriptionResponseFormatVerboseJSON {
|
|
||||||
return Transcription{}, fmt.Errorf("timestamp granularities require response_format=verbose_json")
|
|
||||||
}
|
|
||||||
|
|
||||||
params := openai.AudioTranscriptionNewParams{
|
|
||||||
File: openai.File(bytes.NewReader(wav), "audio.wav", "audio/wav"),
|
|
||||||
Model: openai.AudioModel(o.model),
|
|
||||||
}
|
|
||||||
|
|
||||||
if opts.Language != "" {
|
|
||||||
params.Language = openai.String(opts.Language)
|
|
||||||
}
|
|
||||||
if opts.Prompt != "" {
|
|
||||||
params.Prompt = openai.String(opts.Prompt)
|
|
||||||
}
|
|
||||||
if opts.Temperature != nil {
|
|
||||||
params.Temperature = openai.Float(*opts.Temperature)
|
|
||||||
}
|
|
||||||
|
|
||||||
params.ResponseFormat = openai.AudioResponseFormat(format)
|
|
||||||
|
|
||||||
if opts.IncludeLogprobs {
|
|
||||||
params.Include = []openai.TranscriptionInclude{openai.TranscriptionIncludeLogprobs}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(opts.TimestampGranularities) > 0 {
|
|
||||||
for _, granularity := range opts.TimestampGranularities {
|
|
||||||
params.TimestampGranularities = append(params.TimestampGranularities, string(granularity))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
clientOptions := []option.RequestOption{
|
|
||||||
option.WithAPIKey(o.key),
|
|
||||||
}
|
|
||||||
if o.baseUrl != "" {
|
|
||||||
clientOptions = append(clientOptions, option.WithBaseURL(o.baseUrl))
|
|
||||||
}
|
|
||||||
|
|
||||||
client := openai.NewClient(clientOptions...)
|
|
||||||
resp, err := client.Audio.Transcriptions.New(ctx, params)
|
|
||||||
if err != nil {
|
|
||||||
return Transcription{}, fmt.Errorf("openai transcription failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return openaiTranscriptionToResult(o.model, resp), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type openaiVerboseTranscription struct {
|
|
||||||
Text string `json:"text"`
|
|
||||||
Language string `json:"language"`
|
|
||||||
Duration float64 `json:"duration"`
|
|
||||||
Segments []openaiVerboseSegment `json:"segments"`
|
|
||||||
Words []openaiVerboseWord `json:"words"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type openaiVerboseSegment struct {
|
|
||||||
ID int `json:"id"`
|
|
||||||
Start float64 `json:"start"`
|
|
||||||
End float64 `json:"end"`
|
|
||||||
Text string `json:"text"`
|
|
||||||
Tokens []int `json:"tokens"`
|
|
||||||
AvgLogprob *float64 `json:"avg_logprob"`
|
|
||||||
CompressionRatio *float64 `json:"compression_ratio"`
|
|
||||||
NoSpeechProb *float64 `json:"no_speech_prob"`
|
|
||||||
Words []openaiVerboseWord `json:"words"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type openaiVerboseWord struct {
|
|
||||||
Word string `json:"word"`
|
|
||||||
Start float64 `json:"start"`
|
|
||||||
End float64 `json:"end"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func openaiTranscriptionToResult(model string, resp *openai.Transcription) Transcription {
|
|
||||||
result := Transcription{
|
|
||||||
Provider: "openai",
|
|
||||||
Model: model,
|
|
||||||
}
|
|
||||||
if resp == nil {
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
result.Text = resp.Text
|
|
||||||
result.RawJSON = resp.RawJSON()
|
|
||||||
|
|
||||||
for _, logprob := range resp.Logprobs {
|
|
||||||
result.Logprobs = append(result.Logprobs, TranscriptionTokenLogprob{
|
|
||||||
Token: logprob.Token,
|
|
||||||
Bytes: logprob.Bytes,
|
|
||||||
Logprob: logprob.Logprob,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if usage := openaiUsageToTranscriptionUsage(resp.Usage); usage.Type != "" {
|
|
||||||
result.Usage = usage
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.RawJSON == "" {
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
var verbose openaiVerboseTranscription
|
|
||||||
if err := json.Unmarshal([]byte(result.RawJSON), &verbose); err != nil {
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
if verbose.Text != "" {
|
|
||||||
result.Text = verbose.Text
|
|
||||||
}
|
|
||||||
result.Language = verbose.Language
|
|
||||||
result.DurationSeconds = verbose.Duration
|
|
||||||
|
|
||||||
for _, seg := range verbose.Segments {
|
|
||||||
segment := TranscriptionSegment{
|
|
||||||
ID: seg.ID,
|
|
||||||
Start: seg.Start,
|
|
||||||
End: seg.End,
|
|
||||||
Text: seg.Text,
|
|
||||||
Tokens: append([]int(nil), seg.Tokens...),
|
|
||||||
AvgLogprob: seg.AvgLogprob,
|
|
||||||
CompressionRatio: seg.CompressionRatio,
|
|
||||||
NoSpeechProb: seg.NoSpeechProb,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, word := range seg.Words {
|
|
||||||
segment.Words = append(segment.Words, TranscriptionWord{
|
|
||||||
Word: word.Word,
|
|
||||||
Start: word.Start,
|
|
||||||
End: word.End,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
result.Segments = append(result.Segments, segment)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, word := range verbose.Words {
|
|
||||||
result.Words = append(result.Words, TranscriptionWord{
|
|
||||||
Word: word.Word,
|
|
||||||
Start: word.Start,
|
|
||||||
End: word.End,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func openaiUsageToTranscriptionUsage(usage openai.TranscriptionUsageUnion) TranscriptionUsage {
|
|
||||||
switch usage.Type {
|
|
||||||
case "tokens":
|
|
||||||
tokens := usage.AsTokens()
|
|
||||||
return TranscriptionUsage{
|
|
||||||
Type: usage.Type,
|
|
||||||
InputTokens: tokens.InputTokens,
|
|
||||||
OutputTokens: tokens.OutputTokens,
|
|
||||||
TotalTokens: tokens.TotalTokens,
|
|
||||||
AudioTokens: tokens.InputTokenDetails.AudioTokens,
|
|
||||||
TextTokens: tokens.InputTokenDetails.TextTokens,
|
|
||||||
}
|
|
||||||
case "duration":
|
|
||||||
duration := usage.AsDuration()
|
|
||||||
return TranscriptionUsage{
|
|
||||||
Type: usage.Type,
|
|
||||||
Seconds: duration.Seconds,
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return TranscriptionUsage{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
package llm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Providers are the allowed shortcuts in the providers, e.g.: if you set { "openai": OpenAI("key") } that'll allow
|
|
||||||
// for the "openai" provider to be used when parsed.
|
|
||||||
type Providers map[string]LLM
|
|
||||||
|
|
||||||
// Parse will parse the provided input and attempt to return a LLM chat completion interface.
|
|
||||||
// Input should be in the provided format:
|
|
||||||
// - provider/modelname
|
|
||||||
//
|
|
||||||
// where provider is a key inside Providers, and the modelname being passed to the LLM interface's GetModel
|
|
||||||
func (providers Providers) Parse(input string) ChatCompletion {
|
|
||||||
sections := strings.Split(input, "/")
|
|
||||||
|
|
||||||
var provider LLM
|
|
||||||
var ok bool
|
|
||||||
var modelVersion string
|
|
||||||
|
|
||||||
if len(sections) < 2 {
|
|
||||||
// is there a default provider?
|
|
||||||
provider, ok = providers["default"]
|
|
||||||
if !ok {
|
|
||||||
panic("expected format: \"provider/model\" or provide a \"default\" provider to the Parse callback")
|
|
||||||
}
|
|
||||||
|
|
||||||
modelVersion = sections[0]
|
|
||||||
} else {
|
|
||||||
provider, ok = providers[sections[0]]
|
|
||||||
modelVersion = sections[1]
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
panic("expected format: \"provider/model\" or provide a \"default\" provider to the Parse callback")
|
|
||||||
}
|
|
||||||
|
|
||||||
if provider == nil {
|
|
||||||
panic("unknown provider: " + sections[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := provider.ModelVersion(modelVersion)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
// Package anthropic provides the Anthropic LLM provider.
|
|
||||||
package anthropic
|
|
||||||
|
|
||||||
import (
|
|
||||||
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
|
|
||||||
)
|
|
||||||
|
|
||||||
// New creates a new Anthropic LLM provider with the given API key.
|
|
||||||
func New(key string) llm.LLM {
|
|
||||||
return llm.Anthropic(key)
|
|
||||||
}
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
// Package google provides the Google LLM provider.
|
|
||||||
package google
|
|
||||||
|
|
||||||
import (
|
|
||||||
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
|
|
||||||
)
|
|
||||||
|
|
||||||
// New creates a new Google LLM provider with the given API key.
|
|
||||||
func New(key string) llm.LLM {
|
|
||||||
return llm.Google(key)
|
|
||||||
}
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
// Package openai provides the OpenAI LLM provider.
|
|
||||||
package openai
|
|
||||||
|
|
||||||
import (
|
|
||||||
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
|
|
||||||
)
|
|
||||||
|
|
||||||
// New creates a new OpenAI LLM provider with the given API key.
|
|
||||||
func New(key string) llm.LLM {
|
|
||||||
return llm.OpenAI(key)
|
|
||||||
}
|
|
||||||
-51
@@ -1,51 +0,0 @@
|
|||||||
package llm
|
|
||||||
|
|
||||||
// Input is the interface for conversation inputs.
|
|
||||||
// Types that implement this interface can be part of a conversation:
|
|
||||||
// Message, ToolCall, ToolCallResponse, and ResponseChoice.
|
|
||||||
type Input interface {
|
|
||||||
// isInput is a marker method to ensure only valid types implement this interface.
|
|
||||||
isInput()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Implement Input interface for all valid input types.
|
|
||||||
func (Message) isInput() {}
|
|
||||||
func (ToolCall) isInput() {}
|
|
||||||
func (ToolCallResponse) isInput() {}
|
|
||||||
func (ResponseChoice) isInput() {}
|
|
||||||
|
|
||||||
// Request represents a request to a language model.
|
|
||||||
type Request struct {
|
|
||||||
Conversation []Input
|
|
||||||
Messages []Message
|
|
||||||
Toolbox ToolBox
|
|
||||||
Temperature *float64
|
|
||||||
}
|
|
||||||
|
|
||||||
// NextRequest will take the current request's conversation, messages, the response, and any tool results, and
|
|
||||||
// return a new request with the conversation updated to include the response and tool results.
|
|
||||||
func (req Request) NextRequest(resp ResponseChoice, toolResults []ToolCallResponse) Request {
|
|
||||||
var res Request
|
|
||||||
|
|
||||||
res.Toolbox = req.Toolbox
|
|
||||||
res.Temperature = req.Temperature
|
|
||||||
|
|
||||||
res.Conversation = make([]Input, len(req.Conversation))
|
|
||||||
copy(res.Conversation, req.Conversation)
|
|
||||||
|
|
||||||
// now for every input message, convert those to an Input to add to the conversation
|
|
||||||
for _, msg := range req.Messages {
|
|
||||||
res.Conversation = append(res.Conversation, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.Content != "" || resp.Refusal != "" || len(resp.Calls) > 0 {
|
|
||||||
res.Conversation = append(res.Conversation, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
// if there are tool results, then we need to add those to the conversation
|
|
||||||
for _, result := range toolResults {
|
|
||||||
res.Conversation = append(res.Conversation, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
-52
@@ -1,52 +0,0 @@
|
|||||||
package llm
|
|
||||||
|
|
||||||
// ResponseChoice represents a single choice in a response.
|
|
||||||
type ResponseChoice struct {
|
|
||||||
Index int
|
|
||||||
Role Role
|
|
||||||
Content string
|
|
||||||
Refusal string
|
|
||||||
Name string
|
|
||||||
Calls []ToolCall
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r ResponseChoice) toRaw() map[string]any {
|
|
||||||
res := map[string]any{
|
|
||||||
"index": r.Index,
|
|
||||||
"role": r.Role,
|
|
||||||
"content": r.Content,
|
|
||||||
"refusal": r.Refusal,
|
|
||||||
"name": r.Name,
|
|
||||||
}
|
|
||||||
|
|
||||||
calls := make([]map[string]any, 0, len(r.Calls))
|
|
||||||
for _, call := range r.Calls {
|
|
||||||
calls = append(calls, call.toRaw())
|
|
||||||
}
|
|
||||||
|
|
||||||
res["tool_calls"] = calls
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r ResponseChoice) toInput() []Input {
|
|
||||||
var res []Input
|
|
||||||
|
|
||||||
for _, call := range r.Calls {
|
|
||||||
res = append(res, call)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.Content != "" || r.Refusal != "" {
|
|
||||||
res = append(res, Message{
|
|
||||||
Role: RoleAssistant,
|
|
||||||
Text: r.Content,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// Response represents a response from a language model.
|
|
||||||
type Response struct {
|
|
||||||
Choices []ResponseChoice
|
|
||||||
}
|
|
||||||
@@ -1,142 +0,0 @@
|
|||||||
package schema
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// GetType will, given an interface{} that is a struct (NOT a pointer to a struct), return the Type of the struct that
|
|
||||||
// can be used to generate a json schema and build an object from a parsed json object.
|
|
||||||
func GetType(a any) Type {
|
|
||||||
t := reflect.TypeOf(a)
|
|
||||||
|
|
||||||
if t.Kind() != reflect.Struct {
|
|
||||||
panic("GetType expects a struct")
|
|
||||||
}
|
|
||||||
|
|
||||||
return getObject(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getFromType(t reflect.Type, b basic) Type {
|
|
||||||
if t.Kind() == reflect.Ptr {
|
|
||||||
t = t.Elem()
|
|
||||||
b.required = false
|
|
||||||
}
|
|
||||||
|
|
||||||
switch t.Kind() {
|
|
||||||
case reflect.String:
|
|
||||||
b.DataType = TypeString
|
|
||||||
b.typeName = "string"
|
|
||||||
return b
|
|
||||||
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
||||||
b.DataType = TypeInteger
|
|
||||||
b.typeName = "integer"
|
|
||||||
return b
|
|
||||||
|
|
||||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
||||||
b.DataType = TypeInteger
|
|
||||||
b.typeName = "integer"
|
|
||||||
return b
|
|
||||||
|
|
||||||
case reflect.Float32, reflect.Float64:
|
|
||||||
b.DataType = TypeNumber
|
|
||||||
b.typeName = "number"
|
|
||||||
return b
|
|
||||||
|
|
||||||
case reflect.Bool:
|
|
||||||
b.DataType = TypeBoolean
|
|
||||||
b.typeName = "boolean"
|
|
||||||
return b
|
|
||||||
|
|
||||||
case reflect.Struct:
|
|
||||||
o := getObject(t)
|
|
||||||
|
|
||||||
o.basic.required = b.required
|
|
||||||
o.basic.index = b.index
|
|
||||||
o.basic.description = b.description
|
|
||||||
|
|
||||||
return o
|
|
||||||
|
|
||||||
case reflect.Slice:
|
|
||||||
return getArray(t)
|
|
||||||
|
|
||||||
default:
|
|
||||||
panic("unhandled default case for " + t.Kind().String() + " in getFromType")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getField(f reflect.StructField, index int) Type {
|
|
||||||
b := basic{
|
|
||||||
index: index,
|
|
||||||
required: true,
|
|
||||||
description: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
t := f.Type
|
|
||||||
|
|
||||||
// if the tag "description" is set, use that as the description
|
|
||||||
if desc, ok := f.Tag.Lookup("description"); ok {
|
|
||||||
b.description = desc
|
|
||||||
}
|
|
||||||
|
|
||||||
// now if the tag "enum" is set, we need to create an enum type
|
|
||||||
if v, ok := f.Tag.Lookup("enum"); ok {
|
|
||||||
vals := strings.Split(v, ",")
|
|
||||||
|
|
||||||
for i := 0; i < len(vals); i++ {
|
|
||||||
vals[i] = strings.TrimSpace(vals[i])
|
|
||||||
|
|
||||||
if vals[i] == "" {
|
|
||||||
vals = append(vals[:i], vals[i+1:]...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
b.DataType = TypeString
|
|
||||||
b.typeName = "string"
|
|
||||||
return enum{
|
|
||||||
basic: b,
|
|
||||||
values: vals,
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return getFromType(t, b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getObject(t reflect.Type) Object {
|
|
||||||
fields := make(map[string]Type, t.NumField())
|
|
||||||
for i := 0; i < t.NumField(); i++ {
|
|
||||||
field := t.Field(i)
|
|
||||||
|
|
||||||
if field.Anonymous {
|
|
||||||
// if the field is anonymous, we need to get the fields of the anonymous struct
|
|
||||||
// and add them to the object
|
|
||||||
anon := getObject(field.Type)
|
|
||||||
for k, v := range anon.fields {
|
|
||||||
fields[k] = v
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
} else {
|
|
||||||
fields[field.Name] = getField(field, i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Object{
|
|
||||||
basic: basic{DataType: TypeObject, typeName: "object"},
|
|
||||||
fields: fields,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getArray(t reflect.Type) array {
|
|
||||||
res := array{
|
|
||||||
basic: basic{
|
|
||||||
DataType: TypeArray,
|
|
||||||
typeName: "array",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
res.items = getFromType(t.Elem(), basic{})
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
package schema
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"github.com/openai/openai-go"
|
|
||||||
"google.golang.org/genai"
|
|
||||||
)
|
|
||||||
|
|
||||||
type array struct {
|
|
||||||
basic
|
|
||||||
|
|
||||||
// items is the schema of the items in the array
|
|
||||||
items Type
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a array) OpenAIParameters() openai.FunctionParameters {
|
|
||||||
return openai.FunctionParameters{
|
|
||||||
"type": "array",
|
|
||||||
"description": a.Description(),
|
|
||||||
"items": a.items.OpenAIParameters(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a array) GoogleParameters() *genai.Schema {
|
|
||||||
return &genai.Schema{
|
|
||||||
Type: genai.TypeArray,
|
|
||||||
Description: a.Description(),
|
|
||||||
Items: a.items.GoogleParameters(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a array) AnthropicInputSchema() map[string]any {
|
|
||||||
return map[string]any{
|
|
||||||
"type": "array",
|
|
||||||
"description": a.Description(),
|
|
||||||
"items": a.items.AnthropicInputSchema(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a array) FromAny(val any) (reflect.Value, error) {
|
|
||||||
v := reflect.ValueOf(val)
|
|
||||||
|
|
||||||
// first realize we may have a pointer to a slice if this type is not required
|
|
||||||
if !a.required && v.Kind() == reflect.Ptr {
|
|
||||||
v = v.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if v.Kind() != reflect.Slice {
|
|
||||||
return reflect.Value{}, errors.New("expected slice, got " + v.Kind().String())
|
|
||||||
}
|
|
||||||
|
|
||||||
// if the slice is nil, we can just return it
|
|
||||||
if v.IsNil() {
|
|
||||||
return v, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// if the slice is not nil, we need to convert each item
|
|
||||||
items := make([]reflect.Value, v.Len())
|
|
||||||
for i := 0; i < v.Len(); i++ {
|
|
||||||
item, err := a.items.FromAny(v.Index(i).Interface())
|
|
||||||
if err != nil {
|
|
||||||
return reflect.Value{}, err
|
|
||||||
}
|
|
||||||
items[i] = item
|
|
||||||
}
|
|
||||||
|
|
||||||
return reflect.ValueOf(items), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a array) SetValue(obj reflect.Value, val reflect.Value) {
|
|
||||||
if !a.required {
|
|
||||||
val = val.Addr()
|
|
||||||
}
|
|
||||||
obj.Field(a.index).Set(val)
|
|
||||||
}
|
|
||||||
-165
@@ -1,165 +0,0 @@
|
|||||||
package schema
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"reflect"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/openai/openai-go"
|
|
||||||
"google.golang.org/genai"
|
|
||||||
)
|
|
||||||
|
|
||||||
// just enforcing that basic implements Type
|
|
||||||
var _ Type = basic{}
|
|
||||||
|
|
||||||
type DataType string
|
|
||||||
|
|
||||||
const (
|
|
||||||
TypeString DataType = "string"
|
|
||||||
TypeInteger DataType = "integer"
|
|
||||||
TypeNumber DataType = "number"
|
|
||||||
TypeBoolean DataType = "boolean"
|
|
||||||
TypeObject DataType = "object"
|
|
||||||
TypeArray DataType = "array"
|
|
||||||
)
|
|
||||||
|
|
||||||
type basic struct {
|
|
||||||
DataType
|
|
||||||
typeName string
|
|
||||||
|
|
||||||
// index is the position of the parameter in the StructField of the function's parameter struct
|
|
||||||
index int
|
|
||||||
|
|
||||||
// required is a flag that indicates whether the parameter is required in the function's parameter struct.
|
|
||||||
// this is inferred by if the parameter is a pointer type or not.
|
|
||||||
required bool
|
|
||||||
|
|
||||||
// description is a llm-readable description of the parameter passed to openai
|
|
||||||
description string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b basic) OpenAIParameters() openai.FunctionParameters {
|
|
||||||
return openai.FunctionParameters{
|
|
||||||
"type": b.typeName,
|
|
||||||
"description": b.description,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b basic) GoogleParameters() *genai.Schema {
|
|
||||||
var t = genai.TypeUnspecified
|
|
||||||
|
|
||||||
switch b.DataType {
|
|
||||||
case TypeString:
|
|
||||||
t = genai.TypeString
|
|
||||||
case TypeInteger:
|
|
||||||
t = genai.TypeInteger
|
|
||||||
case TypeNumber:
|
|
||||||
t = genai.TypeNumber
|
|
||||||
case TypeBoolean:
|
|
||||||
t = genai.TypeBoolean
|
|
||||||
case TypeObject:
|
|
||||||
t = genai.TypeObject
|
|
||||||
case TypeArray:
|
|
||||||
t = genai.TypeArray
|
|
||||||
default:
|
|
||||||
t = genai.TypeUnspecified
|
|
||||||
}
|
|
||||||
return &genai.Schema{
|
|
||||||
Type: t,
|
|
||||||
Description: b.description,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b basic) AnthropicInputSchema() map[string]any {
|
|
||||||
var t = "string"
|
|
||||||
|
|
||||||
switch b.DataType {
|
|
||||||
case TypeString:
|
|
||||||
t = "string"
|
|
||||||
case TypeInteger:
|
|
||||||
t = "integer"
|
|
||||||
case TypeNumber:
|
|
||||||
t = "number"
|
|
||||||
case TypeBoolean:
|
|
||||||
t = "boolean"
|
|
||||||
case TypeObject:
|
|
||||||
t = "object"
|
|
||||||
case TypeArray:
|
|
||||||
t = "array"
|
|
||||||
default:
|
|
||||||
t = "unknown"
|
|
||||||
}
|
|
||||||
|
|
||||||
return map[string]any{
|
|
||||||
"type": t,
|
|
||||||
"description": b.description,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b basic) Required() bool {
|
|
||||||
return b.required
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b basic) Description() string {
|
|
||||||
return b.description
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b basic) FromAny(val any) (reflect.Value, error) {
|
|
||||||
v := reflect.ValueOf(val)
|
|
||||||
|
|
||||||
switch b.DataType {
|
|
||||||
case TypeString:
|
|
||||||
var val = v.String()
|
|
||||||
|
|
||||||
return reflect.ValueOf(val), nil
|
|
||||||
|
|
||||||
case TypeInteger:
|
|
||||||
if v.Kind() == reflect.Float64 {
|
|
||||||
return v.Convert(reflect.TypeOf(int(0))), nil
|
|
||||||
} else if v.Kind() != reflect.Int {
|
|
||||||
return reflect.Value{}, errors.New("expected int, got " + v.Kind().String())
|
|
||||||
} else {
|
|
||||||
return v, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
case TypeNumber:
|
|
||||||
if v.Kind() == reflect.Float64 {
|
|
||||||
return v.Convert(reflect.TypeOf(float64(0))), nil
|
|
||||||
} else if v.Kind() != reflect.Float64 {
|
|
||||||
return reflect.Value{}, errors.New("expected float64, got " + v.Kind().String())
|
|
||||||
} else {
|
|
||||||
return v, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
case TypeBoolean:
|
|
||||||
if v.Kind() == reflect.Bool {
|
|
||||||
return v, nil
|
|
||||||
} else if v.Kind() == reflect.String {
|
|
||||||
b, err := strconv.ParseBool(v.String())
|
|
||||||
if err != nil {
|
|
||||||
return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String())
|
|
||||||
}
|
|
||||||
return reflect.ValueOf(b), nil
|
|
||||||
} else {
|
|
||||||
return reflect.Value{}, errors.New("expected bool, got " + v.Kind().String())
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
return reflect.Value{}, errors.New("unknown type")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b basic) SetValueOnField(obj reflect.Value, val reflect.Value) {
|
|
||||||
// if this basic type is not required that means it's a pointer type
|
|
||||||
// so we need to create a new value of the type of the pointer
|
|
||||||
if !b.required {
|
|
||||||
vv := reflect.New(obj.Field(b.index).Type().Elem())
|
|
||||||
|
|
||||||
// and then set the value of the pointer to the new value
|
|
||||||
vv.Elem().Set(val)
|
|
||||||
|
|
||||||
obj.Field(b.index).Set(vv)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
obj.Field(b.index).Set(val)
|
|
||||||
}
|
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
package schema
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"reflect"
|
|
||||||
"slices"
|
|
||||||
|
|
||||||
"github.com/openai/openai-go"
|
|
||||||
"google.golang.org/genai"
|
|
||||||
)
|
|
||||||
|
|
||||||
type enum struct {
|
|
||||||
basic
|
|
||||||
|
|
||||||
values []string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e enum) FunctionParameters() openai.FunctionParameters {
|
|
||||||
return openai.FunctionParameters{
|
|
||||||
"type": "string",
|
|
||||||
"description": e.Description(),
|
|
||||||
"enum": e.values,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e enum) GoogleParameters() *genai.Schema {
|
|
||||||
return &genai.Schema{
|
|
||||||
Type: genai.TypeString,
|
|
||||||
Description: e.Description(),
|
|
||||||
Enum: e.values,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e enum) AnthropicInputSchema() map[string]any {
|
|
||||||
return map[string]any{
|
|
||||||
"type": "string",
|
|
||||||
"description": e.Description(),
|
|
||||||
"enum": e.values,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e enum) FromAny(val any) (reflect.Value, error) {
|
|
||||||
v := reflect.ValueOf(val)
|
|
||||||
if v.Kind() != reflect.String {
|
|
||||||
return reflect.Value{}, errors.New("expected string, got " + v.Kind().String())
|
|
||||||
}
|
|
||||||
|
|
||||||
s := v.String()
|
|
||||||
if !slices.Contains(e.values, s) {
|
|
||||||
return reflect.Value{}, errors.New("value " + s + " not in enum")
|
|
||||||
}
|
|
||||||
|
|
||||||
return v, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e enum) SetValueOnField(obj reflect.Value, val reflect.Value) {
|
|
||||||
if !e.required {
|
|
||||||
val = val.Addr()
|
|
||||||
}
|
|
||||||
obj.Field(e.index).Set(val)
|
|
||||||
}
|
|
||||||
@@ -1,169 +0,0 @@
|
|||||||
package schema
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"github.com/openai/openai-go"
|
|
||||||
"google.golang.org/genai"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// SyntheticFieldPrefix is any prefix that is added to any synthetic fields that are added to the object, to prevent
|
|
||||||
// collisions with the fields in the struct.
|
|
||||||
SyntheticFieldPrefix = "__"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Object struct {
|
|
||||||
basic
|
|
||||||
|
|
||||||
ref reflect.Type
|
|
||||||
|
|
||||||
fields map[string]Type
|
|
||||||
|
|
||||||
// syntheticFields are fields that are not in the struct but are generated by a system.
|
|
||||||
synetheticFields map[string]Type
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o Object) WithSyntheticField(name string, description string) Object {
|
|
||||||
if o.synetheticFields == nil {
|
|
||||||
o.synetheticFields = map[string]Type{}
|
|
||||||
}
|
|
||||||
|
|
||||||
o.synetheticFields[name] = basic{
|
|
||||||
DataType: TypeString,
|
|
||||||
typeName: "string",
|
|
||||||
index: -1,
|
|
||||||
required: false,
|
|
||||||
description: description,
|
|
||||||
}
|
|
||||||
|
|
||||||
return o
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o Object) SyntheticFields() map[string]Type {
|
|
||||||
return o.synetheticFields
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o Object) OpenAIParameters() openai.FunctionParameters {
|
|
||||||
var properties = map[string]openai.FunctionParameters{}
|
|
||||||
var required []string
|
|
||||||
for k, v := range o.fields {
|
|
||||||
properties[k] = v.OpenAIParameters()
|
|
||||||
if v.Required() {
|
|
||||||
required = append(required, k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range o.synetheticFields {
|
|
||||||
properties[SyntheticFieldPrefix+k] = v.OpenAIParameters()
|
|
||||||
if v.Required() {
|
|
||||||
required = append(required, SyntheticFieldPrefix+k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var res = openai.FunctionParameters{
|
|
||||||
"type": "object",
|
|
||||||
"description": o.Description(),
|
|
||||||
"properties": properties,
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(required) > 0 {
|
|
||||||
res["required"] = required
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o Object) GoogleParameters() *genai.Schema {
|
|
||||||
var properties = map[string]*genai.Schema{}
|
|
||||||
var required []string
|
|
||||||
for k, v := range o.fields {
|
|
||||||
properties[k] = v.GoogleParameters()
|
|
||||||
if v.Required() {
|
|
||||||
required = append(required, k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var res = &genai.Schema{
|
|
||||||
Type: genai.TypeObject,
|
|
||||||
Description: o.Description(),
|
|
||||||
Properties: properties,
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(required) > 0 {
|
|
||||||
res.Required = required
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o Object) AnthropicInputSchema() map[string]any {
|
|
||||||
var properties = map[string]any{}
|
|
||||||
var required []string
|
|
||||||
for k, v := range o.fields {
|
|
||||||
properties[k] = v.AnthropicInputSchema()
|
|
||||||
if v.Required() {
|
|
||||||
required = append(required, k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var res = map[string]any{
|
|
||||||
"type": "object",
|
|
||||||
"description": o.Description(),
|
|
||||||
"properties": properties,
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(required) > 0 {
|
|
||||||
res["required"] = required
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// FromAny converts the value from any to the correct type, returning the value, and an error if any
|
|
||||||
func (o Object) FromAny(val any) (reflect.Value, error) {
|
|
||||||
// if the value is nil, we can't do anything
|
|
||||||
if val == nil {
|
|
||||||
return reflect.Value{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// now make a new object of the type we're trying to parse
|
|
||||||
obj := reflect.New(o.ref).Elem()
|
|
||||||
|
|
||||||
// now we need to iterate over the fields and set the values
|
|
||||||
for k, v := range o.fields {
|
|
||||||
// get the field by name
|
|
||||||
field := obj.FieldByName(k)
|
|
||||||
if !field.IsValid() {
|
|
||||||
return reflect.Value{}, errors.New("field " + k + " not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
// get the value from the map
|
|
||||||
val2, ok := val.(map[string]interface{})[k]
|
|
||||||
if !ok {
|
|
||||||
return reflect.Value{}, errors.New("field " + k + " not found in map")
|
|
||||||
}
|
|
||||||
|
|
||||||
// now we need to convert the value to the correct type
|
|
||||||
val3, err := v.FromAny(val2)
|
|
||||||
if err != nil {
|
|
||||||
return reflect.Value{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// now we need to set the value on the field
|
|
||||||
v.SetValueOnField(field, val3)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return obj, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o Object) SetValueOnField(obj reflect.Value, val reflect.Value) {
|
|
||||||
// if this basic type is not required that means it's a pointer type so we need to set the value to the address of the value
|
|
||||||
if !o.required {
|
|
||||||
val = val.Addr()
|
|
||||||
}
|
|
||||||
|
|
||||||
obj.Field(o.index).Set(val)
|
|
||||||
}
|
|
||||||
-134
@@ -1,134 +0,0 @@
|
|||||||
package schema
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"github.com/openai/openai-go"
|
|
||||||
"google.golang.org/genai"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Raw represents a raw JSON schema that is passed through directly.
|
|
||||||
// This is used for MCP tools where we receive the schema from the server.
|
|
||||||
type Raw struct {
|
|
||||||
schema map[string]any
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewRaw creates a new Raw schema from a map.
|
|
||||||
func NewRaw(schema map[string]any) Raw {
|
|
||||||
if schema == nil {
|
|
||||||
schema = map[string]any{
|
|
||||||
"type": "object",
|
|
||||||
"properties": map[string]any{},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Raw{schema: schema}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewRawFromJSON creates a new Raw schema from JSON bytes.
|
|
||||||
func NewRawFromJSON(data []byte) (Raw, error) {
|
|
||||||
var schema map[string]any
|
|
||||||
if err := json.Unmarshal(data, &schema); err != nil {
|
|
||||||
return Raw{}, fmt.Errorf("failed to parse JSON schema: %w", err)
|
|
||||||
}
|
|
||||||
return NewRaw(schema), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r Raw) OpenAIParameters() openai.FunctionParameters {
|
|
||||||
return openai.FunctionParameters(r.schema)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r Raw) GoogleParameters() *genai.Schema {
|
|
||||||
return mapToGenaiSchema(r.schema)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r Raw) AnthropicInputSchema() map[string]any {
|
|
||||||
return r.schema
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r Raw) Required() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r Raw) Description() string {
|
|
||||||
if desc, ok := r.schema["description"].(string); ok {
|
|
||||||
return desc
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r Raw) FromAny(val any) (reflect.Value, error) {
|
|
||||||
return reflect.ValueOf(val), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r Raw) SetValueOnField(obj reflect.Value, val reflect.Value) {
|
|
||||||
// No-op for raw schemas
|
|
||||||
}
|
|
||||||
|
|
||||||
// mapToGenaiSchema converts a map[string]any JSON schema to genai.Schema
|
|
||||||
func mapToGenaiSchema(m map[string]any) *genai.Schema {
|
|
||||||
if m == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
schema := &genai.Schema{}
|
|
||||||
|
|
||||||
// Type
|
|
||||||
if t, ok := m["type"].(string); ok {
|
|
||||||
switch t {
|
|
||||||
case "string":
|
|
||||||
schema.Type = genai.TypeString
|
|
||||||
case "number":
|
|
||||||
schema.Type = genai.TypeNumber
|
|
||||||
case "integer":
|
|
||||||
schema.Type = genai.TypeInteger
|
|
||||||
case "boolean":
|
|
||||||
schema.Type = genai.TypeBoolean
|
|
||||||
case "array":
|
|
||||||
schema.Type = genai.TypeArray
|
|
||||||
case "object":
|
|
||||||
schema.Type = genai.TypeObject
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Description
|
|
||||||
if desc, ok := m["description"].(string); ok {
|
|
||||||
schema.Description = desc
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enum
|
|
||||||
if enum, ok := m["enum"].([]any); ok {
|
|
||||||
for _, e := range enum {
|
|
||||||
if s, ok := e.(string); ok {
|
|
||||||
schema.Enum = append(schema.Enum, s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Properties (for objects)
|
|
||||||
if props, ok := m["properties"].(map[string]any); ok {
|
|
||||||
schema.Properties = make(map[string]*genai.Schema)
|
|
||||||
for k, v := range props {
|
|
||||||
if vm, ok := v.(map[string]any); ok {
|
|
||||||
schema.Properties[k] = mapToGenaiSchema(vm)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Required
|
|
||||||
if req, ok := m["required"].([]any); ok {
|
|
||||||
for _, r := range req {
|
|
||||||
if s, ok := r.(string); ok {
|
|
||||||
schema.Required = append(schema.Required, s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Items (for arrays)
|
|
||||||
if items, ok := m["items"].(map[string]any); ok {
|
|
||||||
schema.Items = mapToGenaiSchema(items)
|
|
||||||
}
|
|
||||||
|
|
||||||
return schema
|
|
||||||
}
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
package schema
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"github.com/openai/openai-go"
|
|
||||||
"google.golang.org/genai"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Type interface {
|
|
||||||
OpenAIParameters() openai.FunctionParameters
|
|
||||||
GoogleParameters() *genai.Schema
|
|
||||||
AnthropicInputSchema() map[string]any
|
|
||||||
|
|
||||||
//SchemaType() jsonschema.DataType
|
|
||||||
//Definition() jsonschema.Definition
|
|
||||||
|
|
||||||
Required() bool
|
|
||||||
Description() string
|
|
||||||
|
|
||||||
FromAny(any) (reflect.Value, error)
|
|
||||||
SetValueOnField(obj reflect.Value, val reflect.Value)
|
|
||||||
}
|
|
||||||
-174
@@ -1,174 +0,0 @@
|
|||||||
package llm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ToolBox is a collection of tools that OpenAI can use to execute functions.
|
|
||||||
// It is a wrapper around a collection of functions, and provides a way to automatically call the correct function with
|
|
||||||
// the correct parameters.
|
|
||||||
type ToolBox struct {
|
|
||||||
functions map[string]Function
|
|
||||||
mcpServers map[string]*MCPServer // tool name -> MCP server that provides it
|
|
||||||
dontRequireTool bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewToolBox(fns ...Function) ToolBox {
|
|
||||||
res := ToolBox{
|
|
||||||
functions: map[string]Function{},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, f := range fns {
|
|
||||||
res.functions[f.Name] = f
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t ToolBox) Functions() []Function {
|
|
||||||
var res []Function
|
|
||||||
|
|
||||||
for _, f := range t.functions {
|
|
||||||
res = append(res, f)
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t ToolBox) WithFunction(f Function) ToolBox {
|
|
||||||
t.functions[f.Name] = f
|
|
||||||
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t ToolBox) WithFunctions(fns ...Function) ToolBox {
|
|
||||||
for _, f := range fns {
|
|
||||||
t.functions[f.Name] = f
|
|
||||||
}
|
|
||||||
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t ToolBox) WithSyntheticFieldsAddedToAllFunctions(fieldsAndDescriptions map[string]string) ToolBox {
|
|
||||||
for k, v := range t.functions {
|
|
||||||
t.functions[k] = v.WithSyntheticFields(fieldsAndDescriptions)
|
|
||||||
}
|
|
||||||
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t ToolBox) ForEachFunction(fn func(f Function)) {
|
|
||||||
for _, f := range t.functions {
|
|
||||||
fn(f)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t ToolBox) WithFunctionRemoved(name string) ToolBox {
|
|
||||||
delete(t.functions, name)
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t ToolBox) WithRequireTool(val bool) ToolBox {
|
|
||||||
t.dontRequireTool = !val
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t ToolBox) RequiresTool() bool {
|
|
||||||
return !t.dontRequireTool && len(t.functions) > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t ToolBox) ToToolChoice() any {
|
|
||||||
if len(t.functions) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return "required"
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrFunctionNotFound = errors.New("function not found")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (t ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) {
|
|
||||||
// Check if this is an MCP tool
|
|
||||||
if server, ok := t.mcpServers[functionName]; ok {
|
|
||||||
var args map[string]any
|
|
||||||
if params != "" {
|
|
||||||
if err := json.Unmarshal([]byte(params), &args); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse MCP tool arguments: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return server.CallTool(ctx, functionName, args)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Regular function
|
|
||||||
f, ok := t.functions[functionName]
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
return "", newError(ErrFunctionNotFound, fmt.Errorf("function \"%s\" not found", functionName))
|
|
||||||
}
|
|
||||||
|
|
||||||
return f.Execute(ctx, params)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t ToolBox) Execute(ctx *Context, toolCall ToolCall) (any, error) {
|
|
||||||
return t.executeFunction(ctx.WithToolCall(&toolCall), toolCall.FunctionCall.Name, toolCall.FunctionCall.Arguments)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t ToolBox) GetSyntheticParametersFromFunctionContext(ctx context.Context) map[string]string {
|
|
||||||
val := ctx.Value("syntheticParameters")
|
|
||||||
|
|
||||||
if val == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
syntheticParameters, ok := val.(map[string]string)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return syntheticParameters
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExecuteCallbacks will execute all the tool calls in the given list, and call the given callbacks when a new function is created, and when a function is finished.
|
|
||||||
// OnNewFunction is called when a new function is created
|
|
||||||
// OnFunctionFinished is called when a function is finished
|
|
||||||
func (t ToolBox) ExecuteCallbacks(ctx *Context, toolCalls []ToolCall, OnNewFunction func(ctx context.Context, funcName string, parameter string) (any, error), OnFunctionFinished func(ctx context.Context, funcName string, parameter string, result any, err error, newFunctionResult any) error) ([]ToolCallResponse, error) {
|
|
||||||
var res []ToolCallResponse
|
|
||||||
|
|
||||||
for _, call := range toolCalls {
|
|
||||||
ctx := ctx.WithToolCall(&call)
|
|
||||||
if call.FunctionCall.Name == "" {
|
|
||||||
return nil, newError(ErrFunctionNotFound, errors.New("function name is empty"))
|
|
||||||
}
|
|
||||||
|
|
||||||
var arg any
|
|
||||||
if OnNewFunction != nil {
|
|
||||||
var err error
|
|
||||||
arg, err = OnNewFunction(ctx, call.FunctionCall.Name, call.FunctionCall.Arguments)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, newError(ErrFunctionNotFound, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out, err := t.Execute(ctx, call)
|
|
||||||
|
|
||||||
if OnFunctionFinished != nil {
|
|
||||||
err := OnFunctionFinished(ctx, call.FunctionCall.Name, call.FunctionCall.Arguments, out, err, arg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, newError(ErrFunctionNotFound, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
res = append(res, ToolCallResponse{
|
|
||||||
ID: call.ID,
|
|
||||||
Result: out,
|
|
||||||
Error: err,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
-145
@@ -1,145 +0,0 @@
|
|||||||
package llm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Transcriber abstracts a speech-to-text model implementation.
|
|
||||||
type Transcriber interface {
|
|
||||||
Transcribe(ctx context.Context, wav []byte, opts TranscriptionOptions) (Transcription, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TranscriptionResponseFormat controls the output format requested from a transcriber.
|
|
||||||
type TranscriptionResponseFormat string
|
|
||||||
|
|
||||||
const (
|
|
||||||
TranscriptionResponseFormatJSON TranscriptionResponseFormat = "json"
|
|
||||||
TranscriptionResponseFormatVerboseJSON TranscriptionResponseFormat = "verbose_json"
|
|
||||||
TranscriptionResponseFormatText TranscriptionResponseFormat = "text"
|
|
||||||
TranscriptionResponseFormatSRT TranscriptionResponseFormat = "srt"
|
|
||||||
TranscriptionResponseFormatVTT TranscriptionResponseFormat = "vtt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TranscriptionTimestampGranularity defines the requested timestamp detail.
|
|
||||||
type TranscriptionTimestampGranularity string
|
|
||||||
|
|
||||||
const (
|
|
||||||
TranscriptionTimestampGranularityWord TranscriptionTimestampGranularity = "word"
|
|
||||||
TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TranscriptionOptions configures transcription behavior.
|
|
||||||
type TranscriptionOptions struct {
|
|
||||||
Language string
|
|
||||||
Prompt string
|
|
||||||
Temperature *float64
|
|
||||||
ResponseFormat TranscriptionResponseFormat
|
|
||||||
TimestampGranularities []TranscriptionTimestampGranularity
|
|
||||||
IncludeLogprobs bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Transcription captures a normalized transcription result.
|
|
||||||
type Transcription struct {
|
|
||||||
Provider string
|
|
||||||
Model string
|
|
||||||
Text string
|
|
||||||
Language string
|
|
||||||
DurationSeconds float64
|
|
||||||
Segments []TranscriptionSegment
|
|
||||||
Words []TranscriptionWord
|
|
||||||
Logprobs []TranscriptionTokenLogprob
|
|
||||||
Usage TranscriptionUsage
|
|
||||||
RawJSON string
|
|
||||||
}
|
|
||||||
|
|
||||||
// TranscriptionSegment provides a coarse time-sliced transcription segment.
|
|
||||||
type TranscriptionSegment struct {
|
|
||||||
ID int
|
|
||||||
Start float64
|
|
||||||
End float64
|
|
||||||
Text string
|
|
||||||
Tokens []int
|
|
||||||
AvgLogprob *float64
|
|
||||||
CompressionRatio *float64
|
|
||||||
NoSpeechProb *float64
|
|
||||||
Words []TranscriptionWord
|
|
||||||
}
|
|
||||||
|
|
||||||
// TranscriptionWord provides a word-level timestamp.
|
|
||||||
type TranscriptionWord struct {
|
|
||||||
Word string
|
|
||||||
Start float64
|
|
||||||
End float64
|
|
||||||
Confidence *float64
|
|
||||||
}
|
|
||||||
|
|
||||||
// TranscriptionTokenLogprob captures token-level log probability details.
|
|
||||||
type TranscriptionTokenLogprob struct {
|
|
||||||
Token string
|
|
||||||
Bytes []float64
|
|
||||||
Logprob float64
|
|
||||||
}
|
|
||||||
|
|
||||||
// TranscriptionUsage captures token or duration usage details.
|
|
||||||
type TranscriptionUsage struct {
|
|
||||||
Type string
|
|
||||||
InputTokens int64
|
|
||||||
OutputTokens int64
|
|
||||||
TotalTokens int64
|
|
||||||
AudioTokens int64
|
|
||||||
TextTokens int64
|
|
||||||
Seconds float64
|
|
||||||
}
|
|
||||||
|
|
||||||
// TranscribeFile converts an audio file to WAV and transcribes it.
|
|
||||||
func TranscribeFile(ctx context.Context, filename string, transcriber Transcriber, opts TranscriptionOptions) (Transcription, error) {
|
|
||||||
if transcriber == nil {
|
|
||||||
return Transcription{}, fmt.Errorf("transcriber is nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
wav, err := audioFileToWav(ctx, filename)
|
|
||||||
if err != nil {
|
|
||||||
return Transcription{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return transcriber.Transcribe(ctx, wav, opts)
|
|
||||||
}
|
|
||||||
|
|
||||||
func audioFileToWav(ctx context.Context, filename string) ([]byte, error) {
|
|
||||||
if filename == "" {
|
|
||||||
return nil, fmt.Errorf("filename is empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.EqualFold(filepath.Ext(filename), ".wav") {
|
|
||||||
data, err := os.ReadFile(filename)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("read wav file: %w", err)
|
|
||||||
}
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
tempFile, err := os.CreateTemp("", "go-llm-audio-*.wav")
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create temp wav file: %w", err)
|
|
||||||
}
|
|
||||||
tempPath := tempFile.Name()
|
|
||||||
_ = tempFile.Close()
|
|
||||||
defer os.Remove(tempPath)
|
|
||||||
|
|
||||||
cmd := exec.CommandContext(ctx, "ffmpeg", "-hide_banner", "-loglevel", "error", "-y", "-i", filename, "-vn", "-f", "wav", tempPath)
|
|
||||||
if output, err := cmd.CombinedOutput(); err != nil {
|
|
||||||
return nil, fmt.Errorf("ffmpeg convert failed: %w (output: %s)", err, strings.TrimSpace(string(output)))
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := os.ReadFile(tempPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("read converted wav file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
# go-llm CLI environment variables
|
||||||
|
# Copy this file to .env and fill in the keys for providers you use.
|
||||||
|
|
||||||
|
# OpenAI API Key (https://platform.openai.com/api-keys)
|
||||||
|
OPENAI_API_KEY=
|
||||||
|
|
||||||
|
# Anthropic API Key (https://console.anthropic.com/settings/keys)
|
||||||
|
ANTHROPIC_API_KEY=
|
||||||
|
|
||||||
|
# Google AI API Key (https://aistudio.google.com/apikey)
|
||||||
|
GOOGLE_API_KEY=
|
||||||
|
|
||||||
|
# DeepSeek API Key (https://platform.deepseek.com)
|
||||||
|
DEEPSEEK_API_KEY=
|
||||||
|
|
||||||
|
# Moonshot / Kimi API Key (https://platform.moonshot.ai)
|
||||||
|
MOONSHOT_API_KEY=
|
||||||
|
|
||||||
|
# xAI / Grok API Key (https://x.ai/api)
|
||||||
|
XAI_API_KEY=
|
||||||
|
|
||||||
|
# Groq API Key (https://console.groq.com/keys)
|
||||||
|
GROQ_API_KEY=
|
||||||
|
|
||||||
|
# Ollama runs locally with no API key required.
|
||||||
|
# Override the endpoint if you're not using localhost:11434.
|
||||||
|
# OLLAMA_BASE_URL=http://localhost:11434/v1
|
||||||
@@ -0,0 +1,136 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
|
||||||
|
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Message types for async operations.
|
||||||
|
|
||||||
|
// ChatResponseMsg contains the response from a chat completion.
|
||||||
|
type ChatResponseMsg struct {
|
||||||
|
Response llm.Response
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolExecutionMsg contains results from executing tool calls, one Message
|
||||||
|
// (RoleTool) per ToolCall, in the same order.
|
||||||
|
type ToolExecutionMsg struct {
|
||||||
|
Results []llm.Message
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageLoadedMsg contains a loaded image.
|
||||||
|
type ImageLoadedMsg struct {
|
||||||
|
Image llm.Image
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendChatRequest sends a completion request with the current conversation,
|
||||||
|
// returning a ChatResponseMsg tea.Msg when the provider responds.
|
||||||
|
func sendChatRequest(model *llm.Model, messages []llm.Message, toolbox *llm.ToolBox, toolsEnabled bool, temperature *float64) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
opts := buildOpts(toolbox, toolsEnabled, temperature)
|
||||||
|
resp, err := model.Complete(context.Background(), messages, opts...)
|
||||||
|
return ChatResponseMsg{Response: resp, Err: err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeTools runs each tool call via the toolbox and returns ToolExecutionMsg
|
||||||
|
// with one RoleTool Message per call, in the same order.
|
||||||
|
func executeTools(toolbox *llm.ToolBox, calls []llm.ToolCall) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
ctx := context.Background()
|
||||||
|
results, err := toolbox.ExecuteAll(ctx, calls)
|
||||||
|
return ToolExecutionMsg{Results: results, Err: err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildOpts constructs RequestOptions from the current CLI state.
|
||||||
|
func buildOpts(toolbox *llm.ToolBox, toolsEnabled bool, temperature *float64) []llm.RequestOption {
|
||||||
|
var opts []llm.RequestOption
|
||||||
|
if toolsEnabled && toolbox != nil && len(toolbox.AllTools()) > 0 {
|
||||||
|
opts = append(opts, llm.WithTools(toolbox))
|
||||||
|
}
|
||||||
|
if temperature != nil {
|
||||||
|
opts = append(opts, llm.WithTemperature(*temperature))
|
||||||
|
}
|
||||||
|
return opts
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadImageFromPath loads an image from a file path.
|
||||||
|
func loadImageFromPath(path string) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
path = strings.TrimSpace(path)
|
||||||
|
path = strings.Trim(path, "\"'")
|
||||||
|
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return ImageLoadedMsg{Err: fmt.Errorf("failed to read image file: %w", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType := http.DetectContentType(data)
|
||||||
|
if !strings.HasPrefix(contentType, "image/") {
|
||||||
|
return ImageLoadedMsg{Err: fmt.Errorf("file is not an image: %s", contentType)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ImageLoadedMsg{
|
||||||
|
Image: llm.Image{
|
||||||
|
Base64: base64.StdEncoding.EncodeToString(data),
|
||||||
|
ContentType: contentType,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadImageFromURL loads an image from a URL (kept as URL, not fetched).
|
||||||
|
func loadImageFromURL(url string) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
return ImageLoadedMsg{Image: llm.Image{URL: strings.TrimSpace(url)}}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadImageFromBase64 loads an image from base64 data (raw or data: URL).
|
||||||
|
func loadImageFromBase64(data string) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
data = strings.TrimSpace(data)
|
||||||
|
|
||||||
|
if strings.HasPrefix(data, "data:") {
|
||||||
|
parts := strings.SplitN(data, ",", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return ImageLoadedMsg{Err: fmt.Errorf("invalid data URL format")}
|
||||||
|
}
|
||||||
|
mediaType := strings.TrimPrefix(parts[0], "data:")
|
||||||
|
mediaType = strings.TrimSuffix(mediaType, ";base64")
|
||||||
|
return ImageLoadedMsg{
|
||||||
|
Image: llm.Image{
|
||||||
|
Base64: parts[1],
|
||||||
|
ContentType: mediaType,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||||
|
if err != nil {
|
||||||
|
return ImageLoadedMsg{Err: fmt.Errorf("invalid base64 data: %w", err)}
|
||||||
|
}
|
||||||
|
contentType := http.DetectContentType(decoded)
|
||||||
|
if !strings.HasPrefix(contentType, "image/") {
|
||||||
|
return ImageLoadedMsg{Err: fmt.Errorf("data is not an image: %s", contentType)}
|
||||||
|
}
|
||||||
|
return ImageLoadedMsg{
|
||||||
|
Image: llm.Image{
|
||||||
|
Base64: data,
|
||||||
|
ContentType: contentType,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,10 +7,10 @@ import (
|
|||||||
"github.com/charmbracelet/bubbles/viewport"
|
"github.com/charmbracelet/bubbles/viewport"
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
|
||||||
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
|
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// State represents the current view/screen of the application
|
// State represents the current view/screen of the application.
|
||||||
type State int
|
type State int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -23,43 +23,42 @@ const (
|
|||||||
StateAPIKeyInput
|
StateAPIKeyInput
|
||||||
)
|
)
|
||||||
|
|
||||||
// DisplayMessage represents a message for display in the UI
|
// DisplayMessage represents a message for display in the UI.
|
||||||
type DisplayMessage struct {
|
type DisplayMessage struct {
|
||||||
Role llm.Role
|
Role llm.Role
|
||||||
Content string
|
Content string
|
||||||
Images int // number of images attached
|
Images int // number of images attached
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProviderInfo contains information about a provider
|
// ProviderEntry is a CLI-local view of a registered provider, enriched with
|
||||||
type ProviderInfo struct {
|
// UI state (which model is currently chosen, whether we have a key, etc.).
|
||||||
Name string
|
type ProviderEntry struct {
|
||||||
EnvVar string
|
Info llm.ProviderInfo
|
||||||
Models []string
|
|
||||||
HasAPIKey bool
|
HasAPIKey bool
|
||||||
ModelIndex int
|
ModelIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
// Model is the main Bubble Tea model
|
// Model is the main Bubble Tea model.
|
||||||
type Model struct {
|
type Model struct {
|
||||||
// State
|
// State
|
||||||
state State
|
state State
|
||||||
previousState State
|
previousState State
|
||||||
|
|
||||||
// Provider
|
// Provider
|
||||||
provider llm.LLM
|
client *llm.Client
|
||||||
|
chat *llm.Model
|
||||||
providerName string
|
providerName string
|
||||||
chat llm.ChatCompletion
|
|
||||||
modelName string
|
modelName string
|
||||||
apiKeys map[string]string
|
apiKeys map[string]string
|
||||||
providers []ProviderInfo
|
providers []ProviderEntry
|
||||||
providerIndex int
|
providerIndex int
|
||||||
|
|
||||||
// Conversation
|
// Conversation
|
||||||
conversation []llm.Input
|
conversation []llm.Message
|
||||||
messages []DisplayMessage
|
messages []DisplayMessage
|
||||||
|
|
||||||
// Tools
|
// Tools
|
||||||
toolbox llm.ToolBox
|
toolbox *llm.ToolBox
|
||||||
toolsEnabled bool
|
toolsEnabled bool
|
||||||
|
|
||||||
// Settings
|
// Settings
|
||||||
@@ -90,7 +89,7 @@ type Model struct {
|
|||||||
apiKeyInput textinput.Model
|
apiKeyInput textinput.Model
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitialModel creates and returns the initial model
|
// InitialModel creates and returns the initial model.
|
||||||
func InitialModel() Model {
|
func InitialModel() Model {
|
||||||
ti := textinput.New()
|
ti := textinput.New()
|
||||||
ti.Placeholder = "Type your message..."
|
ti.Placeholder = "Type your message..."
|
||||||
@@ -104,60 +103,21 @@ func InitialModel() Model {
|
|||||||
aki.Width = 60
|
aki.Width = 60
|
||||||
aki.EchoMode = textinput.EchoPassword
|
aki.EchoMode = textinput.EchoPassword
|
||||||
|
|
||||||
// Initialize providers with environment variable checks
|
// Build provider list from the go-llm registry.
|
||||||
providers := []ProviderInfo{
|
registry := llm.Providers()
|
||||||
{
|
providers := make([]ProviderEntry, len(registry))
|
||||||
Name: "OpenAI",
|
|
||||||
EnvVar: "OPENAI_API_KEY",
|
|
||||||
Models: []string{
|
|
||||||
"gpt-4.1",
|
|
||||||
"gpt-4.1-mini",
|
|
||||||
"gpt-4.1-nano",
|
|
||||||
"gpt-4o",
|
|
||||||
"gpt-4o-mini",
|
|
||||||
"gpt-4-turbo",
|
|
||||||
"gpt-3.5-turbo",
|
|
||||||
"o1",
|
|
||||||
"o1-mini",
|
|
||||||
"o1-preview",
|
|
||||||
"o3-mini",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "Anthropic",
|
|
||||||
EnvVar: "ANTHROPIC_API_KEY",
|
|
||||||
Models: []string{
|
|
||||||
"claude-sonnet-4-20250514",
|
|
||||||
"claude-opus-4-20250514",
|
|
||||||
"claude-3-7-sonnet-20250219",
|
|
||||||
"claude-3-5-sonnet-20241022",
|
|
||||||
"claude-3-5-haiku-20241022",
|
|
||||||
"claude-3-opus-20240229",
|
|
||||||
"claude-3-sonnet-20240229",
|
|
||||||
"claude-3-haiku-20240307",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "Google",
|
|
||||||
EnvVar: "GOOGLE_API_KEY",
|
|
||||||
Models: []string{
|
|
||||||
"gemini-2.0-flash",
|
|
||||||
"gemini-2.0-flash-lite",
|
|
||||||
"gemini-1.5-pro",
|
|
||||||
"gemini-1.5-flash",
|
|
||||||
"gemini-1.5-flash-8b",
|
|
||||||
"gemini-1.0-pro",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for API keys in environment
|
|
||||||
apiKeys := make(map[string]string)
|
apiKeys := make(map[string]string)
|
||||||
for i := range providers {
|
|
||||||
if key := os.Getenv(providers[i].EnvVar); key != "" {
|
for i, info := range registry {
|
||||||
apiKeys[providers[i].Name] = key
|
entry := ProviderEntry{Info: info}
|
||||||
providers[i].HasAPIKey = true
|
if info.EnvKey == "" {
|
||||||
|
// Key-less provider (e.g., Ollama).
|
||||||
|
entry.HasAPIKey = true
|
||||||
|
} else if key := os.Getenv(info.EnvKey); key != "" {
|
||||||
|
apiKeys[info.Name] = key
|
||||||
|
entry.HasAPIKey = true
|
||||||
}
|
}
|
||||||
|
providers[i] = entry
|
||||||
}
|
}
|
||||||
|
|
||||||
m := Model{
|
m := Model{
|
||||||
@@ -170,97 +130,87 @@ func InitialModel() Model {
|
|||||||
toolbox: createDemoToolbox(),
|
toolbox: createDemoToolbox(),
|
||||||
toolsEnabled: false,
|
toolsEnabled: false,
|
||||||
messages: []DisplayMessage{},
|
messages: []DisplayMessage{},
|
||||||
conversation: []llm.Input{},
|
conversation: []llm.Message{},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build list items for provider selection
|
// Build list items for provider selection.
|
||||||
m.listItems = make([]string, len(providers))
|
m.listItems = make([]string, len(providers))
|
||||||
for i, p := range providers {
|
for i, p := range providers {
|
||||||
status := " (no key)"
|
status := " (no key)"
|
||||||
if p.HasAPIKey {
|
if p.HasAPIKey {
|
||||||
status = " (ready)"
|
status = " (ready)"
|
||||||
|
if p.Info.EnvKey == "" {
|
||||||
|
status = " (local)"
|
||||||
}
|
}
|
||||||
m.listItems[i] = p.Name + status
|
}
|
||||||
|
m.listItems[i] = p.Info.DisplayName + status
|
||||||
}
|
}
|
||||||
|
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init initializes the model
|
// Init initializes the model.
|
||||||
func (m Model) Init() tea.Cmd {
|
func (m Model) Init() tea.Cmd {
|
||||||
return textinput.Blink
|
return textinput.Blink
|
||||||
}
|
}
|
||||||
|
|
||||||
// selectProvider sets up the selected provider
|
// selectProvider sets up the selected provider.
|
||||||
func (m *Model) selectProvider(index int) error {
|
func (m *Model) selectProvider(index int) error {
|
||||||
if index < 0 || index >= len(m.providers) {
|
if index < 0 || index >= len(m.providers) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
p := m.providers[index]
|
p := m.providers[index]
|
||||||
key, ok := m.apiKeys[p.Name]
|
key := m.apiKeys[p.Info.Name] // empty for key-less providers like Ollama
|
||||||
if !ok || key == "" {
|
|
||||||
|
if p.Info.EnvKey != "" && key == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
m.providerName = p.Name
|
m.providerName = p.Info.DisplayName
|
||||||
m.providerIndex = index
|
m.providerIndex = index
|
||||||
|
m.client = p.Info.New(key)
|
||||||
|
|
||||||
switch p.Name {
|
// Select default model.
|
||||||
case "OpenAI":
|
if len(p.Info.Models) > 0 {
|
||||||
m.provider = llm.OpenAI(key)
|
|
||||||
case "Anthropic":
|
|
||||||
m.provider = llm.Anthropic(key)
|
|
||||||
case "Google":
|
|
||||||
m.provider = llm.Google(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Select default model
|
|
||||||
if len(p.Models) > 0 {
|
|
||||||
return m.selectModel(p.ModelIndex)
|
return m.selectModel(p.ModelIndex)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// selectModel sets the current model
|
// selectModel sets the current model.
|
||||||
func (m *Model) selectModel(index int) error {
|
func (m *Model) selectModel(index int) error {
|
||||||
if m.provider == nil {
|
if m.client == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
p := m.providers[m.providerIndex]
|
p := m.providers[m.providerIndex]
|
||||||
if index < 0 || index >= len(p.Models) {
|
if index < 0 || index >= len(p.Info.Models) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
modelName := p.Models[index]
|
modelName := p.Info.Models[index]
|
||||||
chat, err := m.provider.ModelVersion(modelName)
|
m.chat = m.client.Model(modelName)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
m.chat = chat
|
|
||||||
m.modelName = modelName
|
m.modelName = modelName
|
||||||
m.providers[m.providerIndex].ModelIndex = index
|
m.providers[m.providerIndex].ModelIndex = index
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// newConversation resets the conversation
|
// newConversation resets the conversation.
|
||||||
func (m *Model) newConversation() {
|
func (m *Model) newConversation() {
|
||||||
m.conversation = []llm.Input{}
|
m.conversation = []llm.Message{}
|
||||||
m.messages = []DisplayMessage{}
|
m.messages = []DisplayMessage{}
|
||||||
m.pendingImages = []llm.Image{}
|
m.pendingImages = []llm.Image{}
|
||||||
m.err = nil
|
m.err = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// addUserMessage adds a user message to the conversation
|
// addUserMessage adds a user message to the conversation.
|
||||||
func (m *Model) addUserMessage(text string, images []llm.Image) {
|
func (m *Model) addUserMessage(text string, images []llm.Image) {
|
||||||
msg := llm.Message{
|
msg := llm.Message{
|
||||||
Role: llm.RoleUser,
|
Role: llm.RoleUser,
|
||||||
Text: text,
|
Content: llm.Content{Text: text, Images: images},
|
||||||
Images: images,
|
|
||||||
}
|
}
|
||||||
m.conversation = append(m.conversation, msg)
|
m.conversation = append(m.conversation, msg)
|
||||||
m.messages = append(m.messages, DisplayMessage{
|
m.messages = append(m.messages, DisplayMessage{
|
||||||
@@ -270,7 +220,7 @@ func (m *Model) addUserMessage(text string, images []llm.Image) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// addAssistantMessage adds an assistant message to the conversation
|
// addAssistantMessage adds an assistant message to the conversation display.
|
||||||
func (m *Model) addAssistantMessage(content string) {
|
func (m *Model) addAssistantMessage(content string) {
|
||||||
m.messages = append(m.messages, DisplayMessage{
|
m.messages = append(m.messages, DisplayMessage{
|
||||||
Role: llm.RoleAssistant,
|
Role: llm.RoleAssistant,
|
||||||
@@ -278,7 +228,7 @@ func (m *Model) addAssistantMessage(content string) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// addToolCallMessage adds a tool call message to display
|
// addToolCallMessage adds a tool call message to display.
|
||||||
func (m *Model) addToolCallMessage(name string, args string) {
|
func (m *Model) addToolCallMessage(name string, args string) {
|
||||||
m.messages = append(m.messages, DisplayMessage{
|
m.messages = append(m.messages, DisplayMessage{
|
||||||
Role: llm.Role("tool_call"),
|
Role: llm.Role("tool_call"),
|
||||||
@@ -286,7 +236,7 @@ func (m *Model) addToolCallMessage(name string, args string) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// addToolResultMessage adds a tool result message to display
|
// addToolResultMessage adds a tool result message to display.
|
||||||
func (m *Model) addToolResultMessage(name string, result string) {
|
func (m *Model) addToolResultMessage(name string, result string) {
|
||||||
m.messages = append(m.messages, DisplayMessage{
|
m.messages = append(m.messages, DisplayMessage{
|
||||||
Role: llm.Role("tool_result"),
|
Role: llm.Role("tool_result"),
|
||||||
@@ -0,0 +1,114 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TimeParams is the parameter struct for the GetTime function.
|
||||||
|
type TimeParams struct{}
|
||||||
|
|
||||||
|
// GetTime returns the current time.
|
||||||
|
func GetTime(_ context.Context, _ TimeParams) (string, error) {
|
||||||
|
return time.Now().Format("Monday, January 2, 2006 3:04:05 PM MST"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CalcParams is the parameter struct for the Calculate function.
|
||||||
|
type CalcParams struct {
|
||||||
|
A float64 `json:"a" description:"First number"`
|
||||||
|
B float64 `json:"b" description:"Second number"`
|
||||||
|
Op string `json:"op" description:"Operation: add, subtract, multiply, divide, power, sqrt, mod"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate performs basic math operations.
|
||||||
|
func Calculate(_ context.Context, params CalcParams) (string, error) {
|
||||||
|
var result float64
|
||||||
|
switch strings.ToLower(params.Op) {
|
||||||
|
case "add", "+":
|
||||||
|
result = params.A + params.B
|
||||||
|
case "subtract", "sub", "-":
|
||||||
|
result = params.A - params.B
|
||||||
|
case "multiply", "mul", "*":
|
||||||
|
result = params.A * params.B
|
||||||
|
case "divide", "div", "/":
|
||||||
|
if params.B == 0 {
|
||||||
|
return "", fmt.Errorf("division by zero")
|
||||||
|
}
|
||||||
|
result = params.A / params.B
|
||||||
|
case "power", "pow", "^":
|
||||||
|
result = math.Pow(params.A, params.B)
|
||||||
|
case "sqrt":
|
||||||
|
if params.A < 0 {
|
||||||
|
return "", fmt.Errorf("cannot take square root of negative number")
|
||||||
|
}
|
||||||
|
result = math.Sqrt(params.A)
|
||||||
|
case "mod", "%":
|
||||||
|
result = math.Mod(params.A, params.B)
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("unknown operation: %s", params.Op)
|
||||||
|
}
|
||||||
|
return strconv.FormatFloat(result, 'f', -1, 64), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WeatherParams is the parameter struct for the GetWeather function.
|
||||||
|
type WeatherParams struct {
|
||||||
|
Location string `json:"location" description:"City name or location"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWeather returns mock weather data (for demo purposes).
|
||||||
|
func GetWeather(_ context.Context, params WeatherParams) (string, error) {
|
||||||
|
weathers := []string{"sunny", "cloudy", "rainy", "partly cloudy", "windy"}
|
||||||
|
temps := []int{65, 72, 58, 80, 45}
|
||||||
|
idx := len(params.Location) % len(weathers)
|
||||||
|
|
||||||
|
out := map[string]any{
|
||||||
|
"location": params.Location,
|
||||||
|
"temperature": strconv.Itoa(temps[idx]) + "F",
|
||||||
|
"condition": weathers[idx],
|
||||||
|
"humidity": "45%",
|
||||||
|
"note": "This is mock data for demonstration purposes",
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(out)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RandomNumberParams is the parameter struct for the RandomNumber function.
|
||||||
|
type RandomNumberParams struct {
|
||||||
|
Min int `json:"min" description:"Minimum value (inclusive)"`
|
||||||
|
Max int `json:"max" description:"Maximum value (inclusive)"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RandomNumber generates a pseudo-random number (using current time nanoseconds).
|
||||||
|
func RandomNumber(_ context.Context, params RandomNumberParams) (string, error) {
|
||||||
|
if params.Min > params.Max {
|
||||||
|
return "", fmt.Errorf("min cannot be greater than max")
|
||||||
|
}
|
||||||
|
n := time.Now().UnixNano()
|
||||||
|
rangeSize := params.Max - params.Min + 1
|
||||||
|
result := params.Min + int(n%int64(rangeSize))
|
||||||
|
return strconv.Itoa(result), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createDemoToolbox creates a toolbox with demo tools for testing.
|
||||||
|
func createDemoToolbox() *llm.ToolBox {
|
||||||
|
return llm.NewToolBox(
|
||||||
|
llm.Define[TimeParams]("get_time", "Get the current date and time", GetTime),
|
||||||
|
llm.Define[CalcParams]("calculate",
|
||||||
|
"Perform basic math operations (add, subtract, multiply, divide, power, sqrt, mod)",
|
||||||
|
Calculate),
|
||||||
|
llm.Define[WeatherParams]("get_weather",
|
||||||
|
"Get weather information for a location (demo data)", GetWeather),
|
||||||
|
llm.Define[RandomNumberParams]("random_number",
|
||||||
|
"Generate a random number between min and max", RandomNumber),
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -8,14 +8,14 @@ import (
|
|||||||
"github.com/charmbracelet/bubbles/viewport"
|
"github.com/charmbracelet/bubbles/viewport"
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
|
||||||
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
|
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// pendingRequest stores the request being processed for follow-up
|
// pendingToolCalls stores the last response's tool calls so we can pair them
|
||||||
var pendingRequest llm.Request
|
// with tool execution results for display.
|
||||||
var pendingResponse llm.ResponseChoice
|
var pendingToolCalls []llm.ToolCall
|
||||||
|
|
||||||
// Update handles messages and updates the model
|
// Update handles messages and updates the model.
|
||||||
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
var cmd tea.Cmd
|
var cmd tea.Cmd
|
||||||
var cmds []tea.Cmd
|
var cmds []tea.Cmd
|
||||||
@@ -53,40 +53,30 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(msg.Response.Choices) == 0 {
|
resp := msg.Response
|
||||||
m.err = fmt.Errorf("no response choices returned")
|
|
||||||
return m, nil
|
// Add the assistant message to the conversation history.
|
||||||
|
m.conversation = append(m.conversation, resp.Message())
|
||||||
|
|
||||||
|
// Show any text the assistant produced alongside tool calls.
|
||||||
|
if resp.Text != "" {
|
||||||
|
m.addAssistantMessage(resp.Text)
|
||||||
}
|
}
|
||||||
|
|
||||||
choice := msg.Response.Choices[0]
|
if resp.HasToolCalls() && m.toolsEnabled {
|
||||||
|
pendingToolCalls = resp.ToolCalls
|
||||||
|
|
||||||
// Check for tool calls
|
for _, call := range resp.ToolCalls {
|
||||||
if len(choice.Calls) > 0 && m.toolsEnabled {
|
m.addToolCallMessage(call.Name, call.Arguments)
|
||||||
// Store for follow-up
|
|
||||||
pendingResponse = choice
|
|
||||||
|
|
||||||
// Add assistant's response to conversation if there's content
|
|
||||||
if choice.Content != "" {
|
|
||||||
m.addAssistantMessage(choice.Content)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Display tool calls
|
|
||||||
for _, call := range choice.Calls {
|
|
||||||
m.addToolCallMessage(call.FunctionCall.Name, call.FunctionCall.Arguments)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
m.viewport.SetContent(m.renderMessages())
|
m.viewport.SetContent(m.renderMessages())
|
||||||
m.viewport.GotoBottom()
|
m.viewport.GotoBottom()
|
||||||
|
|
||||||
// Execute tools
|
|
||||||
m.loading = true
|
m.loading = true
|
||||||
return m, executeTools(m.toolbox, pendingRequest, choice)
|
return m, executeTools(m.toolbox, resp.ToolCalls)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Regular response - add to conversation and display
|
|
||||||
m.conversation = append(m.conversation, choice)
|
|
||||||
m.addAssistantMessage(choice.Content)
|
|
||||||
|
|
||||||
m.viewport.SetContent(m.renderMessages())
|
m.viewport.SetContent(m.renderMessages())
|
||||||
m.viewport.GotoBottom()
|
m.viewport.GotoBottom()
|
||||||
|
|
||||||
@@ -97,31 +87,24 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Display tool results
|
// Display results paired with the tool calls that produced them.
|
||||||
for i, result := range msg.Results {
|
for i, result := range msg.Results {
|
||||||
name := pendingResponse.Calls[i].FunctionCall.Name
|
name := ""
|
||||||
resultStr := fmt.Sprintf("%v", result.Result)
|
if i < len(pendingToolCalls) {
|
||||||
if result.Error != nil {
|
name = pendingToolCalls[i].Name
|
||||||
resultStr = "Error: " + result.Error.Error()
|
|
||||||
}
|
}
|
||||||
m.addToolResultMessage(name, resultStr)
|
m.addToolResultMessage(name, result.Content.Text)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add tool call responses to conversation
|
// Append the raw tool result messages to the conversation so the
|
||||||
for _, result := range msg.Results {
|
// assistant can reference them on the next turn.
|
||||||
m.conversation = append(m.conversation, result)
|
m.conversation = append(m.conversation, msg.Results...)
|
||||||
}
|
|
||||||
|
|
||||||
// Add the assistant's response to conversation
|
|
||||||
m.conversation = append(m.conversation, pendingResponse)
|
|
||||||
|
|
||||||
m.viewport.SetContent(m.renderMessages())
|
m.viewport.SetContent(m.renderMessages())
|
||||||
m.viewport.GotoBottom()
|
m.viewport.GotoBottom()
|
||||||
|
|
||||||
// Send follow-up request
|
// Ask the model to continue given the tool results.
|
||||||
followUp := buildFollowUpRequest(&m, pendingRequest, pendingResponse, msg.Results)
|
return m, sendChatRequest(m.chat, m.conversation, m.toolbox, m.toolsEnabled, m.temperature)
|
||||||
pendingRequest = followUp
|
|
||||||
return m, sendChatRequest(m.chat, followUp)
|
|
||||||
|
|
||||||
case ImageLoadedMsg:
|
case ImageLoadedMsg:
|
||||||
if msg.Err != nil {
|
if msg.Err != nil {
|
||||||
@@ -135,7 +118,7 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
m.err = nil
|
m.err = nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
// Update text input
|
// Update text input.
|
||||||
if m.state == StateChat {
|
if m.state == StateChat {
|
||||||
m.input, cmd = m.input.Update(msg)
|
m.input, cmd = m.input.Update(msg)
|
||||||
cmds = append(cmds, cmd)
|
cmds = append(cmds, cmd)
|
||||||
@@ -148,13 +131,11 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
return m, tea.Batch(cmds...)
|
return m, tea.Batch(cmds...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleKeyMsg handles keyboard input
|
// handleKeyMsg handles keyboard input.
|
||||||
func (m Model) handleKeyMsg(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
func (m Model) handleKeyMsg(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||||
// Global key handling
|
|
||||||
switch msg.String() {
|
switch msg.String() {
|
||||||
case "ctrl+c":
|
case "ctrl+c":
|
||||||
return m, tea.Quit
|
return m, tea.Quit
|
||||||
|
|
||||||
case "esc":
|
case "esc":
|
||||||
if m.state != StateChat {
|
if m.state != StateChat {
|
||||||
m.state = StateChat
|
m.state = StateChat
|
||||||
@@ -164,7 +145,6 @@ func (m Model) handleKeyMsg(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|||||||
return m, tea.Quit
|
return m, tea.Quit
|
||||||
}
|
}
|
||||||
|
|
||||||
// State-specific key handling
|
|
||||||
switch m.state {
|
switch m.state {
|
||||||
case StateChat:
|
case StateChat:
|
||||||
return m.handleChatKeys(msg)
|
return m.handleChatKeys(msg)
|
||||||
@@ -185,7 +165,7 @@ func (m Model) handleKeyMsg(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleChatKeys handles keys in chat state
|
// handleChatKeys handles keys in chat state.
|
||||||
func (m Model) handleChatKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
func (m Model) handleChatKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||||
switch msg.String() {
|
switch msg.String() {
|
||||||
case "enter":
|
case "enter":
|
||||||
@@ -203,14 +183,13 @@ func (m Model) handleChatKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build and send request
|
// Ensure a system message is at the head of the conversation.
|
||||||
req := buildRequest(&m, text)
|
if len(m.conversation) == 0 && m.systemPrompt != "" {
|
||||||
pendingRequest = req
|
m.conversation = append(m.conversation, llm.SystemMessage(m.systemPrompt))
|
||||||
|
}
|
||||||
|
|
||||||
// Add user message to display
|
|
||||||
m.addUserMessage(text, m.pendingImages)
|
m.addUserMessage(text, m.pendingImages)
|
||||||
|
|
||||||
// Clear input and pending images
|
|
||||||
m.input.Reset()
|
m.input.Reset()
|
||||||
m.pendingImages = nil
|
m.pendingImages = nil
|
||||||
m.err = nil
|
m.err = nil
|
||||||
@@ -219,7 +198,7 @@ func (m Model) handleChatKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|||||||
m.viewport.SetContent(m.renderMessages())
|
m.viewport.SetContent(m.renderMessages())
|
||||||
m.viewport.GotoBottom()
|
m.viewport.GotoBottom()
|
||||||
|
|
||||||
return m, sendChatRequest(m.chat, req)
|
return m, sendChatRequest(m.chat, m.conversation, m.toolbox, m.toolsEnabled, m.temperature)
|
||||||
|
|
||||||
case "ctrl+i":
|
case "ctrl+i":
|
||||||
m.previousState = StateChat
|
m.previousState = StateChat
|
||||||
@@ -238,12 +217,12 @@ func (m Model) handleChatKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
|
|
||||||
case "ctrl+m":
|
case "ctrl+m":
|
||||||
if m.provider == nil {
|
if m.client == nil {
|
||||||
m.err = fmt.Errorf("select a provider first")
|
m.err = fmt.Errorf("select a provider first")
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
m.state = StateModelSelect
|
m.state = StateModelSelect
|
||||||
m.listItems = m.providers[m.providerIndex].Models
|
m.listItems = m.providers[m.providerIndex].Info.Models
|
||||||
m.listIndex = m.providers[m.providerIndex].ModelIndex
|
m.listIndex = m.providers[m.providerIndex].ModelIndex
|
||||||
return m, nil
|
return m, nil
|
||||||
|
|
||||||
@@ -268,7 +247,7 @@ func (m Model) handleChatKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleProviderSelectKeys handles keys in provider selection state
|
// handleProviderSelectKeys handles keys in provider selection state.
|
||||||
func (m Model) handleProviderSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
func (m Model) handleProviderSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||||
switch msg.String() {
|
switch msg.String() {
|
||||||
case "up", "k":
|
case "up", "k":
|
||||||
@@ -282,15 +261,13 @@ func (m Model) handleProviderSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|||||||
case "enter":
|
case "enter":
|
||||||
p := m.providers[m.listIndex]
|
p := m.providers[m.listIndex]
|
||||||
if !p.HasAPIKey {
|
if !p.HasAPIKey {
|
||||||
// Need to get API key
|
|
||||||
m.state = StateAPIKeyInput
|
m.state = StateAPIKeyInput
|
||||||
m.apiKeyInput.Focus()
|
m.apiKeyInput.Focus()
|
||||||
m.apiKeyInput.SetValue("")
|
m.apiKeyInput.SetValue("")
|
||||||
return m, textinput.Blink
|
return m, textinput.Blink
|
||||||
}
|
}
|
||||||
|
|
||||||
err := m.selectProvider(m.listIndex)
|
if err := m.selectProvider(m.listIndex); err != nil {
|
||||||
if err != nil {
|
|
||||||
m.err = err
|
m.err = err
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
@@ -303,7 +280,7 @@ func (m Model) handleProviderSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleAPIKeyInputKeys handles keys in API key input state
|
// handleAPIKeyInputKeys handles keys in API key input state.
|
||||||
func (m Model) handleAPIKeyInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
func (m Model) handleAPIKeyInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||||
switch msg.String() {
|
switch msg.String() {
|
||||||
case "enter":
|
case "enter":
|
||||||
@@ -312,23 +289,22 @@ func (m Model) handleAPIKeyInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store the API key
|
|
||||||
p := m.providers[m.listIndex]
|
p := m.providers[m.listIndex]
|
||||||
m.apiKeys[p.Name] = key
|
m.apiKeys[p.Info.Name] = key
|
||||||
m.providers[m.listIndex].HasAPIKey = true
|
m.providers[m.listIndex].HasAPIKey = true
|
||||||
|
|
||||||
// Update list items
|
|
||||||
for i, prov := range m.providers {
|
for i, prov := range m.providers {
|
||||||
status := " (no key)"
|
status := " (no key)"
|
||||||
if prov.HasAPIKey {
|
if prov.HasAPIKey {
|
||||||
status = " (ready)"
|
status = " (ready)"
|
||||||
|
if prov.Info.EnvKey == "" {
|
||||||
|
status = " (local)"
|
||||||
}
|
}
|
||||||
m.listItems[i] = prov.Name + status
|
}
|
||||||
|
m.listItems[i] = prov.Info.DisplayName + status
|
||||||
}
|
}
|
||||||
|
|
||||||
// Select the provider
|
if err := m.selectProvider(m.listIndex); err != nil {
|
||||||
err := m.selectProvider(m.listIndex)
|
|
||||||
if err != nil {
|
|
||||||
m.err = err
|
m.err = err
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
@@ -345,7 +321,7 @@ func (m Model) handleAPIKeyInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleModelSelectKeys handles keys in model selection state
|
// handleModelSelectKeys handles keys in model selection state.
|
||||||
func (m Model) handleModelSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
func (m Model) handleModelSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||||
switch msg.String() {
|
switch msg.String() {
|
||||||
case "up", "k":
|
case "up", "k":
|
||||||
@@ -357,8 +333,7 @@ func (m Model) handleModelSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|||||||
m.listIndex++
|
m.listIndex++
|
||||||
}
|
}
|
||||||
case "enter":
|
case "enter":
|
||||||
err := m.selectModel(m.listIndex)
|
if err := m.selectModel(m.listIndex); err != nil {
|
||||||
if err != nil {
|
|
||||||
m.err = err
|
m.err = err
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
@@ -368,7 +343,7 @@ func (m Model) handleModelSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleImageInputKeys handles keys in image input state
|
// handleImageInputKeys handles keys in image input state.
|
||||||
func (m Model) handleImageInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
func (m Model) handleImageInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||||
switch msg.String() {
|
switch msg.String() {
|
||||||
case "enter":
|
case "enter":
|
||||||
@@ -381,12 +356,12 @@ func (m Model) handleImageInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|||||||
|
|
||||||
m.input.Placeholder = "Type your message..."
|
m.input.Placeholder = "Type your message..."
|
||||||
|
|
||||||
// Determine input type and load
|
switch {
|
||||||
if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") {
|
case strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://"):
|
||||||
return m, loadImageFromURL(input)
|
return m, loadImageFromURL(input)
|
||||||
} else if strings.HasPrefix(input, "data:") || len(input) > 100 && !strings.Contains(input, "/") && !strings.Contains(input, "\\") {
|
case strings.HasPrefix(input, "data:") || (len(input) > 100 && !strings.Contains(input, "/") && !strings.Contains(input, "\\")):
|
||||||
return m, loadImageFromBase64(input)
|
return m, loadImageFromBase64(input)
|
||||||
} else {
|
default:
|
||||||
return m, loadImageFromPath(input)
|
return m, loadImageFromPath(input)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -397,7 +372,7 @@ func (m Model) handleImageInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleToolsPanelKeys handles keys in tools panel state
|
// handleToolsPanelKeys handles keys in tools panel state.
|
||||||
func (m Model) handleToolsPanelKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
func (m Model) handleToolsPanelKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||||
switch msg.String() {
|
switch msg.String() {
|
||||||
case "t":
|
case "t":
|
||||||
@@ -409,11 +384,10 @@ func (m Model) handleToolsPanelKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleSettingsKeys handles keys in settings state
|
// handleSettingsKeys handles keys in settings state.
|
||||||
func (m Model) handleSettingsKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
func (m Model) handleSettingsKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||||
switch msg.String() {
|
switch msg.String() {
|
||||||
case "1":
|
case "1":
|
||||||
// Set temperature to nil (default)
|
|
||||||
m.temperature = nil
|
m.temperature = nil
|
||||||
case "2":
|
case "2":
|
||||||
t := 0.0
|
t := 0.0
|
||||||
@@ -6,10 +6,10 @@ import (
|
|||||||
|
|
||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
|
||||||
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
|
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// View renders the current state
|
// View renders the current state.
|
||||||
func (m Model) View() string {
|
func (m Model) View() string {
|
||||||
switch m.state {
|
switch m.state {
|
||||||
case StateProviderSelect:
|
case StateProviderSelect:
|
||||||
@@ -29,11 +29,10 @@ func (m Model) View() string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// renderChat renders the main chat view
|
// renderChat renders the main chat view.
|
||||||
func (m Model) renderChat() string {
|
func (m Model) renderChat() string {
|
||||||
var b strings.Builder
|
var b strings.Builder
|
||||||
|
|
||||||
// Header
|
|
||||||
provider := m.providerName
|
provider := m.providerName
|
||||||
if provider == "" {
|
if provider == "" {
|
||||||
provider = "None"
|
provider = "None"
|
||||||
@@ -49,43 +48,37 @@ func (m Model) renderChat() string {
|
|||||||
b.WriteString(header)
|
b.WriteString(header)
|
||||||
b.WriteString("\n")
|
b.WriteString("\n")
|
||||||
|
|
||||||
// Messages viewport
|
|
||||||
if m.viewportReady {
|
if m.viewportReady {
|
||||||
b.WriteString(m.viewport.View())
|
b.WriteString(m.viewport.View())
|
||||||
b.WriteString("\n")
|
b.WriteString("\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Image indicator
|
|
||||||
if len(m.pendingImages) > 0 {
|
if len(m.pendingImages) > 0 {
|
||||||
b.WriteString(imageIndicatorStyle.Render(fmt.Sprintf(" [%d image(s) attached]", len(m.pendingImages))))
|
b.WriteString(imageIndicatorStyle.Render(fmt.Sprintf(" [%d image(s) attached]", len(m.pendingImages))))
|
||||||
b.WriteString("\n")
|
b.WriteString("\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error
|
|
||||||
if m.err != nil {
|
if m.err != nil {
|
||||||
b.WriteString(errorStyle.Render(" Error: " + m.err.Error()))
|
b.WriteString(errorStyle.Render(" Error: " + m.err.Error()))
|
||||||
b.WriteString("\n")
|
b.WriteString("\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Loading
|
|
||||||
if m.loading {
|
if m.loading {
|
||||||
b.WriteString(loadingStyle.Render(" Thinking..."))
|
b.WriteString(loadingStyle.Render(" Thinking..."))
|
||||||
b.WriteString("\n")
|
b.WriteString("\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Input
|
|
||||||
inputBox := inputStyle.Render(m.input.View())
|
inputBox := inputStyle.Render(m.input.View())
|
||||||
b.WriteString(inputBox)
|
b.WriteString(inputBox)
|
||||||
b.WriteString("\n")
|
b.WriteString("\n")
|
||||||
|
|
||||||
// Help
|
|
||||||
help := inputHelpStyle.Render("Enter: send | Ctrl+I: image | Ctrl+T: tools | Ctrl+P: provider | Ctrl+M: model | Ctrl+S: settings | Ctrl+N: new | Esc: quit")
|
help := inputHelpStyle.Render("Enter: send | Ctrl+I: image | Ctrl+T: tools | Ctrl+P: provider | Ctrl+M: model | Ctrl+S: settings | Ctrl+N: new | Esc: quit")
|
||||||
b.WriteString(help)
|
b.WriteString(help)
|
||||||
|
|
||||||
return appStyle.Render(b.String())
|
return appStyle.Render(b.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// renderMessages renders all messages for the viewport
|
// renderMessages renders all messages for the viewport.
|
||||||
func (m Model) renderMessages() string {
|
func (m Model) renderMessages() string {
|
||||||
var b strings.Builder
|
var b strings.Builder
|
||||||
|
|
||||||
@@ -133,7 +126,7 @@ func (m Model) renderMessages() string {
|
|||||||
return b.String()
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// renderProviderSelect renders the provider selection view
|
// renderProviderSelect renders the provider selection view.
|
||||||
func (m Model) renderProviderSelect() string {
|
func (m Model) renderProviderSelect() string {
|
||||||
var b strings.Builder
|
var b strings.Builder
|
||||||
|
|
||||||
@@ -157,16 +150,18 @@ func (m Model) renderProviderSelect() string {
|
|||||||
return appStyle.Render(b.String())
|
return appStyle.Render(b.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// renderAPIKeyInput renders the API key input view
|
// renderAPIKeyInput renders the API key input view.
|
||||||
func (m Model) renderAPIKeyInput() string {
|
func (m Model) renderAPIKeyInput() string {
|
||||||
var b strings.Builder
|
var b strings.Builder
|
||||||
|
|
||||||
provider := m.providers[m.listIndex]
|
provider := m.providers[m.listIndex]
|
||||||
|
|
||||||
b.WriteString(headerStyle.Render(fmt.Sprintf("Enter API Key for %s", provider.Name)))
|
b.WriteString(headerStyle.Render(fmt.Sprintf("Enter API Key for %s", provider.Info.DisplayName)))
|
||||||
b.WriteString("\n\n")
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
b.WriteString(fmt.Sprintf("Environment variable: %s\n\n", provider.EnvVar))
|
if provider.Info.EnvKey != "" {
|
||||||
|
b.WriteString(fmt.Sprintf("Environment variable: %s\n\n", provider.Info.EnvKey))
|
||||||
|
}
|
||||||
b.WriteString("Enter your API key below (it will be hidden):\n\n")
|
b.WriteString("Enter your API key below (it will be hidden):\n\n")
|
||||||
|
|
||||||
inputBox := inputStyle.Render(m.apiKeyInput.View())
|
inputBox := inputStyle.Render(m.apiKeyInput.View())
|
||||||
@@ -178,7 +173,7 @@ func (m Model) renderAPIKeyInput() string {
|
|||||||
return appStyle.Render(b.String())
|
return appStyle.Render(b.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// renderModelSelect renders the model selection view
|
// renderModelSelect renders the model selection view.
|
||||||
func (m Model) renderModelSelect() string {
|
func (m Model) renderModelSelect() string {
|
||||||
var b strings.Builder
|
var b strings.Builder
|
||||||
|
|
||||||
@@ -205,7 +200,7 @@ func (m Model) renderModelSelect() string {
|
|||||||
return appStyle.Render(b.String())
|
return appStyle.Render(b.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// renderImageInput renders the image input view
|
// renderImageInput renders the image input view.
|
||||||
func (m Model) renderImageInput() string {
|
func (m Model) renderImageInput() string {
|
||||||
var b strings.Builder
|
var b strings.Builder
|
||||||
|
|
||||||
@@ -230,7 +225,7 @@ func (m Model) renderImageInput() string {
|
|||||||
return appStyle.Render(b.String())
|
return appStyle.Render(b.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// renderToolsPanel renders the tools panel
|
// renderToolsPanel renders the tools panel.
|
||||||
func (m Model) renderToolsPanel() string {
|
func (m Model) renderToolsPanel() string {
|
||||||
var b strings.Builder
|
var b strings.Builder
|
||||||
|
|
||||||
@@ -249,8 +244,10 @@ func (m Model) renderToolsPanel() string {
|
|||||||
b.WriteString("\n\n")
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
b.WriteString("Available tools:\n")
|
b.WriteString("Available tools:\n")
|
||||||
for _, fn := range m.toolbox.Functions() {
|
if m.toolbox != nil {
|
||||||
b.WriteString(fmt.Sprintf(" - %s: %s\n", selectedItemStyle.Render(fn.Name), fn.Description))
|
for _, t := range m.toolbox.AllTools() {
|
||||||
|
b.WriteString(fmt.Sprintf(" - %s: %s\n", selectedItemStyle.Render(t.Name), t.Description))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
b.WriteString("\n")
|
b.WriteString("\n")
|
||||||
@@ -259,14 +256,13 @@ func (m Model) renderToolsPanel() string {
|
|||||||
return appStyle.Render(b.String())
|
return appStyle.Render(b.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// renderSettings renders the settings view
|
// renderSettings renders the settings view.
|
||||||
func (m Model) renderSettings() string {
|
func (m Model) renderSettings() string {
|
||||||
var b strings.Builder
|
var b strings.Builder
|
||||||
|
|
||||||
b.WriteString(headerStyle.Render("Settings"))
|
b.WriteString(headerStyle.Render("Settings"))
|
||||||
b.WriteString("\n\n")
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
// Temperature
|
|
||||||
tempStr := "default"
|
tempStr := "default"
|
||||||
if m.temperature != nil {
|
if m.temperature != nil {
|
||||||
tempStr = fmt.Sprintf("%.1f", *m.temperature)
|
tempStr = fmt.Sprintf("%.1f", *m.temperature)
|
||||||
@@ -284,7 +280,6 @@ func (m Model) renderSettings() string {
|
|||||||
|
|
||||||
b.WriteString("\n")
|
b.WriteString("\n")
|
||||||
|
|
||||||
// System prompt
|
|
||||||
b.WriteString(settingLabelStyle.Render("System Prompt:"))
|
b.WriteString(settingLabelStyle.Render("System Prompt:"))
|
||||||
b.WriteString("\n")
|
b.WriteString("\n")
|
||||||
b.WriteString(settingValueStyle.Render(" " + m.systemPrompt))
|
b.WriteString(settingValueStyle.Render(" " + m.systemPrompt))
|
||||||
@@ -2,8 +2,13 @@ package llm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
anthProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/anthropic"
|
anthProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/anthropic"
|
||||||
|
deepseekProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/deepseek"
|
||||||
googleProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/google"
|
googleProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/google"
|
||||||
|
groqProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/groq"
|
||||||
|
moonshotProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/moonshot"
|
||||||
|
ollamaProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/ollama"
|
||||||
openaiProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openai"
|
openaiProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openai"
|
||||||
|
xaiProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/xai"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OpenAI creates an OpenAI client.
|
// OpenAI creates an OpenAI client.
|
||||||
@@ -46,3 +51,69 @@ func Google(apiKey string, opts ...ClientOption) *Client {
|
|||||||
_ = cfg // Google doesn't support custom base URL in the SDK
|
_ = cfg // Google doesn't support custom base URL in the SDK
|
||||||
return NewClient(googleProvider.New(apiKey))
|
return NewClient(googleProvider.New(apiKey))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeepSeek creates a DeepSeek client (OpenAI-compatible).
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// model := llm.DeepSeek("sk-...").Model("deepseek-chat")
|
||||||
|
func DeepSeek(apiKey string, opts ...ClientOption) *Client {
|
||||||
|
cfg := &clientConfig{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(cfg)
|
||||||
|
}
|
||||||
|
return NewClient(deepseekProvider.New(apiKey, cfg.baseURL))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Moonshot creates a Moonshot AI (Kimi) client (OpenAI-compatible).
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// model := llm.Moonshot("sk-...").Model("kimi-k2-0711-preview")
|
||||||
|
func Moonshot(apiKey string, opts ...ClientOption) *Client {
|
||||||
|
cfg := &clientConfig{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(cfg)
|
||||||
|
}
|
||||||
|
return NewClient(moonshotProvider.New(apiKey, cfg.baseURL))
|
||||||
|
}
|
||||||
|
|
||||||
|
// XAI creates an xAI (Grok) client (OpenAI-compatible).
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// model := llm.XAI("xai-...").Model("grok-2")
|
||||||
|
func XAI(apiKey string, opts ...ClientOption) *Client {
|
||||||
|
cfg := &clientConfig{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(cfg)
|
||||||
|
}
|
||||||
|
return NewClient(xaiProvider.New(apiKey, cfg.baseURL))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Groq creates a Groq client (OpenAI-compatible).
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// model := llm.Groq("gsk-...").Model("llama-3.3-70b-versatile")
|
||||||
|
func Groq(apiKey string, opts ...ClientOption) *Client {
|
||||||
|
cfg := &clientConfig{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(cfg)
|
||||||
|
}
|
||||||
|
return NewClient(groqProvider.New(apiKey, cfg.baseURL))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ollama creates a client for a local Ollama instance (OpenAI-compatible).
|
||||||
|
// No API key is required. Use WithBaseURL to point at a non-default host/port.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// model := llm.Ollama().Model("llama3.2")
|
||||||
|
func Ollama(opts ...ClientOption) *Client {
|
||||||
|
cfg := &clientConfig{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(cfg)
|
||||||
|
}
|
||||||
|
return NewClient(ollamaProvider.New("", cfg.baseURL))
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,36 @@
|
|||||||
|
// Package deepseek implements the go-llm v2 provider interface for DeepSeek
|
||||||
|
// (https://platform.deepseek.com). DeepSeek speaks the OpenAI Chat Completions
|
||||||
|
// protocol, so this package is a thin wrapper around openaicompat with its own
|
||||||
|
// defaults and per-model Rules.
|
||||||
|
package deepseek
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultBaseURL is the public DeepSeek API endpoint.
|
||||||
|
const DefaultBaseURL = "https://api.deepseek.com/v1"
|
||||||
|
|
||||||
|
// Provider is a type alias over openaicompat.Provider.
|
||||||
|
type Provider = openaicompat.Provider
|
||||||
|
|
||||||
|
// New creates a new DeepSeek provider. An empty baseURL uses DefaultBaseURL.
|
||||||
|
func New(apiKey, baseURL string) *Provider {
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = DefaultBaseURL
|
||||||
|
}
|
||||||
|
return openaicompat.New(apiKey, baseURL, openaicompat.Rules{
|
||||||
|
// DeepSeek's chat and reasoner models are text-only.
|
||||||
|
SupportsVision: func(string) bool { return false },
|
||||||
|
// Reasoner doesn't accept tool calls.
|
||||||
|
SupportsTools: func(m string) bool {
|
||||||
|
return !strings.Contains(m, "reasoner")
|
||||||
|
},
|
||||||
|
// Reasoner rejects user-supplied temperature.
|
||||||
|
RestrictTemperature: func(m string) bool {
|
||||||
|
return strings.Contains(m, "reasoner")
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
package deepseek_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/deepseek"
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNew_DefaultBaseURL(t *testing.T) {
|
||||||
|
if p := deepseek.New("key", ""); p == nil {
|
||||||
|
t.Fatal("New returned nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRules_ReasonerRejectsTools(t *testing.T) {
|
||||||
|
p := deepseek.New("key", "")
|
||||||
|
req := provider.Request{
|
||||||
|
Model: "deepseek-reasoner",
|
||||||
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
||||||
|
Tools: []provider.ToolDef{
|
||||||
|
{Name: "x", Schema: map[string]any{"type": "object"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, err := p.Complete(context.Background(), req)
|
||||||
|
var fue *openaicompat.FeatureUnsupportedError
|
||||||
|
if !errors.As(err, &fue) || fue.Feature != "tools" {
|
||||||
|
t.Fatalf("want FeatureUnsupportedError(tools), got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRules_ChatRejectsImages(t *testing.T) {
|
||||||
|
p := deepseek.New("key", "")
|
||||||
|
req := provider.Request{
|
||||||
|
Model: "deepseek-chat",
|
||||||
|
Messages: []provider.Message{{
|
||||||
|
Role: "user",
|
||||||
|
Images: []provider.Image{{URL: "a"}},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
_, err := p.Complete(context.Background(), req)
|
||||||
|
var fue *openaicompat.FeatureUnsupportedError
|
||||||
|
if !errors.As(err, &fue) || fue.Feature != "vision" {
|
||||||
|
t.Fatalf("want FeatureUnsupportedError(vision), got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
module gitea.stevedudenhoeffer.com/steve/go-llm/v2
|
module gitea.stevedudenhoeffer.com/steve/go-llm/v2
|
||||||
|
|
||||||
go 1.24.0
|
go 1.24.2
|
||||||
|
|
||||||
toolchain go1.24.2
|
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/charmbracelet/bubbles v1.0.0
|
||||||
|
github.com/charmbracelet/bubbletea v1.3.10
|
||||||
|
github.com/charmbracelet/lipgloss v1.1.0
|
||||||
|
github.com/joho/godotenv v1.5.1
|
||||||
github.com/liushuangls/go-anthropic/v2 v2.17.0
|
github.com/liushuangls/go-anthropic/v2 v2.17.0
|
||||||
github.com/modelcontextprotocol/go-sdk v1.2.0
|
github.com/modelcontextprotocol/go-sdk v1.2.0
|
||||||
github.com/openai/openai-go v1.12.0
|
github.com/openai/openai-go v1.12.0
|
||||||
@@ -18,6 +20,16 @@ require (
|
|||||||
cloud.google.com/go v0.116.0 // indirect
|
cloud.google.com/go v0.116.0 // indirect
|
||||||
cloud.google.com/go/auth v0.9.3 // indirect
|
cloud.google.com/go/auth v0.9.3 // indirect
|
||||||
cloud.google.com/go/compute/metadata v0.5.0 // indirect
|
cloud.google.com/go/compute/metadata v0.5.0 // indirect
|
||||||
|
github.com/atotto/clipboard v0.1.4 // indirect
|
||||||
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||||
|
github.com/charmbracelet/colorprofile v0.4.1 // indirect
|
||||||
|
github.com/charmbracelet/x/ansi v0.11.6 // indirect
|
||||||
|
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||||
|
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||||
|
github.com/clipperhouse/displaywidth v0.9.0 // indirect
|
||||||
|
github.com/clipperhouse/stringish v0.1.1 // indirect
|
||||||
|
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
|
||||||
|
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||||
github.com/google/go-cmp v0.7.0 // indirect
|
github.com/google/go-cmp v0.7.0 // indirect
|
||||||
github.com/google/jsonschema-go v0.3.0 // indirect
|
github.com/google/jsonschema-go v0.3.0 // indirect
|
||||||
@@ -25,15 +37,24 @@ require (
|
|||||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
|
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
|
||||||
github.com/gorilla/websocket v1.5.3 // indirect
|
github.com/gorilla/websocket v1.5.3 // indirect
|
||||||
github.com/kr/fs v0.1.0 // indirect
|
github.com/kr/fs v0.1.0 // indirect
|
||||||
|
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||||
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||||
|
github.com/mattn/go-runewidth v0.0.19 // indirect
|
||||||
|
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||||
|
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||||
|
github.com/muesli/termenv v0.16.0 // indirect
|
||||||
|
github.com/rivo/uniseg v0.4.7 // indirect
|
||||||
github.com/tidwall/gjson v1.14.4 // indirect
|
github.com/tidwall/gjson v1.14.4 // indirect
|
||||||
github.com/tidwall/match v1.1.1 // indirect
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
github.com/tidwall/pretty v1.2.1 // indirect
|
github.com/tidwall/pretty v1.2.1 // indirect
|
||||||
github.com/tidwall/sjson v1.2.5 // indirect
|
github.com/tidwall/sjson v1.2.5 // indirect
|
||||||
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||||
go.opencensus.io v0.24.0 // indirect
|
go.opencensus.io v0.24.0 // indirect
|
||||||
golang.org/x/net v0.42.0 // indirect
|
golang.org/x/net v0.42.0 // indirect
|
||||||
golang.org/x/oauth2 v0.30.0 // indirect
|
golang.org/x/oauth2 v0.30.0 // indirect
|
||||||
golang.org/x/sys v0.35.0 // indirect
|
golang.org/x/sys v0.38.0 // indirect
|
||||||
golang.org/x/text v0.33.0 // indirect
|
golang.org/x/text v0.33.0 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
|
||||||
google.golang.org/grpc v1.66.2 // indirect
|
google.golang.org/grpc v1.66.2 // indirect
|
||||||
|
|||||||
@@ -6,8 +6,32 @@ cloud.google.com/go/auth v0.9.3/go.mod h1:7z6VY+7h3KUdRov5F1i8NDP5ZzWKYmEPO842Bg
|
|||||||
cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY=
|
cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY=
|
||||||
cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY=
|
cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY=
|
||||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||||
|
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||||
|
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||||
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||||
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||||
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
||||||
|
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
|
||||||
|
github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
|
||||||
|
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
||||||
|
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
|
||||||
|
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
|
||||||
|
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
|
||||||
|
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||||
|
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||||
|
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
|
||||||
|
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
|
||||||
|
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
||||||
|
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
|
||||||
|
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
|
||||||
|
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
|
||||||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||||
|
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
|
||||||
|
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
|
||||||
|
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
|
||||||
|
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
||||||
|
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
|
||||||
|
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
||||||
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
@@ -16,6 +40,8 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF
|
|||||||
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||||
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
||||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||||
|
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||||
|
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||||
@@ -49,12 +75,28 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gT
|
|||||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA=
|
github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA=
|
||||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
|
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||||
|
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||||
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
|
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
|
||||||
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
|
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
|
||||||
github.com/liushuangls/go-anthropic/v2 v2.17.0 h1:iBA6h7aghi1q86owEQ95XE2R2MF/0dQ7bCxtwTxOg4c=
|
github.com/liushuangls/go-anthropic/v2 v2.17.0 h1:iBA6h7aghi1q86owEQ95XE2R2MF/0dQ7bCxtwTxOg4c=
|
||||||
github.com/liushuangls/go-anthropic/v2 v2.17.0/go.mod h1:a550cJXPoTG2FL3DvfKG2zzD5O2vjgvo4tHtoGPzFLU=
|
github.com/liushuangls/go-anthropic/v2 v2.17.0/go.mod h1:a550cJXPoTG2FL3DvfKG2zzD5O2vjgvo4tHtoGPzFLU=
|
||||||
|
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||||
|
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||||
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||||
|
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||||
|
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
|
||||||
|
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||||
github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s=
|
github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s=
|
||||||
github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10=
|
github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10=
|
||||||
|
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||||
|
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||||
|
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||||
|
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||||
|
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||||
|
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||||
github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0=
|
github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0=
|
||||||
github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
|
github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
|
||||||
github.com/pkg/sftp v1.13.10 h1:+5FbKNTe5Z9aspU88DPIKJ9z2KZoaGCu6Sr6kKR/5mU=
|
github.com/pkg/sftp v1.13.10 h1:+5FbKNTe5Z9aspU88DPIKJ9z2KZoaGCu6Sr6kKR/5mU=
|
||||||
@@ -62,6 +104,8 @@ github.com/pkg/sftp v1.13.10/go.mod h1:bJ1a7uDhrX/4OII+agvy28lzRvQrmIQuaHrcI1Hbe
|
|||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||||
|
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||||
|
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
@@ -80,6 +124,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
|||||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||||
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||||
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||||
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
|
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
|
||||||
@@ -89,6 +135,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
|
|||||||
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
||||||
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
|
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
|
||||||
|
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
|
||||||
golang.org/x/image v0.35.0 h1:LKjiHdgMtO8z7Fh18nGY6KDcoEtVfsgLDPeLyguqb7I=
|
golang.org/x/image v0.35.0 h1:LKjiHdgMtO8z7Fh18nGY6KDcoEtVfsgLDPeLyguqb7I=
|
||||||
golang.org/x/image v0.35.0/go.mod h1:MwPLTVgvxSASsxdLzKrl8BRFuyqMyGhLwmC+TO1Sybk=
|
golang.org/x/image v0.35.0/go.mod h1:MwPLTVgvxSASsxdLzKrl8BRFuyqMyGhLwmC+TO1Sybk=
|
||||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||||
@@ -114,8 +162,10 @@ golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5h
|
|||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||||
|
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4=
|
golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4=
|
||||||
golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw=
|
golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
|||||||
@@ -0,0 +1,33 @@
|
|||||||
|
// Package groq implements the go-llm v2 provider interface for Groq
|
||||||
|
// (https://console.groq.com). Groq hosts open-source models behind an OpenAI
|
||||||
|
// Chat Completions-compatible endpoint, so this package is a thin wrapper over
|
||||||
|
// openaicompat with its own defaults and per-model Rules.
|
||||||
|
package groq
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultBaseURL is the public Groq OpenAI-compatible endpoint.
|
||||||
|
const DefaultBaseURL = "https://api.groq.com/openai/v1"
|
||||||
|
|
||||||
|
// Provider is a type alias over openaicompat.Provider.
|
||||||
|
type Provider = openaicompat.Provider
|
||||||
|
|
||||||
|
// New creates a new Groq provider. An empty baseURL uses DefaultBaseURL.
|
||||||
|
func New(apiKey, baseURL string) *Provider {
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = DefaultBaseURL
|
||||||
|
}
|
||||||
|
return openaicompat.New(apiKey, baseURL, openaicompat.Rules{
|
||||||
|
// Only Groq-hosted vision variants (e.g. *-vision-preview) accept images.
|
||||||
|
SupportsVision: func(m string) bool {
|
||||||
|
return strings.Contains(m, "vision")
|
||||||
|
},
|
||||||
|
// Chat completions endpoint does not accept audio input; audio is via
|
||||||
|
// dedicated transcription endpoints, which go-llm doesn't cover here.
|
||||||
|
SupportsAudio: func(string) bool { return false },
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
package groq_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/groq"
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNew_Basic(t *testing.T) {
|
||||||
|
if p := groq.New("key", ""); p == nil {
|
||||||
|
t.Fatal("New returned nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRules_AudioRejected(t *testing.T) {
|
||||||
|
p := groq.New("key", "")
|
||||||
|
req := provider.Request{
|
||||||
|
Model: "llama-3.3-70b-versatile",
|
||||||
|
Messages: []provider.Message{{
|
||||||
|
Role: "user",
|
||||||
|
Audio: []provider.Audio{{Base64: "AAA=", ContentType: "audio/wav"}},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
_, err := p.Complete(context.Background(), req)
|
||||||
|
var fue *openaicompat.FeatureUnsupportedError
|
||||||
|
if !errors.As(err, &fue) || fue.Feature != "audio" {
|
||||||
|
t.Fatalf("want FeatureUnsupportedError(audio), got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
// Package moonshot implements the go-llm v2 provider interface for Moonshot
|
||||||
|
// AI (Kimi, https://platform.moonshot.ai). Moonshot speaks OpenAI Chat
|
||||||
|
// Completions, so this package is a thin wrapper over openaicompat with its
|
||||||
|
// own defaults and per-model Rules.
|
||||||
|
package moonshot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultBaseURL is the public Moonshot API endpoint (international).
|
||||||
|
const DefaultBaseURL = "https://api.moonshot.ai/v1"
|
||||||
|
|
||||||
|
// Provider is a type alias over openaicompat.Provider.
|
||||||
|
type Provider = openaicompat.Provider
|
||||||
|
|
||||||
|
// New creates a new Moonshot provider. An empty baseURL uses DefaultBaseURL.
|
||||||
|
func New(apiKey, baseURL string) *Provider {
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = DefaultBaseURL
|
||||||
|
}
|
||||||
|
return openaicompat.New(apiKey, baseURL, openaicompat.Rules{
|
||||||
|
// Only Moonshot models whose name contains "vision" accept images.
|
||||||
|
SupportsVision: func(m string) bool {
|
||||||
|
return strings.Contains(m, "vision")
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
package moonshot_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/moonshot"
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNew_Basic(t *testing.T) {
|
||||||
|
if p := moonshot.New("key", ""); p == nil {
|
||||||
|
t.Fatal("New returned nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRules_NonVisionModelRejectsImages(t *testing.T) {
|
||||||
|
p := moonshot.New("key", "")
|
||||||
|
req := provider.Request{
|
||||||
|
Model: "moonshot-v1-8k",
|
||||||
|
Messages: []provider.Message{{
|
||||||
|
Role: "user",
|
||||||
|
Images: []provider.Image{{URL: "a"}},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
_, err := p.Complete(context.Background(), req)
|
||||||
|
var fue *openaicompat.FeatureUnsupportedError
|
||||||
|
if !errors.As(err, &fue) || fue.Feature != "vision" {
|
||||||
|
t.Fatalf("want FeatureUnsupportedError(vision), got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
// Package ollama implements the go-llm v2 provider interface for Ollama
|
||||||
|
// (https://ollama.com), a local model runner that exposes an OpenAI Chat
|
||||||
|
// Completions-compatible endpoint. No API key is required; capability depends
|
||||||
|
// on whichever model the user has pulled locally, so Rules are intentionally
|
||||||
|
// empty — we trust the local user.
|
||||||
|
package ollama
|
||||||
|
|
||||||
|
import (
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultBaseURL points at a local Ollama instance with default port.
|
||||||
|
const DefaultBaseURL = "http://localhost:11434/v1"
|
||||||
|
|
||||||
|
// Provider is a type alias over openaicompat.Provider.
|
||||||
|
type Provider = openaicompat.Provider
|
||||||
|
|
||||||
|
// New creates a new Ollama provider. An empty baseURL uses DefaultBaseURL.
|
||||||
|
// Ollama ignores the API key; callers may pass "".
|
||||||
|
func New(apiKey, baseURL string) *Provider {
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = DefaultBaseURL
|
||||||
|
}
|
||||||
|
return openaicompat.New(apiKey, baseURL, openaicompat.Rules{})
|
||||||
|
}
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
package ollama_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/ollama"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNew_NoKeyNeeded(t *testing.T) {
|
||||||
|
if p := ollama.New("", ""); p == nil {
|
||||||
|
t.Fatal("New returned nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
+20
-418
@@ -1,433 +1,35 @@
|
|||||||
// Package openai implements the go-llm v2 provider interface for OpenAI.
|
// Package openai implements the go-llm v2 provider interface for OpenAI.
|
||||||
|
//
|
||||||
|
// The actual wire-protocol logic lives in the shared openaicompat package;
|
||||||
|
// this file encodes OpenAI-specific Rules (temperature is rejected on o-series
|
||||||
|
// and gpt-5* models) and supplies the default base URL.
|
||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"path"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/openai/openai-go"
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||||
"github.com/openai/openai-go/option"
|
|
||||||
"github.com/openai/openai-go/packages/param"
|
|
||||||
"github.com/openai/openai-go/shared"
|
|
||||||
|
|
||||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Provider implements the provider.Provider interface for OpenAI.
|
// DefaultBaseURL is the public OpenAI Chat Completions endpoint.
|
||||||
type Provider struct {
|
const DefaultBaseURL = "https://api.openai.com/v1"
|
||||||
apiKey string
|
|
||||||
baseURL string
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new OpenAI provider.
|
// Provider is the OpenAI chat-completion provider. It's a type alias over
|
||||||
|
// openaicompat.Provider so existing callers using openai.Provider keep compiling.
|
||||||
|
type Provider = openaicompat.Provider
|
||||||
|
|
||||||
|
// New creates a new OpenAI provider. An empty baseURL uses DefaultBaseURL.
|
||||||
func New(apiKey string, baseURL string) *Provider {
|
func New(apiKey string, baseURL string) *Provider {
|
||||||
return &Provider{apiKey: apiKey, baseURL: baseURL}
|
if baseURL == "" {
|
||||||
}
|
baseURL = DefaultBaseURL
|
||||||
|
|
||||||
// Complete performs a non-streaming completion.
|
|
||||||
func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
|
|
||||||
var opts []option.RequestOption
|
|
||||||
opts = append(opts, option.WithAPIKey(p.apiKey))
|
|
||||||
if p.baseURL != "" {
|
|
||||||
opts = append(opts, option.WithBaseURL(p.baseURL))
|
|
||||||
}
|
}
|
||||||
|
return openaicompat.New(apiKey, baseURL, openaicompat.Rules{
|
||||||
cl := openai.NewClient(opts...)
|
RestrictTemperature: restrictTemperature,
|
||||||
oaiReq := p.buildRequest(req)
|
|
||||||
|
|
||||||
resp, err := cl.Chat.Completions.New(ctx, oaiReq)
|
|
||||||
if err != nil {
|
|
||||||
return provider.Response{}, fmt.Errorf("openai completion error: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.convertResponse(resp), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stream performs a streaming completion.
|
|
||||||
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
|
|
||||||
var opts []option.RequestOption
|
|
||||||
opts = append(opts, option.WithAPIKey(p.apiKey))
|
|
||||||
if p.baseURL != "" {
|
|
||||||
opts = append(opts, option.WithBaseURL(p.baseURL))
|
|
||||||
}
|
|
||||||
|
|
||||||
cl := openai.NewClient(opts...)
|
|
||||||
oaiReq := p.buildRequest(req)
|
|
||||||
oaiReq.StreamOptions = openai.ChatCompletionStreamOptionsParam{
|
|
||||||
IncludeUsage: openai.Bool(true),
|
|
||||||
}
|
|
||||||
|
|
||||||
stream := cl.Chat.Completions.NewStreaming(ctx, oaiReq)
|
|
||||||
|
|
||||||
var fullText strings.Builder
|
|
||||||
var toolCalls []provider.ToolCall
|
|
||||||
toolCallArgs := map[int]*strings.Builder{}
|
|
||||||
var usage *provider.Usage
|
|
||||||
|
|
||||||
for stream.Next() {
|
|
||||||
chunk := stream.Current()
|
|
||||||
|
|
||||||
// Capture usage from the final chunk (present when StreamOptions.IncludeUsage is true)
|
|
||||||
if chunk.Usage.TotalTokens > 0 {
|
|
||||||
usage = &provider.Usage{
|
|
||||||
InputTokens: int(chunk.Usage.PromptTokens),
|
|
||||||
OutputTokens: int(chunk.Usage.CompletionTokens),
|
|
||||||
TotalTokens: int(chunk.Usage.TotalTokens),
|
|
||||||
Details: extractUsageDetails(chunk.Usage),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, choice := range chunk.Choices {
|
|
||||||
// Text delta
|
|
||||||
if choice.Delta.Content != "" {
|
|
||||||
fullText.WriteString(choice.Delta.Content)
|
|
||||||
events <- provider.StreamEvent{
|
|
||||||
Type: provider.StreamEventText,
|
|
||||||
Text: choice.Delta.Content,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tool call deltas
|
|
||||||
for _, tc := range choice.Delta.ToolCalls {
|
|
||||||
idx := int(tc.Index)
|
|
||||||
|
|
||||||
if tc.ID != "" {
|
|
||||||
// New tool call starting
|
|
||||||
for len(toolCalls) <= idx {
|
|
||||||
toolCalls = append(toolCalls, provider.ToolCall{})
|
|
||||||
}
|
|
||||||
toolCalls[idx].ID = tc.ID
|
|
||||||
toolCalls[idx].Name = tc.Function.Name
|
|
||||||
toolCallArgs[idx] = &strings.Builder{}
|
|
||||||
|
|
||||||
events <- provider.StreamEvent{
|
|
||||||
Type: provider.StreamEventToolStart,
|
|
||||||
ToolCall: &provider.ToolCall{
|
|
||||||
ID: tc.ID,
|
|
||||||
Name: tc.Function.Name,
|
|
||||||
},
|
|
||||||
ToolIndex: idx,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if tc.Function.Arguments != "" {
|
|
||||||
if b, ok := toolCallArgs[idx]; ok {
|
|
||||||
b.WriteString(tc.Function.Arguments)
|
|
||||||
}
|
|
||||||
events <- provider.StreamEvent{
|
|
||||||
Type: provider.StreamEventToolDelta,
|
|
||||||
ToolIndex: idx,
|
|
||||||
ToolCall: &provider.ToolCall{
|
|
||||||
Arguments: tc.Function.Arguments,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := stream.Err(); err != nil {
|
|
||||||
return fmt.Errorf("openai stream error: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finalize tool calls
|
|
||||||
for idx := range toolCalls {
|
|
||||||
if b, ok := toolCallArgs[idx]; ok {
|
|
||||||
toolCalls[idx].Arguments = b.String()
|
|
||||||
}
|
|
||||||
events <- provider.StreamEvent{
|
|
||||||
Type: provider.StreamEventToolEnd,
|
|
||||||
ToolIndex: idx,
|
|
||||||
ToolCall: &toolCalls[idx],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send done event
|
|
||||||
events <- provider.StreamEvent{
|
|
||||||
Type: provider.StreamEventDone,
|
|
||||||
Response: &provider.Response{
|
|
||||||
Text: fullText.String(),
|
|
||||||
ToolCalls: toolCalls,
|
|
||||||
Usage: usage,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Provider) buildRequest(req provider.Request) openai.ChatCompletionNewParams {
|
|
||||||
oaiReq := openai.ChatCompletionNewParams{
|
|
||||||
Model: req.Model,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, msg := range req.Messages {
|
|
||||||
oaiReq.Messages = append(oaiReq.Messages, convertMessage(msg, req.Model))
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tool := range req.Tools {
|
|
||||||
oaiReq.Tools = append(oaiReq.Tools, openai.ChatCompletionToolParam{
|
|
||||||
Type: "function",
|
|
||||||
Function: shared.FunctionDefinitionParam{
|
|
||||||
Name: tool.Name,
|
|
||||||
Description: openai.String(tool.Description),
|
|
||||||
Parameters: openai.FunctionParameters(tool.Schema),
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
}
|
|
||||||
|
|
||||||
if req.Temperature != nil {
|
|
||||||
// o* and gpt-5* models don't support custom temperatures
|
|
||||||
if !strings.HasPrefix(req.Model, "o") && !strings.HasPrefix(req.Model, "gpt-5") {
|
|
||||||
oaiReq.Temperature = openai.Float(*req.Temperature)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.MaxTokens != nil {
|
|
||||||
oaiReq.MaxCompletionTokens = openai.Int(int64(*req.MaxTokens))
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.TopP != nil {
|
|
||||||
oaiReq.TopP = openai.Float(*req.TopP)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(req.Stop) > 0 {
|
|
||||||
oaiReq.Stop = openai.ChatCompletionNewParamsStopUnion{OfString: openai.String(req.Stop[0])}
|
|
||||||
}
|
|
||||||
|
|
||||||
return oaiReq
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertMessage(msg provider.Message, model string) openai.ChatCompletionMessageParamUnion {
|
// restrictTemperature reports whether OpenAI rejects a user-supplied
|
||||||
var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam
|
// temperature for this model. o-series reasoning models and gpt-5* both do.
|
||||||
var textContent param.Opt[string]
|
func restrictTemperature(model string) bool {
|
||||||
|
return strings.HasPrefix(model, "o") || strings.HasPrefix(model, "gpt-5")
|
||||||
for _, img := range msg.Images {
|
|
||||||
var url string
|
|
||||||
if img.Base64 != "" {
|
|
||||||
url = "data:" + img.ContentType + ";base64," + img.Base64
|
|
||||||
} else if img.URL != "" {
|
|
||||||
url = img.URL
|
|
||||||
}
|
|
||||||
if url != "" {
|
|
||||||
arrayOfContentParts = append(arrayOfContentParts,
|
|
||||||
openai.ChatCompletionContentPartUnionParam{
|
|
||||||
OfImageURL: &openai.ChatCompletionContentPartImageParam{
|
|
||||||
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
|
|
||||||
URL: url,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, aud := range msg.Audio {
|
|
||||||
var b64Data string
|
|
||||||
var format string
|
|
||||||
|
|
||||||
if aud.Base64 != "" {
|
|
||||||
b64Data = aud.Base64
|
|
||||||
format = audioFormat(aud.ContentType)
|
|
||||||
} else if aud.URL != "" {
|
|
||||||
resp, err := http.Get(aud.URL)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data, err := io.ReadAll(resp.Body)
|
|
||||||
resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
b64Data = base64.StdEncoding.EncodeToString(data)
|
|
||||||
ct := resp.Header.Get("Content-Type")
|
|
||||||
if ct == "" {
|
|
||||||
ct = aud.ContentType
|
|
||||||
}
|
|
||||||
if ct == "" {
|
|
||||||
ct = audioFormatFromURL(aud.URL)
|
|
||||||
}
|
|
||||||
format = audioFormat(ct)
|
|
||||||
}
|
|
||||||
|
|
||||||
if b64Data != "" && format != "" {
|
|
||||||
arrayOfContentParts = append(arrayOfContentParts,
|
|
||||||
openai.ChatCompletionContentPartUnionParam{
|
|
||||||
OfInputAudio: &openai.ChatCompletionContentPartInputAudioParam{
|
|
||||||
InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
|
|
||||||
Data: b64Data,
|
|
||||||
Format: format,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if msg.Content != "" {
|
|
||||||
if len(arrayOfContentParts) > 0 {
|
|
||||||
arrayOfContentParts = append(arrayOfContentParts,
|
|
||||||
openai.ChatCompletionContentPartUnionParam{
|
|
||||||
OfText: &openai.ChatCompletionContentPartTextParam{
|
|
||||||
Text: msg.Content,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
textContent = openai.String(msg.Content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine if this model uses developer messages instead of system
|
|
||||||
useDeveloper := false
|
|
||||||
parts := strings.Split(model, "-")
|
|
||||||
if len(parts) > 1 && len(parts[0]) > 0 && parts[0][0] == 'o' {
|
|
||||||
useDeveloper = true
|
|
||||||
}
|
|
||||||
|
|
||||||
switch msg.Role {
|
|
||||||
case "system":
|
|
||||||
if useDeveloper {
|
|
||||||
return openai.ChatCompletionMessageParamUnion{
|
|
||||||
OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{
|
|
||||||
Content: openai.ChatCompletionDeveloperMessageParamContentUnion{
|
|
||||||
OfString: textContent,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return openai.ChatCompletionMessageParamUnion{
|
|
||||||
OfSystem: &openai.ChatCompletionSystemMessageParam{
|
|
||||||
Content: openai.ChatCompletionSystemMessageParamContentUnion{
|
|
||||||
OfString: textContent,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
case "user":
|
|
||||||
return openai.ChatCompletionMessageParamUnion{
|
|
||||||
OfUser: &openai.ChatCompletionUserMessageParam{
|
|
||||||
Content: openai.ChatCompletionUserMessageParamContentUnion{
|
|
||||||
OfString: textContent,
|
|
||||||
OfArrayOfContentParts: arrayOfContentParts,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
case "assistant":
|
|
||||||
as := &openai.ChatCompletionAssistantMessageParam{}
|
|
||||||
if msg.Content != "" {
|
|
||||||
as.Content.OfString = openai.String(msg.Content)
|
|
||||||
}
|
|
||||||
for _, tc := range msg.ToolCalls {
|
|
||||||
as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{
|
|
||||||
ID: tc.ID,
|
|
||||||
Function: openai.ChatCompletionMessageToolCallFunctionParam{
|
|
||||||
Name: tc.Name,
|
|
||||||
Arguments: tc.Arguments,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return openai.ChatCompletionMessageParamUnion{OfAssistant: as}
|
|
||||||
|
|
||||||
case "tool":
|
|
||||||
return openai.ChatCompletionMessageParamUnion{
|
|
||||||
OfTool: &openai.ChatCompletionToolMessageParam{
|
|
||||||
ToolCallID: msg.ToolCallID,
|
|
||||||
Content: openai.ChatCompletionToolMessageParamContentUnion{
|
|
||||||
OfString: openai.String(msg.Content),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback to user message
|
|
||||||
return openai.ChatCompletionMessageParamUnion{
|
|
||||||
OfUser: &openai.ChatCompletionUserMessageParam{
|
|
||||||
Content: openai.ChatCompletionUserMessageParamContentUnion{
|
|
||||||
OfString: textContent,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Provider) convertResponse(resp *openai.ChatCompletion) provider.Response {
|
|
||||||
var res provider.Response
|
|
||||||
|
|
||||||
if resp == nil || len(resp.Choices) == 0 {
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
choice := resp.Choices[0]
|
|
||||||
res.Text = choice.Message.Content
|
|
||||||
|
|
||||||
for _, tc := range choice.Message.ToolCalls {
|
|
||||||
res.ToolCalls = append(res.ToolCalls, provider.ToolCall{
|
|
||||||
ID: tc.ID,
|
|
||||||
Name: tc.Function.Name,
|
|
||||||
Arguments: strings.TrimSpace(tc.Function.Arguments),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.Usage.TotalTokens > 0 {
|
|
||||||
res.Usage = &provider.Usage{
|
|
||||||
InputTokens: int(resp.Usage.PromptTokens),
|
|
||||||
OutputTokens: int(resp.Usage.CompletionTokens),
|
|
||||||
TotalTokens: int(resp.Usage.TotalTokens),
|
|
||||||
}
|
|
||||||
res.Usage.Details = extractUsageDetails(resp.Usage)
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// audioFormat converts a MIME type to an OpenAI audio format string ("wav" or "mp3").
|
|
||||||
func audioFormat(contentType string) string {
|
|
||||||
ct := strings.ToLower(contentType)
|
|
||||||
switch {
|
|
||||||
case strings.Contains(ct, "wav"):
|
|
||||||
return "wav"
|
|
||||||
case strings.Contains(ct, "mp3"), strings.Contains(ct, "mpeg"):
|
|
||||||
return "mp3"
|
|
||||||
default:
|
|
||||||
return "wav"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractUsageDetails extracts provider-specific detail tokens from an OpenAI CompletionUsage.
|
|
||||||
func extractUsageDetails(usage openai.CompletionUsage) map[string]int {
|
|
||||||
details := map[string]int{}
|
|
||||||
if usage.CompletionTokensDetails.ReasoningTokens > 0 {
|
|
||||||
details[provider.UsageDetailReasoningTokens] = int(usage.CompletionTokensDetails.ReasoningTokens)
|
|
||||||
}
|
|
||||||
if usage.CompletionTokensDetails.AudioTokens > 0 {
|
|
||||||
details[provider.UsageDetailAudioOutputTokens] = int(usage.CompletionTokensDetails.AudioTokens)
|
|
||||||
}
|
|
||||||
if usage.PromptTokensDetails.CachedTokens > 0 {
|
|
||||||
details[provider.UsageDetailCachedInputTokens] = int(usage.PromptTokensDetails.CachedTokens)
|
|
||||||
}
|
|
||||||
if usage.PromptTokensDetails.AudioTokens > 0 {
|
|
||||||
details[provider.UsageDetailAudioInputTokens] = int(usage.PromptTokensDetails.AudioTokens)
|
|
||||||
}
|
|
||||||
if len(details) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return details
|
|
||||||
}
|
|
||||||
|
|
||||||
// audioFormatFromURL guesses the audio format from a URL's file extension.
|
|
||||||
func audioFormatFromURL(u string) string {
|
|
||||||
ext := strings.ToLower(path.Ext(u))
|
|
||||||
switch ext {
|
|
||||||
case ".mp3":
|
|
||||||
return "audio/mp3"
|
|
||||||
case ".wav":
|
|
||||||
return "audio/wav"
|
|
||||||
default:
|
|
||||||
return "audio/wav"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,537 @@
|
|||||||
|
// Package openaicompat implements a shared chat-completion Provider for any
|
||||||
|
// service that speaks the OpenAI Chat Completions API (OpenAI itself, DeepSeek,
|
||||||
|
// Moonshot, xAI, Groq, Ollama, and friends).
|
||||||
|
//
|
||||||
|
// Most providers differ from vanilla OpenAI only in endpoint URL and a handful
|
||||||
|
// of per-model quirks (e.g., "this model is text-only", "this model doesn't
|
||||||
|
// accept tools", "drop temperature on reasoning models"). Those quirks are
|
||||||
|
// captured declaratively via Rules, so a concrete provider package is usually
|
||||||
|
// a one-function wrapper that calls New with its own base URL and Rules.
|
||||||
|
package openaicompat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"path"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/openai/openai-go"
|
||||||
|
"github.com/openai/openai-go/option"
|
||||||
|
"github.com/openai/openai-go/packages/param"
|
||||||
|
"github.com/openai/openai-go/shared"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Rules encodes provider-specific constraints on top of the OpenAI wire
|
||||||
|
// protocol. The zero value means "no restrictions" and behaves like vanilla
|
||||||
|
// OpenAI. Individual fields are documented inline.
|
||||||
|
type Rules struct {
|
||||||
|
// MaxImagesPerMessage rejects requests whose any single message carries
|
||||||
|
// more images than this cap. 0 means "no cap".
|
||||||
|
MaxImagesPerMessage int
|
||||||
|
|
||||||
|
// MaxAudioPerMessage rejects requests whose any single message carries
|
||||||
|
// more audio attachments than this cap. 0 means "no cap".
|
||||||
|
MaxAudioPerMessage int
|
||||||
|
|
||||||
|
// SupportsVision, when non-nil, is consulted for every request that
|
||||||
|
// includes any image attachments. If it returns false for the request's
|
||||||
|
// model, the call fails with a FeatureUnsupportedError before hitting
|
||||||
|
// the network.
|
||||||
|
SupportsVision func(model string) bool
|
||||||
|
|
||||||
|
// SupportsTools, when non-nil, is consulted for every request that
|
||||||
|
// includes any tool definitions. If it returns false for the model,
|
||||||
|
// the call fails with a FeatureUnsupportedError before hitting the
|
||||||
|
// network.
|
||||||
|
SupportsTools func(model string) bool
|
||||||
|
|
||||||
|
// SupportsAudio, when non-nil, is consulted for every request that
|
||||||
|
// includes any audio attachments. If it returns false for the model,
|
||||||
|
// the call fails with a FeatureUnsupportedError.
|
||||||
|
SupportsAudio func(model string) bool
|
||||||
|
|
||||||
|
// RestrictTemperature, when non-nil and returning true for the request's
|
||||||
|
// model, causes the Temperature field to be silently dropped from the
|
||||||
|
// outgoing request. Used by OpenAI o-series and gpt-5* which reject a
|
||||||
|
// user-provided temperature.
|
||||||
|
RestrictTemperature func(model string) bool
|
||||||
|
|
||||||
|
// CustomizeRequest is a last-mile hook invoked after buildRequest but
|
||||||
|
// before the call is sent. It receives the fully built OpenAI SDK
|
||||||
|
// parameters and may mutate them freely (add headers, flip flags, tweak
|
||||||
|
// response_format, etc.).
|
||||||
|
CustomizeRequest func(params *openai.ChatCompletionNewParams)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FeatureUnsupportedError is returned when a Rules predicate rejects a request
|
||||||
|
// because the target model does not support a feature the caller included.
|
||||||
|
type FeatureUnsupportedError struct {
|
||||||
|
Feature string
|
||||||
|
Model string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *FeatureUnsupportedError) Error() string {
|
||||||
|
return fmt.Sprintf("openaicompat: model %q does not support %s", e.Model, e.Feature)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Provider implements provider.Provider for any OpenAI-compatible endpoint.
|
||||||
|
type Provider struct {
|
||||||
|
apiKey string
|
||||||
|
baseURL string
|
||||||
|
rules Rules
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a Provider. baseURL may be empty to let the OpenAI SDK use its
|
||||||
|
// default; in practice concrete provider packages always pass a default.
|
||||||
|
func New(apiKey, baseURL string, rules Rules) *Provider {
|
||||||
|
return &Provider{apiKey: apiKey, baseURL: baseURL, rules: rules}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Complete performs a non-streaming completion.
|
||||||
|
func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
if err := p.checkRules(req); err != nil {
|
||||||
|
return provider.Response{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cl := openai.NewClient(p.requestOptions()...)
|
||||||
|
oaiReq := p.buildRequest(req)
|
||||||
|
if p.rules.CustomizeRequest != nil {
|
||||||
|
p.rules.CustomizeRequest(&oaiReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := cl.Chat.Completions.New(ctx, oaiReq)
|
||||||
|
if err != nil {
|
||||||
|
return provider.Response{}, fmt.Errorf("openai completion error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.convertResponse(resp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream performs a streaming completion.
|
||||||
|
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
|
||||||
|
if err := p.checkRules(req); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cl := openai.NewClient(p.requestOptions()...)
|
||||||
|
oaiReq := p.buildRequest(req)
|
||||||
|
oaiReq.StreamOptions = openai.ChatCompletionStreamOptionsParam{
|
||||||
|
IncludeUsage: openai.Bool(true),
|
||||||
|
}
|
||||||
|
if p.rules.CustomizeRequest != nil {
|
||||||
|
p.rules.CustomizeRequest(&oaiReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
stream := cl.Chat.Completions.NewStreaming(ctx, oaiReq)
|
||||||
|
|
||||||
|
var fullText strings.Builder
|
||||||
|
var toolCalls []provider.ToolCall
|
||||||
|
toolCallArgs := map[int]*strings.Builder{}
|
||||||
|
var usage *provider.Usage
|
||||||
|
|
||||||
|
for stream.Next() {
|
||||||
|
chunk := stream.Current()
|
||||||
|
|
||||||
|
// Capture usage from the final chunk (present when StreamOptions.IncludeUsage is true)
|
||||||
|
if chunk.Usage.TotalTokens > 0 {
|
||||||
|
usage = &provider.Usage{
|
||||||
|
InputTokens: int(chunk.Usage.PromptTokens),
|
||||||
|
OutputTokens: int(chunk.Usage.CompletionTokens),
|
||||||
|
TotalTokens: int(chunk.Usage.TotalTokens),
|
||||||
|
Details: extractUsageDetails(chunk.Usage),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, choice := range chunk.Choices {
|
||||||
|
// Text delta
|
||||||
|
if choice.Delta.Content != "" {
|
||||||
|
fullText.WriteString(choice.Delta.Content)
|
||||||
|
events <- provider.StreamEvent{
|
||||||
|
Type: provider.StreamEventText,
|
||||||
|
Text: choice.Delta.Content,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool call deltas
|
||||||
|
for _, tc := range choice.Delta.ToolCalls {
|
||||||
|
idx := int(tc.Index)
|
||||||
|
|
||||||
|
if tc.ID != "" {
|
||||||
|
// New tool call starting
|
||||||
|
for len(toolCalls) <= idx {
|
||||||
|
toolCalls = append(toolCalls, provider.ToolCall{})
|
||||||
|
}
|
||||||
|
toolCalls[idx].ID = tc.ID
|
||||||
|
toolCalls[idx].Name = tc.Function.Name
|
||||||
|
toolCallArgs[idx] = &strings.Builder{}
|
||||||
|
|
||||||
|
events <- provider.StreamEvent{
|
||||||
|
Type: provider.StreamEventToolStart,
|
||||||
|
ToolCall: &provider.ToolCall{
|
||||||
|
ID: tc.ID,
|
||||||
|
Name: tc.Function.Name,
|
||||||
|
},
|
||||||
|
ToolIndex: idx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.Function.Arguments != "" {
|
||||||
|
if b, ok := toolCallArgs[idx]; ok {
|
||||||
|
b.WriteString(tc.Function.Arguments)
|
||||||
|
}
|
||||||
|
events <- provider.StreamEvent{
|
||||||
|
Type: provider.StreamEventToolDelta,
|
||||||
|
ToolIndex: idx,
|
||||||
|
ToolCall: &provider.ToolCall{
|
||||||
|
Arguments: tc.Function.Arguments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := stream.Err(); err != nil {
|
||||||
|
return fmt.Errorf("openai stream error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finalize tool calls
|
||||||
|
for idx := range toolCalls {
|
||||||
|
if b, ok := toolCallArgs[idx]; ok {
|
||||||
|
toolCalls[idx].Arguments = b.String()
|
||||||
|
}
|
||||||
|
events <- provider.StreamEvent{
|
||||||
|
Type: provider.StreamEventToolEnd,
|
||||||
|
ToolIndex: idx,
|
||||||
|
ToolCall: &toolCalls[idx],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
events <- provider.StreamEvent{
|
||||||
|
Type: provider.StreamEventDone,
|
||||||
|
Response: &provider.Response{
|
||||||
|
Text: fullText.String(),
|
||||||
|
ToolCalls: toolCalls,
|
||||||
|
Usage: usage,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) requestOptions() []option.RequestOption {
|
||||||
|
opts := []option.RequestOption{option.WithAPIKey(p.apiKey)}
|
||||||
|
if p.baseURL != "" {
|
||||||
|
opts = append(opts, option.WithBaseURL(p.baseURL))
|
||||||
|
}
|
||||||
|
return opts
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkRules applies all Rules predicates against a request and returns an
|
||||||
|
// error if any constraint is violated. Runs before any network call.
|
||||||
|
func (p *Provider) checkRules(req provider.Request) error {
|
||||||
|
var hasImages, hasAudio bool
|
||||||
|
for _, msg := range req.Messages {
|
||||||
|
if len(msg.Images) > 0 {
|
||||||
|
hasImages = true
|
||||||
|
}
|
||||||
|
if len(msg.Audio) > 0 {
|
||||||
|
hasAudio = true
|
||||||
|
}
|
||||||
|
if p.rules.MaxImagesPerMessage > 0 && len(msg.Images) > p.rules.MaxImagesPerMessage {
|
||||||
|
return fmt.Errorf("openaicompat: message has %d images, max allowed is %d for model %q",
|
||||||
|
len(msg.Images), p.rules.MaxImagesPerMessage, req.Model)
|
||||||
|
}
|
||||||
|
if p.rules.MaxAudioPerMessage > 0 && len(msg.Audio) > p.rules.MaxAudioPerMessage {
|
||||||
|
return fmt.Errorf("openaicompat: message has %d audio attachments, max allowed is %d for model %q",
|
||||||
|
len(msg.Audio), p.rules.MaxAudioPerMessage, req.Model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasImages && p.rules.SupportsVision != nil && !p.rules.SupportsVision(req.Model) {
|
||||||
|
return &FeatureUnsupportedError{Feature: "vision", Model: req.Model}
|
||||||
|
}
|
||||||
|
if hasAudio && p.rules.SupportsAudio != nil && !p.rules.SupportsAudio(req.Model) {
|
||||||
|
return &FeatureUnsupportedError{Feature: "audio", Model: req.Model}
|
||||||
|
}
|
||||||
|
if len(req.Tools) > 0 && p.rules.SupportsTools != nil && !p.rules.SupportsTools(req.Model) {
|
||||||
|
return &FeatureUnsupportedError{Feature: "tools", Model: req.Model}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) buildRequest(req provider.Request) openai.ChatCompletionNewParams {
|
||||||
|
oaiReq := openai.ChatCompletionNewParams{
|
||||||
|
Model: req.Model,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, msg := range req.Messages {
|
||||||
|
oaiReq.Messages = append(oaiReq.Messages, convertMessage(msg, req.Model))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tool := range req.Tools {
|
||||||
|
oaiReq.Tools = append(oaiReq.Tools, openai.ChatCompletionToolParam{
|
||||||
|
Type: "function",
|
||||||
|
Function: shared.FunctionDefinitionParam{
|
||||||
|
Name: tool.Name,
|
||||||
|
Description: openai.String(tool.Description),
|
||||||
|
Parameters: openai.FunctionParameters(tool.Schema),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Temperature != nil {
|
||||||
|
if p.rules.RestrictTemperature == nil || !p.rules.RestrictTemperature(req.Model) {
|
||||||
|
oaiReq.Temperature = openai.Float(*req.Temperature)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.MaxTokens != nil {
|
||||||
|
oaiReq.MaxCompletionTokens = openai.Int(int64(*req.MaxTokens))
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.TopP != nil {
|
||||||
|
oaiReq.TopP = openai.Float(*req.TopP)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(req.Stop) > 0 {
|
||||||
|
oaiReq.Stop = openai.ChatCompletionNewParamsStopUnion{OfString: openai.String(req.Stop[0])}
|
||||||
|
}
|
||||||
|
|
||||||
|
return oaiReq
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertMessage(msg provider.Message, model string) openai.ChatCompletionMessageParamUnion {
|
||||||
|
var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam
|
||||||
|
var textContent param.Opt[string]
|
||||||
|
|
||||||
|
for _, img := range msg.Images {
|
||||||
|
var url string
|
||||||
|
if img.Base64 != "" {
|
||||||
|
url = "data:" + img.ContentType + ";base64," + img.Base64
|
||||||
|
} else if img.URL != "" {
|
||||||
|
url = img.URL
|
||||||
|
}
|
||||||
|
if url != "" {
|
||||||
|
arrayOfContentParts = append(arrayOfContentParts,
|
||||||
|
openai.ChatCompletionContentPartUnionParam{
|
||||||
|
OfImageURL: &openai.ChatCompletionContentPartImageParam{
|
||||||
|
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
|
||||||
|
URL: url,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, aud := range msg.Audio {
|
||||||
|
var b64Data string
|
||||||
|
var format string
|
||||||
|
|
||||||
|
if aud.Base64 != "" {
|
||||||
|
b64Data = aud.Base64
|
||||||
|
format = audioFormat(aud.ContentType)
|
||||||
|
} else if aud.URL != "" {
|
||||||
|
resp, err := http.Get(aud.URL)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data, err := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b64Data = base64.StdEncoding.EncodeToString(data)
|
||||||
|
ct := resp.Header.Get("Content-Type")
|
||||||
|
if ct == "" {
|
||||||
|
ct = aud.ContentType
|
||||||
|
}
|
||||||
|
if ct == "" {
|
||||||
|
ct = audioFormatFromURL(aud.URL)
|
||||||
|
}
|
||||||
|
format = audioFormat(ct)
|
||||||
|
}
|
||||||
|
|
||||||
|
if b64Data != "" && format != "" {
|
||||||
|
arrayOfContentParts = append(arrayOfContentParts,
|
||||||
|
openai.ChatCompletionContentPartUnionParam{
|
||||||
|
OfInputAudio: &openai.ChatCompletionContentPartInputAudioParam{
|
||||||
|
InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
|
||||||
|
Data: b64Data,
|
||||||
|
Format: format,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg.Content != "" {
|
||||||
|
if len(arrayOfContentParts) > 0 {
|
||||||
|
arrayOfContentParts = append(arrayOfContentParts,
|
||||||
|
openai.ChatCompletionContentPartUnionParam{
|
||||||
|
OfText: &openai.ChatCompletionContentPartTextParam{
|
||||||
|
Text: msg.Content,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
textContent = openai.String(msg.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine if this model uses developer messages instead of system
|
||||||
|
useDeveloper := false
|
||||||
|
parts := strings.Split(model, "-")
|
||||||
|
if len(parts) > 1 && len(parts[0]) > 0 && parts[0][0] == 'o' {
|
||||||
|
useDeveloper = true
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg.Role {
|
||||||
|
case "system":
|
||||||
|
if useDeveloper {
|
||||||
|
return openai.ChatCompletionMessageParamUnion{
|
||||||
|
OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{
|
||||||
|
Content: openai.ChatCompletionDeveloperMessageParamContentUnion{
|
||||||
|
OfString: textContent,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return openai.ChatCompletionMessageParamUnion{
|
||||||
|
OfSystem: &openai.ChatCompletionSystemMessageParam{
|
||||||
|
Content: openai.ChatCompletionSystemMessageParamContentUnion{
|
||||||
|
OfString: textContent,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
case "user":
|
||||||
|
return openai.ChatCompletionMessageParamUnion{
|
||||||
|
OfUser: &openai.ChatCompletionUserMessageParam{
|
||||||
|
Content: openai.ChatCompletionUserMessageParamContentUnion{
|
||||||
|
OfString: textContent,
|
||||||
|
OfArrayOfContentParts: arrayOfContentParts,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
case "assistant":
|
||||||
|
as := &openai.ChatCompletionAssistantMessageParam{}
|
||||||
|
if msg.Content != "" {
|
||||||
|
as.Content.OfString = openai.String(msg.Content)
|
||||||
|
}
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{
|
||||||
|
ID: tc.ID,
|
||||||
|
Function: openai.ChatCompletionMessageToolCallFunctionParam{
|
||||||
|
Name: tc.Name,
|
||||||
|
Arguments: tc.Arguments,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return openai.ChatCompletionMessageParamUnion{OfAssistant: as}
|
||||||
|
|
||||||
|
case "tool":
|
||||||
|
return openai.ChatCompletionMessageParamUnion{
|
||||||
|
OfTool: &openai.ChatCompletionToolMessageParam{
|
||||||
|
ToolCallID: msg.ToolCallID,
|
||||||
|
Content: openai.ChatCompletionToolMessageParamContentUnion{
|
||||||
|
OfString: openai.String(msg.Content),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to user message
|
||||||
|
return openai.ChatCompletionMessageParamUnion{
|
||||||
|
OfUser: &openai.ChatCompletionUserMessageParam{
|
||||||
|
Content: openai.ChatCompletionUserMessageParamContentUnion{
|
||||||
|
OfString: textContent,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) convertResponse(resp *openai.ChatCompletion) provider.Response {
|
||||||
|
var res provider.Response
|
||||||
|
|
||||||
|
if resp == nil || len(resp.Choices) == 0 {
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
choice := resp.Choices[0]
|
||||||
|
res.Text = choice.Message.Content
|
||||||
|
|
||||||
|
for _, tc := range choice.Message.ToolCalls {
|
||||||
|
res.ToolCalls = append(res.ToolCalls, provider.ToolCall{
|
||||||
|
ID: tc.ID,
|
||||||
|
Name: tc.Function.Name,
|
||||||
|
Arguments: strings.TrimSpace(tc.Function.Arguments),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Usage.TotalTokens > 0 {
|
||||||
|
res.Usage = &provider.Usage{
|
||||||
|
InputTokens: int(resp.Usage.PromptTokens),
|
||||||
|
OutputTokens: int(resp.Usage.CompletionTokens),
|
||||||
|
TotalTokens: int(resp.Usage.TotalTokens),
|
||||||
|
}
|
||||||
|
res.Usage.Details = extractUsageDetails(resp.Usage)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// audioFormat converts a MIME type to an OpenAI audio format string ("wav" or "mp3").
|
||||||
|
func audioFormat(contentType string) string {
|
||||||
|
ct := strings.ToLower(contentType)
|
||||||
|
switch {
|
||||||
|
case strings.Contains(ct, "wav"):
|
||||||
|
return "wav"
|
||||||
|
case strings.Contains(ct, "mp3"), strings.Contains(ct, "mpeg"):
|
||||||
|
return "mp3"
|
||||||
|
default:
|
||||||
|
return "wav"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractUsageDetails extracts provider-specific detail tokens from an OpenAI CompletionUsage.
|
||||||
|
func extractUsageDetails(usage openai.CompletionUsage) map[string]int {
|
||||||
|
details := map[string]int{}
|
||||||
|
if usage.CompletionTokensDetails.ReasoningTokens > 0 {
|
||||||
|
details[provider.UsageDetailReasoningTokens] = int(usage.CompletionTokensDetails.ReasoningTokens)
|
||||||
|
}
|
||||||
|
if usage.CompletionTokensDetails.AudioTokens > 0 {
|
||||||
|
details[provider.UsageDetailAudioOutputTokens] = int(usage.CompletionTokensDetails.AudioTokens)
|
||||||
|
}
|
||||||
|
if usage.PromptTokensDetails.CachedTokens > 0 {
|
||||||
|
details[provider.UsageDetailCachedInputTokens] = int(usage.PromptTokensDetails.CachedTokens)
|
||||||
|
}
|
||||||
|
if usage.PromptTokensDetails.AudioTokens > 0 {
|
||||||
|
details[provider.UsageDetailAudioInputTokens] = int(usage.PromptTokensDetails.AudioTokens)
|
||||||
|
}
|
||||||
|
if len(details) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return details
|
||||||
|
}
|
||||||
|
|
||||||
|
// audioFormatFromURL guesses the audio format from a URL's file extension.
|
||||||
|
func audioFormatFromURL(u string) string {
|
||||||
|
ext := strings.ToLower(path.Ext(u))
|
||||||
|
switch ext {
|
||||||
|
case ".mp3":
|
||||||
|
return "audio/mp3"
|
||||||
|
case ".wav":
|
||||||
|
return "audio/wav"
|
||||||
|
default:
|
||||||
|
return "audio/wav"
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,313 @@
|
|||||||
|
package openaicompat_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/openai/openai-go"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newTestServer returns a httptest server that captures the raw request body
|
||||||
|
// on POST /chat/completions and returns a canned OpenAI response so Complete()
|
||||||
|
// succeeds. Use `captured` to assert on what the provider would send.
|
||||||
|
func newTestServer(t *testing.T) (*httptest.Server, *[]byte) {
|
||||||
|
t.Helper()
|
||||||
|
var body []byte
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/chat/completions" {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("read body: %v", err)
|
||||||
|
}
|
||||||
|
body = b
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = io.WriteString(w, `{
|
||||||
|
"id": "cmpl-1",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role":"assistant","content":"ok"},
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}],
|
||||||
|
"usage": {"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}
|
||||||
|
}`)
|
||||||
|
}))
|
||||||
|
return srv, &body
|
||||||
|
}
|
||||||
|
|
||||||
|
func textReq(model, content string) provider.Request {
|
||||||
|
return provider.Request{
|
||||||
|
Model: model,
|
||||||
|
Messages: []provider.Message{{Role: "user", Content: content}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComplete_ZeroRulesPassesThrough(t *testing.T) {
|
||||||
|
srv, body := newTestServer(t)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
temp := 0.7
|
||||||
|
req := textReq("gpt-4o", "hi")
|
||||||
|
req.Temperature = &temp
|
||||||
|
|
||||||
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{})
|
||||||
|
resp, err := p.Complete(context.Background(), req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Complete: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Text != "ok" {
|
||||||
|
t.Errorf("Text = %q, want %q", resp.Text, "ok")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Temperature should be present since RestrictTemperature is nil.
|
||||||
|
var parsed map[string]any
|
||||||
|
if err := json.Unmarshal(*body, &parsed); err != nil {
|
||||||
|
t.Fatalf("unmarshal request body: %v", err)
|
||||||
|
}
|
||||||
|
if _, ok := parsed["temperature"]; !ok {
|
||||||
|
t.Errorf("expected temperature in request body, got: %s", *body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComplete_RestrictTemperatureDropsField(t *testing.T) {
|
||||||
|
srv, body := newTestServer(t)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
temp := 0.7
|
||||||
|
req := textReq("o1", "hi")
|
||||||
|
req.Temperature = &temp
|
||||||
|
|
||||||
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
||||||
|
RestrictTemperature: func(m string) bool { return strings.HasPrefix(m, "o") },
|
||||||
|
})
|
||||||
|
if _, err := p.Complete(context.Background(), req); err != nil {
|
||||||
|
t.Fatalf("Complete: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var parsed map[string]any
|
||||||
|
if err := json.Unmarshal(*body, &parsed); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if _, ok := parsed["temperature"]; ok {
|
||||||
|
t.Errorf("temperature should be dropped for o1, got: %s", *body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComplete_SupportsVisionRejectsWhenFalse(t *testing.T) {
|
||||||
|
srv, _ := newTestServer(t)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
req := provider.Request{
|
||||||
|
Model: "deepseek-chat",
|
||||||
|
Messages: []provider.Message{{
|
||||||
|
Role: "user",
|
||||||
|
Content: "describe",
|
||||||
|
Images: []provider.Image{{URL: "https://example.com/a.png"}},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
||||||
|
SupportsVision: func(string) bool { return false },
|
||||||
|
})
|
||||||
|
_, err := p.Complete(context.Background(), req)
|
||||||
|
var fue *openaicompat.FeatureUnsupportedError
|
||||||
|
if !errors.As(err, &fue) {
|
||||||
|
t.Fatalf("want FeatureUnsupportedError, got %v", err)
|
||||||
|
}
|
||||||
|
if fue.Feature != "vision" || fue.Model != "deepseek-chat" {
|
||||||
|
t.Errorf("unexpected err: %+v", fue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComplete_SupportsToolsRejectsWhenFalse(t *testing.T) {
|
||||||
|
srv, _ := newTestServer(t)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
req := provider.Request{
|
||||||
|
Model: "deepseek-reasoner",
|
||||||
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
||||||
|
Tools: []provider.ToolDef{
|
||||||
|
{Name: "get_weather", Description: "weather", Schema: map[string]any{"type": "object"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
||||||
|
SupportsTools: func(m string) bool { return !strings.Contains(m, "reasoner") },
|
||||||
|
})
|
||||||
|
_, err := p.Complete(context.Background(), req)
|
||||||
|
var fue *openaicompat.FeatureUnsupportedError
|
||||||
|
if !errors.As(err, &fue) {
|
||||||
|
t.Fatalf("want FeatureUnsupportedError, got %v", err)
|
||||||
|
}
|
||||||
|
if fue.Feature != "tools" {
|
||||||
|
t.Errorf("feature = %q, want tools", fue.Feature)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComplete_SupportsAudioRejectsWhenFalse(t *testing.T) {
|
||||||
|
srv, _ := newTestServer(t)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
req := provider.Request{
|
||||||
|
Model: "groq-llama",
|
||||||
|
Messages: []provider.Message{{
|
||||||
|
Role: "user",
|
||||||
|
Audio: []provider.Audio{{Base64: "AAA=", ContentType: "audio/wav"}},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
||||||
|
SupportsAudio: func(string) bool { return false },
|
||||||
|
})
|
||||||
|
_, err := p.Complete(context.Background(), req)
|
||||||
|
var fue *openaicompat.FeatureUnsupportedError
|
||||||
|
if !errors.As(err, &fue) {
|
||||||
|
t.Fatalf("want FeatureUnsupportedError, got %v", err)
|
||||||
|
}
|
||||||
|
if fue.Feature != "audio" {
|
||||||
|
t.Errorf("feature = %q, want audio", fue.Feature)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComplete_MaxImagesPerMessage(t *testing.T) {
|
||||||
|
srv, _ := newTestServer(t)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
req := provider.Request{
|
||||||
|
Model: "anything",
|
||||||
|
Messages: []provider.Message{{
|
||||||
|
Role: "user",
|
||||||
|
Images: []provider.Image{
|
||||||
|
{URL: "a"}, {URL: "b"}, {URL: "c"},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{MaxImagesPerMessage: 2})
|
||||||
|
_, err := p.Complete(context.Background(), req)
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "max allowed is 2") {
|
||||||
|
t.Fatalf("want max-images error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exactly at limit succeeds.
|
||||||
|
req.Messages[0].Images = req.Messages[0].Images[:2]
|
||||||
|
if _, err := p.Complete(context.Background(), req); err != nil {
|
||||||
|
t.Errorf("at-limit request should succeed, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComplete_CustomizeRequestInvoked(t *testing.T) {
|
||||||
|
srv, body := newTestServer(t)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
called := false
|
||||||
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
||||||
|
CustomizeRequest: func(params *openai.ChatCompletionNewParams) {
|
||||||
|
called = true
|
||||||
|
// Confirm we receive a non-empty built request.
|
||||||
|
if params.Model != "gpt-4o" {
|
||||||
|
t.Errorf("CustomizeRequest saw model %q, want gpt-4o", params.Model)
|
||||||
|
}
|
||||||
|
// Mutation here should end up on the wire.
|
||||||
|
params.User = openai.String("test-user")
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if _, err := p.Complete(context.Background(), textReq("gpt-4o", "hi")); err != nil {
|
||||||
|
t.Fatalf("Complete: %v", err)
|
||||||
|
}
|
||||||
|
if !called {
|
||||||
|
t.Fatal("CustomizeRequest hook was not invoked")
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(*body), `"user":"test-user"`) {
|
||||||
|
t.Errorf("mutation from CustomizeRequest not reflected on wire: %s", *body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStream_EmitsDoneAndText(t *testing.T) {
|
||||||
|
// SSE stream with one content chunk then [DONE].
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
flusher, _ := w.(http.Flusher)
|
||||||
|
for _, line := range []string{
|
||||||
|
`data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"hel"}}]}`,
|
||||||
|
`data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"lo"}}]}`,
|
||||||
|
`data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
} {
|
||||||
|
_, _ = io.WriteString(w, line+"\n\n")
|
||||||
|
if flusher != nil {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{})
|
||||||
|
events := make(chan provider.StreamEvent, 16)
|
||||||
|
go func() {
|
||||||
|
_ = p.Stream(context.Background(), textReq("gpt-4o", "hi"), events)
|
||||||
|
close(events)
|
||||||
|
}()
|
||||||
|
|
||||||
|
var text strings.Builder
|
||||||
|
var sawDone bool
|
||||||
|
var doneUsage *provider.Usage
|
||||||
|
for ev := range events {
|
||||||
|
switch ev.Type {
|
||||||
|
case provider.StreamEventText:
|
||||||
|
text.WriteString(ev.Text)
|
||||||
|
case provider.StreamEventDone:
|
||||||
|
sawDone = true
|
||||||
|
if ev.Response != nil {
|
||||||
|
doneUsage = ev.Response.Usage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if text.String() != "hello" {
|
||||||
|
t.Errorf("got text %q, want %q", text.String(), "hello")
|
||||||
|
}
|
||||||
|
if !sawDone {
|
||||||
|
t.Fatal("no Done event emitted")
|
||||||
|
}
|
||||||
|
if doneUsage == nil || doneUsage.TotalTokens != 3 {
|
||||||
|
t.Errorf("usage on Done = %+v, want TotalTokens=3", doneUsage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStream_RulesCheckedBeforeNetwork(t *testing.T) {
|
||||||
|
// Server should never be hit when rules reject up front.
|
||||||
|
hit := false
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
hit = true
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
||||||
|
SupportsVision: func(string) bool { return false },
|
||||||
|
})
|
||||||
|
req := provider.Request{
|
||||||
|
Model: "no-vision-model",
|
||||||
|
Messages: []provider.Message{{
|
||||||
|
Role: "user",
|
||||||
|
Images: []provider.Image{{URL: "a"}},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
events := make(chan provider.StreamEvent, 4)
|
||||||
|
err := p.Stream(context.Background(), req, events)
|
||||||
|
var fue *openaicompat.FeatureUnsupportedError
|
||||||
|
if !errors.As(err, &fue) {
|
||||||
|
t.Fatalf("want FeatureUnsupportedError, got %v", err)
|
||||||
|
}
|
||||||
|
if hit {
|
||||||
|
t.Error("server was contacted despite Rules violation")
|
||||||
|
}
|
||||||
|
}
|
||||||
+158
@@ -0,0 +1,158 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/deepseek"
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/groq"
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/moonshot"
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/ollama"
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openai"
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/xai"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProviderInfo describes a registered provider for discovery purposes (CLI
|
||||||
|
// pickers, wiring layers, admin tools). It is the single source of truth for
|
||||||
|
// "what providers exist and how do I instantiate one."
|
||||||
|
type ProviderInfo struct {
|
||||||
|
// Name is the short lowercase identifier used in provider/model strings
|
||||||
|
// (e.g., "openai", "deepseek", "moonshot").
|
||||||
|
Name string
|
||||||
|
|
||||||
|
// DisplayName is a human-readable label for UIs.
|
||||||
|
DisplayName string
|
||||||
|
|
||||||
|
// EnvKey is the conventional environment variable that holds the API key
|
||||||
|
// for this provider. Empty string means "no key needed" (e.g., Ollama).
|
||||||
|
EnvKey string
|
||||||
|
|
||||||
|
// DefaultURL is the default base URL used when no override is supplied.
|
||||||
|
DefaultURL string
|
||||||
|
|
||||||
|
// Models is a list of well-known model names, populated for CLI pickers
|
||||||
|
// and similar. It is not exhaustive and not validated against the API.
|
||||||
|
Models []string
|
||||||
|
|
||||||
|
// New returns a ready-to-use Client for this provider, given an API key
|
||||||
|
// (ignored for key-less providers like Ollama) and optional ClientOptions.
|
||||||
|
New func(apiKey string, opts ...ClientOption) *Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// providerRegistry is the in-process list of known providers. Order is
|
||||||
|
// intentional: the three original providers first, then OpenAI-compatible
|
||||||
|
// additions in the order they were added.
|
||||||
|
var providerRegistry = []ProviderInfo{
|
||||||
|
{
|
||||||
|
Name: "openai",
|
||||||
|
DisplayName: "OpenAI",
|
||||||
|
EnvKey: "OPENAI_API_KEY",
|
||||||
|
DefaultURL: openai.DefaultBaseURL,
|
||||||
|
Models: []string{
|
||||||
|
"gpt-4.1", "gpt-4.1-mini", "gpt-4.1-nano",
|
||||||
|
"gpt-4o", "gpt-4o-mini",
|
||||||
|
"gpt-4-turbo", "gpt-3.5-turbo",
|
||||||
|
"o1", "o1-mini", "o1-preview", "o3-mini",
|
||||||
|
},
|
||||||
|
New: OpenAI,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "anthropic",
|
||||||
|
DisplayName: "Anthropic",
|
||||||
|
EnvKey: "ANTHROPIC_API_KEY",
|
||||||
|
DefaultURL: "https://api.anthropic.com",
|
||||||
|
Models: []string{
|
||||||
|
"claude-opus-4-7",
|
||||||
|
"claude-sonnet-4-6",
|
||||||
|
"claude-haiku-4-5-20251001",
|
||||||
|
"claude-opus-4-20250514",
|
||||||
|
"claude-sonnet-4-20250514",
|
||||||
|
"claude-3-7-sonnet-20250219",
|
||||||
|
"claude-3-5-sonnet-20241022",
|
||||||
|
"claude-3-5-haiku-20241022",
|
||||||
|
},
|
||||||
|
New: Anthropic,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "google",
|
||||||
|
DisplayName: "Google",
|
||||||
|
EnvKey: "GOOGLE_API_KEY",
|
||||||
|
DefaultURL: "https://generativelanguage.googleapis.com",
|
||||||
|
Models: []string{
|
||||||
|
"gemini-2.0-flash", "gemini-2.0-flash-lite",
|
||||||
|
"gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b",
|
||||||
|
},
|
||||||
|
New: Google,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "deepseek",
|
||||||
|
DisplayName: "DeepSeek",
|
||||||
|
EnvKey: "DEEPSEEK_API_KEY",
|
||||||
|
DefaultURL: deepseek.DefaultBaseURL,
|
||||||
|
Models: []string{"deepseek-chat", "deepseek-reasoner"},
|
||||||
|
New: DeepSeek,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "moonshot",
|
||||||
|
DisplayName: "Moonshot (Kimi)",
|
||||||
|
EnvKey: "MOONSHOT_API_KEY",
|
||||||
|
DefaultURL: moonshot.DefaultBaseURL,
|
||||||
|
Models: []string{
|
||||||
|
"kimi-k2-0711-preview",
|
||||||
|
"moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k",
|
||||||
|
"moonshot-v1-8k-vision-preview",
|
||||||
|
},
|
||||||
|
New: Moonshot,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "xai",
|
||||||
|
DisplayName: "xAI (Grok)",
|
||||||
|
EnvKey: "XAI_API_KEY",
|
||||||
|
DefaultURL: xai.DefaultBaseURL,
|
||||||
|
Models: []string{
|
||||||
|
"grok-2", "grok-2-mini", "grok-2-vision", "grok-beta",
|
||||||
|
},
|
||||||
|
New: XAI,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "groq",
|
||||||
|
DisplayName: "Groq",
|
||||||
|
EnvKey: "GROQ_API_KEY",
|
||||||
|
DefaultURL: groq.DefaultBaseURL,
|
||||||
|
Models: []string{
|
||||||
|
"llama-3.3-70b-versatile",
|
||||||
|
"llama-3.1-8b-instant",
|
||||||
|
"mixtral-8x7b-32768",
|
||||||
|
"gemma2-9b-it",
|
||||||
|
"llama-3.2-90b-vision-preview",
|
||||||
|
},
|
||||||
|
New: Groq,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "ollama",
|
||||||
|
DisplayName: "Ollama (local)",
|
||||||
|
EnvKey: "", // no key needed
|
||||||
|
DefaultURL: ollama.DefaultBaseURL,
|
||||||
|
Models: []string{
|
||||||
|
"llama3.2", "llama3.1", "qwen2.5", "mistral", "gemma2", "phi4",
|
||||||
|
},
|
||||||
|
New: func(_ string, opts ...ClientOption) *Client { return Ollama(opts...) },
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Providers returns a copy of the registered provider list so callers cannot
|
||||||
|
// mutate library state.
|
||||||
|
func Providers() []ProviderInfo {
|
||||||
|
out := make([]ProviderInfo, len(providerRegistry))
|
||||||
|
copy(out, providerRegistry)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderByName returns the registered ProviderInfo with the given name, or
|
||||||
|
// nil if no such provider is registered. Name matching is exact.
|
||||||
|
func ProviderByName(name string) *ProviderInfo {
|
||||||
|
for i := range providerRegistry {
|
||||||
|
if providerRegistry[i].Name == name {
|
||||||
|
p := providerRegistry[i]
|
||||||
|
return &p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
// Package xai implements the go-llm v2 provider interface for xAI (Grok,
|
||||||
|
// https://x.ai/api). xAI speaks OpenAI Chat Completions, so this package is a
|
||||||
|
// thin wrapper over openaicompat with its own defaults and per-model Rules.
|
||||||
|
package xai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultBaseURL is the public xAI API endpoint.
|
||||||
|
const DefaultBaseURL = "https://api.x.ai/v1"
|
||||||
|
|
||||||
|
// Provider is a type alias over openaicompat.Provider.
|
||||||
|
type Provider = openaicompat.Provider
|
||||||
|
|
||||||
|
// New creates a new xAI provider. An empty baseURL uses DefaultBaseURL.
|
||||||
|
func New(apiKey, baseURL string) *Provider {
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = DefaultBaseURL
|
||||||
|
}
|
||||||
|
return openaicompat.New(apiKey, baseURL, openaicompat.Rules{
|
||||||
|
// Grok models whose name contains "vision" accept images; others don't.
|
||||||
|
SupportsVision: func(m string) bool {
|
||||||
|
return strings.Contains(m, "vision")
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
package xai_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/xai"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNew_Basic(t *testing.T) {
|
||||||
|
if p := xai.New("key", ""); p == nil {
|
||||||
|
t.Fatal("New returned nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRules_Grok2RejectsImages(t *testing.T) {
|
||||||
|
p := xai.New("key", "")
|
||||||
|
req := provider.Request{
|
||||||
|
Model: "grok-2",
|
||||||
|
Messages: []provider.Message{{
|
||||||
|
Role: "user",
|
||||||
|
Images: []provider.Image{{URL: "a"}},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
_, err := p.Complete(context.Background(), req)
|
||||||
|
var fue *openaicompat.FeatureUnsupportedError
|
||||||
|
if !errors.As(err, &fue) || fue.Feature != "vision" {
|
||||||
|
t.Fatalf("want FeatureUnsupportedError(vision), got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user