Compare commits
17 Commits
1927f4d187
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 7e1705c385 | |||
| fc2218b5fe | |||
| 23c9068022 | |||
| 87ec56a2be | |||
| be572a76f4 | |||
| 6a7eeef619 | |||
| cbe340ced0 | |||
| 9e288954f2 | |||
| 9d6d2c61c3 | |||
| a4cb4baab5 | |||
| 85a848d96e | |||
| 8801ce5945 | |||
| 9c1b4f7e9f | |||
| 2cf75ae07d | |||
| 97d54c10ae | |||
| bf7c86ab2a | |||
| be99af3597 |
76
.gitea/workflows/ci.yaml
Normal file
76
.gitea/workflows/ci.yaml
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: ["*"]
|
||||||
|
pull_request:
|
||||||
|
branches: ["*"]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
root-module:
|
||||||
|
name: Root Module
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.24"
|
||||||
|
|
||||||
|
- name: Download dependencies
|
||||||
|
run: go mod download
|
||||||
|
|
||||||
|
- name: Build
|
||||||
|
run: go build ./...
|
||||||
|
|
||||||
|
- name: Vet
|
||||||
|
run: go vet ./...
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: go test -race -count=1 ./...
|
||||||
|
|
||||||
|
v2-module:
|
||||||
|
name: V2 Module
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
working-directory: v2
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.24"
|
||||||
|
|
||||||
|
- name: Download dependencies
|
||||||
|
run: go mod download
|
||||||
|
|
||||||
|
- name: Build
|
||||||
|
run: go build ./...
|
||||||
|
|
||||||
|
- name: Vet
|
||||||
|
run: go vet ./...
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: go test -race -count=1 ./...
|
||||||
|
|
||||||
|
lint:
|
||||||
|
name: Lint
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.24"
|
||||||
|
|
||||||
|
- name: Check root module tidiness
|
||||||
|
run: |
|
||||||
|
go mod tidy
|
||||||
|
git diff --exit-code go.mod go.sum
|
||||||
|
|
||||||
|
- name: Check v2 module tidiness
|
||||||
|
run: |
|
||||||
|
cd v2
|
||||||
|
go mod tidy
|
||||||
|
git diff --exit-code go.mod go.sum
|
||||||
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
.claude
|
||||||
|
.idea
|
||||||
|
*.exe
|
||||||
|
.env
|
||||||
88
CLAUDE.md
Normal file
88
CLAUDE.md
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
# CLAUDE.md for go-llm
|
||||||
|
|
||||||
|
## Build and Test Commands
|
||||||
|
- Build project: `go build ./...`
|
||||||
|
- Run all tests: `go test ./...`
|
||||||
|
- Run specific test: `go test -v -run <TestName> ./...`
|
||||||
|
- Tidy dependencies: `go mod tidy`
|
||||||
|
|
||||||
|
## Code Style Guidelines
|
||||||
|
- **Indentation**: Use standard Go tabs for indentation.
|
||||||
|
- **Naming**:
|
||||||
|
- Use `camelCase` for internal/private variables and functions.
|
||||||
|
- Use `PascalCase` for exported types, functions, and struct fields.
|
||||||
|
- Interface names should be concise (e.g., `LLM`, `ChatCompletion`).
|
||||||
|
- **Error Handling**:
|
||||||
|
- Always check and handle errors immediately.
|
||||||
|
- Wrap errors with context using `fmt.Errorf("%w: ...", err)`.
|
||||||
|
- Use the project's internal `Error` struct in `error.go` when differentiating between error types is needed.
|
||||||
|
- **Project Structure**:
|
||||||
|
- `llm.go`: Contains core interfaces (`LLM`, `ChatCompletion`) and shared types (`Message`, `Role`, `Image`).
|
||||||
|
- Provider implementations are in `openai.go`, `anthropic.go`, and `google.go`.
|
||||||
|
- Schema definitions for tool calling are in the `schema/` directory.
|
||||||
|
- `mcp.go`: MCP (Model Context Protocol) client integration for connecting to MCP servers.
|
||||||
|
- **Imports**: Organize imports into groups: standard library, then third-party libraries.
|
||||||
|
- **Documentation**: Use standard Go doc comments for exported symbols.
|
||||||
|
- **README.md**: The README.md file should always be kept up to date with any significant changes to the project.
|
||||||
|
|
||||||
|
## CLI Tool
|
||||||
|
- Build CLI: `go build ./cmd/llm`
|
||||||
|
- Run CLI: `./llm` (or `llm.exe` on Windows)
|
||||||
|
- Run without building: `go run ./cmd/llm`
|
||||||
|
|
||||||
|
### CLI Features
|
||||||
|
- Interactive TUI for testing all go-llm features
|
||||||
|
- Support for OpenAI, Anthropic, and Google providers
|
||||||
|
- Image input (file path, URL, or base64)
|
||||||
|
- Tool/function calling with demo tools
|
||||||
|
- Temperature control and settings
|
||||||
|
|
||||||
|
### Key Bindings
|
||||||
|
- `Enter` - Send message
|
||||||
|
- `Ctrl+I` - Add image
|
||||||
|
- `Ctrl+T` - Toggle tools panel
|
||||||
|
- `Ctrl+P` - Change provider
|
||||||
|
- `Ctrl+M` - Change model
|
||||||
|
- `Ctrl+S` - Settings
|
||||||
|
- `Ctrl+N` - New conversation
|
||||||
|
- `Esc` - Exit/Cancel
|
||||||
|
|
||||||
|
## MCP (Model Context Protocol) Support
|
||||||
|
|
||||||
|
The library supports connecting to MCP servers to use their tools. MCP servers can be connected via:
|
||||||
|
- **stdio**: Run a command as a subprocess
|
||||||
|
- **sse**: Connect to an SSE endpoint
|
||||||
|
- **http**: Connect to a streamable HTTP endpoint
|
||||||
|
|
||||||
|
### Usage Example
|
||||||
|
```go
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create and connect to an MCP server
|
||||||
|
server := &llm.MCPServer{
|
||||||
|
Name: "my-server",
|
||||||
|
Command: "my-mcp-server",
|
||||||
|
Args: []string{"--some-flag"},
|
||||||
|
}
|
||||||
|
if err := server.Connect(ctx); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Add the server to a toolbox
|
||||||
|
toolbox := llm.NewToolBox().WithMCPServer(server)
|
||||||
|
|
||||||
|
// Use the toolbox in requests - MCP tools are automatically available
|
||||||
|
req := llm.Request{
|
||||||
|
Messages: []llm.Message{{Role: llm.RoleUser, Text: "Use the MCP tool"}},
|
||||||
|
Toolbox: toolbox,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### MCPServer Options
|
||||||
|
- `Name`: Friendly name for logging
|
||||||
|
- `Command`: Command to run (for stdio transport)
|
||||||
|
- `Args`: Command arguments
|
||||||
|
- `Env`: Additional environment variables
|
||||||
|
- `URL`: Endpoint URL (for sse/http transport)
|
||||||
|
- `Transport`: "stdio" (default), "sse", or "http"
|
||||||
20
anthropic.go
20
anthropic.go
@@ -1,4 +1,4 @@
|
|||||||
package go_llm
|
package llm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -10,19 +10,19 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"gitea.stevedudenhoeffer.com/steve/go-llm/utils"
|
"gitea.stevedudenhoeffer.com/steve/go-llm/internal/imageutil"
|
||||||
|
|
||||||
anth "github.com/liushuangls/go-anthropic/v2"
|
anth "github.com/liushuangls/go-anthropic/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type anthropic struct {
|
type anthropicImpl struct {
|
||||||
key string
|
key string
|
||||||
model string
|
model string
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ LLM = anthropic{}
|
var _ LLM = anthropicImpl{}
|
||||||
|
|
||||||
func (a anthropic) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
func (a anthropicImpl) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
||||||
a.model = modelVersion
|
a.model = modelVersion
|
||||||
|
|
||||||
// TODO: model verification?
|
// TODO: model verification?
|
||||||
@@ -36,7 +36,7 @@ func deferClose(c io.Closer) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
|
func (a anthropicImpl) requestToAnthropicRequest(req Request) anth.MessagesRequest {
|
||||||
res := anth.MessagesRequest{
|
res := anth.MessagesRequest{
|
||||||
Model: anth.Model(a.model),
|
Model: anth.Model(a.model),
|
||||||
MaxTokens: 1000,
|
MaxTokens: 1000,
|
||||||
@@ -90,7 +90,7 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
|
|||||||
// Check if image size exceeds 5MiB (5242880 bytes)
|
// Check if image size exceeds 5MiB (5242880 bytes)
|
||||||
if len(raw) >= 5242880 {
|
if len(raw) >= 5242880 {
|
||||||
|
|
||||||
compressed, mime, err := utils.CompressImage(img.Base64, 5*1024*1024)
|
compressed, mime, err := imageutil.CompressImage(img.Base64, 5*1024*1024)
|
||||||
|
|
||||||
// just replace the image with the compressed one
|
// just replace the image with the compressed one
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -157,7 +157,7 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tool := range req.Toolbox.functions {
|
for _, tool := range req.Toolbox.Functions() {
|
||||||
res.Tools = append(res.Tools, anth.ToolDefinition{
|
res.Tools = append(res.Tools, anth.ToolDefinition{
|
||||||
Name: tool.Name,
|
Name: tool.Name,
|
||||||
Description: tool.Description,
|
Description: tool.Description,
|
||||||
@@ -177,7 +177,7 @@ func (a anthropic) requestToAnthropicRequest(req Request) anth.MessagesRequest {
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a anthropic) responseToLLMResponse(in anth.MessagesResponse) Response {
|
func (a anthropicImpl) responseToLLMResponse(in anth.MessagesResponse) Response {
|
||||||
choice := ResponseChoice{}
|
choice := ResponseChoice{}
|
||||||
for _, msg := range in.Content {
|
for _, msg := range in.Content {
|
||||||
|
|
||||||
@@ -212,7 +212,7 @@ func (a anthropic) responseToLLMResponse(in anth.MessagesResponse) Response {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a anthropic) ChatComplete(ctx context.Context, req Request) (Response, error) {
|
func (a anthropicImpl) ChatComplete(ctx context.Context, req Request) (Response, error) {
|
||||||
cl := anth.NewClient(a.key)
|
cl := anth.NewClient(a.key)
|
||||||
|
|
||||||
res, err := cl.CreateMessages(ctx, a.requestToAnthropicRequest(req))
|
res, err := cl.CreateMessages(ctx, a.requestToAnthropicRequest(req))
|
||||||
|
|||||||
11
cmd/llm/.env.example
Normal file
11
cmd/llm/.env.example
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
# go-llm CLI Environment Variables
|
||||||
|
# Copy this file to .env and fill in your API keys
|
||||||
|
|
||||||
|
# OpenAI API Key (https://platform.openai.com/api-keys)
|
||||||
|
OPENAI_API_KEY=
|
||||||
|
|
||||||
|
# Anthropic API Key (https://console.anthropic.com/settings/keys)
|
||||||
|
ANTHROPIC_API_KEY=
|
||||||
|
|
||||||
|
# Google AI API Key (https://aistudio.google.com/apikey)
|
||||||
|
GOOGLE_API_KEY=
|
||||||
182
cmd/llm/commands.go
Normal file
182
cmd/llm/commands.go
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
|
||||||
|
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Message types for async operations
|
||||||
|
|
||||||
|
// ChatResponseMsg contains the response from a chat completion
|
||||||
|
type ChatResponseMsg struct {
|
||||||
|
Response llm.Response
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolExecutionMsg contains results from tool execution
|
||||||
|
type ToolExecutionMsg struct {
|
||||||
|
Results []llm.ToolCallResponse
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageLoadedMsg contains a loaded image
|
||||||
|
type ImageLoadedMsg struct {
|
||||||
|
Image llm.Image
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendChatRequest sends a chat completion request
|
||||||
|
func sendChatRequest(chat llm.ChatCompletion, req llm.Request) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
resp, err := chat.ChatComplete(context.Background(), req)
|
||||||
|
return ChatResponseMsg{Response: resp, Err: err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeTools executes tool calls and returns results
|
||||||
|
func executeTools(toolbox llm.ToolBox, req llm.Request, resp llm.ResponseChoice) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
ctx := llm.NewContext(context.Background(), req, &resp, nil)
|
||||||
|
var results []llm.ToolCallResponse
|
||||||
|
|
||||||
|
for _, call := range resp.Calls {
|
||||||
|
result, err := toolbox.Execute(ctx, call)
|
||||||
|
results = append(results, llm.ToolCallResponse{
|
||||||
|
ID: call.ID,
|
||||||
|
Result: result,
|
||||||
|
Error: err,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return ToolExecutionMsg{Results: results, Err: nil}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadImageFromPath loads an image from a file path
|
||||||
|
func loadImageFromPath(path string) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
// Clean up the path
|
||||||
|
path = strings.TrimSpace(path)
|
||||||
|
path = strings.Trim(path, "\"'")
|
||||||
|
|
||||||
|
// Read the file
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return ImageLoadedMsg{Err: fmt.Errorf("failed to read image file: %w", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect content type
|
||||||
|
contentType := http.DetectContentType(data)
|
||||||
|
if !strings.HasPrefix(contentType, "image/") {
|
||||||
|
return ImageLoadedMsg{Err: fmt.Errorf("file is not an image: %s", contentType)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Base64 encode
|
||||||
|
encoded := base64.StdEncoding.EncodeToString(data)
|
||||||
|
|
||||||
|
return ImageLoadedMsg{
|
||||||
|
Image: llm.Image{
|
||||||
|
Base64: encoded,
|
||||||
|
ContentType: contentType,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadImageFromURL loads an image from a URL
|
||||||
|
func loadImageFromURL(url string) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
url = strings.TrimSpace(url)
|
||||||
|
|
||||||
|
// For URL images, we can just use the URL directly
|
||||||
|
return ImageLoadedMsg{
|
||||||
|
Image: llm.Image{
|
||||||
|
Url: url,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadImageFromBase64 loads an image from base64 data
|
||||||
|
func loadImageFromBase64(data string) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
data = strings.TrimSpace(data)
|
||||||
|
|
||||||
|
// Check if it's a data URL
|
||||||
|
if strings.HasPrefix(data, "data:") {
|
||||||
|
// Parse data URL: data:image/png;base64,....
|
||||||
|
parts := strings.SplitN(data, ",", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return ImageLoadedMsg{Err: fmt.Errorf("invalid data URL format")}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract content type from first part
|
||||||
|
mediaType := strings.TrimPrefix(parts[0], "data:")
|
||||||
|
mediaType = strings.TrimSuffix(mediaType, ";base64")
|
||||||
|
|
||||||
|
return ImageLoadedMsg{
|
||||||
|
Image: llm.Image{
|
||||||
|
Base64: parts[1],
|
||||||
|
ContentType: mediaType,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assume it's raw base64, try to detect content type
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||||
|
if err != nil {
|
||||||
|
return ImageLoadedMsg{Err: fmt.Errorf("invalid base64 data: %w", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType := http.DetectContentType(decoded)
|
||||||
|
if !strings.HasPrefix(contentType, "image/") {
|
||||||
|
return ImageLoadedMsg{Err: fmt.Errorf("data is not an image: %s", contentType)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ImageLoadedMsg{
|
||||||
|
Image: llm.Image{
|
||||||
|
Base64: data,
|
||||||
|
ContentType: contentType,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildRequest builds a chat request from the current state
|
||||||
|
func buildRequest(m *Model, userText string) llm.Request {
|
||||||
|
// Create the user message with any pending images
|
||||||
|
userMsg := llm.Message{
|
||||||
|
Role: llm.RoleUser,
|
||||||
|
Text: userText,
|
||||||
|
Images: m.pendingImages,
|
||||||
|
}
|
||||||
|
|
||||||
|
req := llm.Request{
|
||||||
|
Conversation: m.conversation,
|
||||||
|
Messages: []llm.Message{
|
||||||
|
{Role: llm.RoleSystem, Text: m.systemPrompt},
|
||||||
|
userMsg,
|
||||||
|
},
|
||||||
|
Temperature: m.temperature,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add toolbox if enabled
|
||||||
|
if m.toolsEnabled && len(m.toolbox.Functions()) > 0 {
|
||||||
|
req.Toolbox = m.toolbox.WithRequireTool(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildFollowUpRequest builds a follow-up request after tool execution
|
||||||
|
func buildFollowUpRequest(m *Model, previousReq llm.Request, resp llm.ResponseChoice, toolResults []llm.ToolCallResponse) llm.Request {
|
||||||
|
return previousReq.NextRequest(resp, toolResults)
|
||||||
|
}
|
||||||
25
cmd/llm/main.go
Normal file
25
cmd/llm/main.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
"github.com/joho/godotenv"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Load .env file if it exists (ignore error if not found)
|
||||||
|
_ = godotenv.Load()
|
||||||
|
|
||||||
|
p := tea.NewProgram(
|
||||||
|
InitialModel(),
|
||||||
|
tea.WithAltScreen(),
|
||||||
|
tea.WithMouseCellMotion(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if _, err := p.Run(); err != nil {
|
||||||
|
fmt.Printf("Error running program: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
295
cmd/llm/model.go
Normal file
295
cmd/llm/model.go
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/charmbracelet/bubbles/textinput"
|
||||||
|
"github.com/charmbracelet/bubbles/viewport"
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
|
||||||
|
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// State represents the current view/screen of the application
|
||||||
|
type State int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StateChat State = iota
|
||||||
|
StateProviderSelect
|
||||||
|
StateModelSelect
|
||||||
|
StateImageInput
|
||||||
|
StateToolsPanel
|
||||||
|
StateSettings
|
||||||
|
StateAPIKeyInput
|
||||||
|
)
|
||||||
|
|
||||||
|
// DisplayMessage represents a message for display in the UI
|
||||||
|
type DisplayMessage struct {
|
||||||
|
Role llm.Role
|
||||||
|
Content string
|
||||||
|
Images int // number of images attached
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderInfo contains information about a provider
|
||||||
|
type ProviderInfo struct {
|
||||||
|
Name string
|
||||||
|
EnvVar string
|
||||||
|
Models []string
|
||||||
|
HasAPIKey bool
|
||||||
|
ModelIndex int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model is the main Bubble Tea model
|
||||||
|
type Model struct {
|
||||||
|
// State
|
||||||
|
state State
|
||||||
|
previousState State
|
||||||
|
|
||||||
|
// Provider
|
||||||
|
provider llm.LLM
|
||||||
|
providerName string
|
||||||
|
chat llm.ChatCompletion
|
||||||
|
modelName string
|
||||||
|
apiKeys map[string]string
|
||||||
|
providers []ProviderInfo
|
||||||
|
providerIndex int
|
||||||
|
|
||||||
|
// Conversation
|
||||||
|
conversation []llm.Input
|
||||||
|
messages []DisplayMessage
|
||||||
|
|
||||||
|
// Tools
|
||||||
|
toolbox llm.ToolBox
|
||||||
|
toolsEnabled bool
|
||||||
|
|
||||||
|
// Settings
|
||||||
|
systemPrompt string
|
||||||
|
temperature *float64
|
||||||
|
|
||||||
|
// Pending images
|
||||||
|
pendingImages []llm.Image
|
||||||
|
|
||||||
|
// UI Components
|
||||||
|
input textinput.Model
|
||||||
|
viewport viewport.Model
|
||||||
|
viewportReady bool
|
||||||
|
|
||||||
|
// Selection state (for lists)
|
||||||
|
listIndex int
|
||||||
|
listItems []string
|
||||||
|
|
||||||
|
// Dimensions
|
||||||
|
width int
|
||||||
|
height int
|
||||||
|
|
||||||
|
// Loading state
|
||||||
|
loading bool
|
||||||
|
err error
|
||||||
|
|
||||||
|
// For API key input
|
||||||
|
apiKeyInput textinput.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitialModel creates and returns the initial model
|
||||||
|
func InitialModel() Model {
|
||||||
|
ti := textinput.New()
|
||||||
|
ti.Placeholder = "Type your message..."
|
||||||
|
ti.Focus()
|
||||||
|
ti.CharLimit = 4096
|
||||||
|
ti.Width = 60
|
||||||
|
|
||||||
|
aki := textinput.New()
|
||||||
|
aki.Placeholder = "Enter API key..."
|
||||||
|
aki.CharLimit = 256
|
||||||
|
aki.Width = 60
|
||||||
|
aki.EchoMode = textinput.EchoPassword
|
||||||
|
|
||||||
|
// Initialize providers with environment variable checks
|
||||||
|
providers := []ProviderInfo{
|
||||||
|
{
|
||||||
|
Name: "OpenAI",
|
||||||
|
EnvVar: "OPENAI_API_KEY",
|
||||||
|
Models: []string{
|
||||||
|
"gpt-4.1",
|
||||||
|
"gpt-4.1-mini",
|
||||||
|
"gpt-4.1-nano",
|
||||||
|
"gpt-4o",
|
||||||
|
"gpt-4o-mini",
|
||||||
|
"gpt-4-turbo",
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
"o1",
|
||||||
|
"o1-mini",
|
||||||
|
"o1-preview",
|
||||||
|
"o3-mini",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Anthropic",
|
||||||
|
EnvVar: "ANTHROPIC_API_KEY",
|
||||||
|
Models: []string{
|
||||||
|
"claude-sonnet-4-20250514",
|
||||||
|
"claude-opus-4-20250514",
|
||||||
|
"claude-3-7-sonnet-20250219",
|
||||||
|
"claude-3-5-sonnet-20241022",
|
||||||
|
"claude-3-5-haiku-20241022",
|
||||||
|
"claude-3-opus-20240229",
|
||||||
|
"claude-3-sonnet-20240229",
|
||||||
|
"claude-3-haiku-20240307",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Google",
|
||||||
|
EnvVar: "GOOGLE_API_KEY",
|
||||||
|
Models: []string{
|
||||||
|
"gemini-2.0-flash",
|
||||||
|
"gemini-2.0-flash-lite",
|
||||||
|
"gemini-1.5-pro",
|
||||||
|
"gemini-1.5-flash",
|
||||||
|
"gemini-1.5-flash-8b",
|
||||||
|
"gemini-1.0-pro",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for API keys in environment
|
||||||
|
apiKeys := make(map[string]string)
|
||||||
|
for i := range providers {
|
||||||
|
if key := os.Getenv(providers[i].EnvVar); key != "" {
|
||||||
|
apiKeys[providers[i].Name] = key
|
||||||
|
providers[i].HasAPIKey = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m := Model{
|
||||||
|
state: StateProviderSelect,
|
||||||
|
input: ti,
|
||||||
|
apiKeyInput: aki,
|
||||||
|
apiKeys: apiKeys,
|
||||||
|
providers: providers,
|
||||||
|
systemPrompt: "You are a helpful assistant.",
|
||||||
|
toolbox: createDemoToolbox(),
|
||||||
|
toolsEnabled: false,
|
||||||
|
messages: []DisplayMessage{},
|
||||||
|
conversation: []llm.Input{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build list items for provider selection
|
||||||
|
m.listItems = make([]string, len(providers))
|
||||||
|
for i, p := range providers {
|
||||||
|
status := " (no key)"
|
||||||
|
if p.HasAPIKey {
|
||||||
|
status = " (ready)"
|
||||||
|
}
|
||||||
|
m.listItems[i] = p.Name + status
|
||||||
|
}
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init initializes the model
|
||||||
|
func (m Model) Init() tea.Cmd {
|
||||||
|
return textinput.Blink
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectProvider sets up the selected provider
|
||||||
|
func (m *Model) selectProvider(index int) error {
|
||||||
|
if index < 0 || index >= len(m.providers) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
p := m.providers[index]
|
||||||
|
key, ok := m.apiKeys[p.Name]
|
||||||
|
if !ok || key == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.providerName = p.Name
|
||||||
|
m.providerIndex = index
|
||||||
|
|
||||||
|
switch p.Name {
|
||||||
|
case "OpenAI":
|
||||||
|
m.provider = llm.OpenAI(key)
|
||||||
|
case "Anthropic":
|
||||||
|
m.provider = llm.Anthropic(key)
|
||||||
|
case "Google":
|
||||||
|
m.provider = llm.Google(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select default model
|
||||||
|
if len(p.Models) > 0 {
|
||||||
|
return m.selectModel(p.ModelIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectModel sets the current model
|
||||||
|
func (m *Model) selectModel(index int) error {
|
||||||
|
if m.provider == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
p := m.providers[m.providerIndex]
|
||||||
|
if index < 0 || index >= len(p.Models) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
modelName := p.Models[index]
|
||||||
|
chat, err := m.provider.ModelVersion(modelName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.chat = chat
|
||||||
|
m.modelName = modelName
|
||||||
|
m.providers[m.providerIndex].ModelIndex = index
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// newConversation resets the conversation
|
||||||
|
func (m *Model) newConversation() {
|
||||||
|
m.conversation = []llm.Input{}
|
||||||
|
m.messages = []DisplayMessage{}
|
||||||
|
m.pendingImages = []llm.Image{}
|
||||||
|
m.err = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addUserMessage adds a user message to the conversation
|
||||||
|
func (m *Model) addUserMessage(text string, images []llm.Image) {
|
||||||
|
msg := llm.Message{
|
||||||
|
Role: llm.RoleUser,
|
||||||
|
Text: text,
|
||||||
|
Images: images,
|
||||||
|
}
|
||||||
|
m.conversation = append(m.conversation, msg)
|
||||||
|
m.messages = append(m.messages, DisplayMessage{
|
||||||
|
Role: llm.RoleUser,
|
||||||
|
Content: text,
|
||||||
|
Images: len(images),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// addAssistantMessage adds an assistant message to the conversation
|
||||||
|
func (m *Model) addAssistantMessage(content string) {
|
||||||
|
m.messages = append(m.messages, DisplayMessage{
|
||||||
|
Role: llm.RoleAssistant,
|
||||||
|
Content: content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// addToolCallMessage adds a tool call message to display
|
||||||
|
func (m *Model) addToolCallMessage(name string, args string) {
|
||||||
|
m.messages = append(m.messages, DisplayMessage{
|
||||||
|
Role: llm.Role("tool_call"),
|
||||||
|
Content: name + ": " + args,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// addToolResultMessage adds a tool result message to display
|
||||||
|
func (m *Model) addToolResultMessage(name string, result string) {
|
||||||
|
m.messages = append(m.messages, DisplayMessage{
|
||||||
|
Role: llm.Role("tool_result"),
|
||||||
|
Content: name + " -> " + result,
|
||||||
|
})
|
||||||
|
}
|
||||||
113
cmd/llm/styles.go
Normal file
113
cmd/llm/styles.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// Colors
|
||||||
|
primaryColor = lipgloss.Color("205")
|
||||||
|
secondaryColor = lipgloss.Color("39")
|
||||||
|
accentColor = lipgloss.Color("212")
|
||||||
|
mutedColor = lipgloss.Color("241")
|
||||||
|
errorColor = lipgloss.Color("196")
|
||||||
|
successColor = lipgloss.Color("82")
|
||||||
|
|
||||||
|
// App styles
|
||||||
|
appStyle = lipgloss.NewStyle().Padding(1, 2)
|
||||||
|
|
||||||
|
// Header
|
||||||
|
headerStyle = lipgloss.NewStyle().
|
||||||
|
Bold(true).
|
||||||
|
Foreground(primaryColor).
|
||||||
|
BorderStyle(lipgloss.NormalBorder()).
|
||||||
|
BorderBottom(true).
|
||||||
|
BorderForeground(mutedColor).
|
||||||
|
Padding(0, 1)
|
||||||
|
|
||||||
|
// Provider badge
|
||||||
|
providerBadgeStyle = lipgloss.NewStyle().
|
||||||
|
Background(secondaryColor).
|
||||||
|
Foreground(lipgloss.Color("0")).
|
||||||
|
Padding(0, 1).
|
||||||
|
Bold(true)
|
||||||
|
|
||||||
|
// Messages
|
||||||
|
systemMsgStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(mutedColor).
|
||||||
|
Italic(true).
|
||||||
|
Padding(0, 1)
|
||||||
|
|
||||||
|
userMsgStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(secondaryColor).
|
||||||
|
Padding(0, 1)
|
||||||
|
|
||||||
|
assistantMsgStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("255")).
|
||||||
|
Padding(0, 1)
|
||||||
|
|
||||||
|
roleLabelStyle = lipgloss.NewStyle().
|
||||||
|
Bold(true).
|
||||||
|
Width(12)
|
||||||
|
|
||||||
|
// Tool calls
|
||||||
|
toolCallStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(accentColor).
|
||||||
|
Italic(true).
|
||||||
|
Padding(0, 1)
|
||||||
|
|
||||||
|
toolResultStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(successColor).
|
||||||
|
Padding(0, 1)
|
||||||
|
|
||||||
|
// Input area
|
||||||
|
inputStyle = lipgloss.NewStyle().
|
||||||
|
BorderStyle(lipgloss.RoundedBorder()).
|
||||||
|
BorderForeground(primaryColor).
|
||||||
|
Padding(0, 1)
|
||||||
|
|
||||||
|
inputHelpStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(mutedColor).
|
||||||
|
Italic(true)
|
||||||
|
|
||||||
|
// Error
|
||||||
|
errorStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(errorColor).
|
||||||
|
Bold(true)
|
||||||
|
|
||||||
|
// Loading
|
||||||
|
loadingStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(accentColor).
|
||||||
|
Italic(true)
|
||||||
|
|
||||||
|
// List selection
|
||||||
|
selectedItemStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(primaryColor).
|
||||||
|
Bold(true)
|
||||||
|
|
||||||
|
normalItemStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("255"))
|
||||||
|
|
||||||
|
// Settings panel
|
||||||
|
settingLabelStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(secondaryColor).
|
||||||
|
Width(15)
|
||||||
|
|
||||||
|
settingValueStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("255"))
|
||||||
|
|
||||||
|
// Help text
|
||||||
|
helpStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(mutedColor).
|
||||||
|
Padding(1, 0)
|
||||||
|
|
||||||
|
// Image indicator
|
||||||
|
imageIndicatorStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(accentColor).
|
||||||
|
Bold(true)
|
||||||
|
|
||||||
|
// Viewport
|
||||||
|
viewportStyle = lipgloss.NewStyle().
|
||||||
|
BorderStyle(lipgloss.NormalBorder()).
|
||||||
|
BorderForeground(mutedColor)
|
||||||
|
)
|
||||||
105
cmd/llm/tools.go
Normal file
105
cmd/llm/tools.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TimeParams is the parameter struct for the GetTime function
|
||||||
|
type TimeParams struct{}
|
||||||
|
|
||||||
|
// GetTime returns the current time
|
||||||
|
func GetTime(_ *llm.Context, _ TimeParams) (any, error) {
|
||||||
|
return time.Now().Format("Monday, January 2, 2006 3:04:05 PM MST"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CalcParams is the parameter struct for the Calculate function
|
||||||
|
type CalcParams struct {
|
||||||
|
A float64 `json:"a" description:"First number"`
|
||||||
|
B float64 `json:"b" description:"Second number"`
|
||||||
|
Op string `json:"op" description:"Operation: add, subtract, multiply, divide, power, sqrt, mod"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate performs basic math operations
|
||||||
|
func Calculate(_ *llm.Context, params CalcParams) (any, error) {
|
||||||
|
switch strings.ToLower(params.Op) {
|
||||||
|
case "add", "+":
|
||||||
|
return params.A + params.B, nil
|
||||||
|
case "subtract", "sub", "-":
|
||||||
|
return params.A - params.B, nil
|
||||||
|
case "multiply", "mul", "*":
|
||||||
|
return params.A * params.B, nil
|
||||||
|
case "divide", "div", "/":
|
||||||
|
if params.B == 0 {
|
||||||
|
return nil, fmt.Errorf("division by zero")
|
||||||
|
}
|
||||||
|
return params.A / params.B, nil
|
||||||
|
case "power", "pow", "^":
|
||||||
|
return math.Pow(params.A, params.B), nil
|
||||||
|
case "sqrt":
|
||||||
|
if params.A < 0 {
|
||||||
|
return nil, fmt.Errorf("cannot take square root of negative number")
|
||||||
|
}
|
||||||
|
return math.Sqrt(params.A), nil
|
||||||
|
case "mod", "%":
|
||||||
|
return math.Mod(params.A, params.B), nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown operation: %s", params.Op)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WeatherParams is the parameter struct for the GetWeather function
|
||||||
|
type WeatherParams struct {
|
||||||
|
Location string `json:"location" description:"City name or location"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWeather returns mock weather data (for demo purposes)
|
||||||
|
func GetWeather(_ *llm.Context, params WeatherParams) (any, error) {
|
||||||
|
// This is a demo function - returns mock data
|
||||||
|
weathers := []string{"sunny", "cloudy", "rainy", "partly cloudy", "windy"}
|
||||||
|
temps := []int{65, 72, 58, 80, 45}
|
||||||
|
|
||||||
|
// Use location string to deterministically pick weather
|
||||||
|
idx := len(params.Location) % len(weathers)
|
||||||
|
|
||||||
|
return map[string]any{
|
||||||
|
"location": params.Location,
|
||||||
|
"temperature": strconv.Itoa(temps[idx]) + "F",
|
||||||
|
"condition": weathers[idx],
|
||||||
|
"humidity": "45%",
|
||||||
|
"note": "This is mock data for demonstration purposes",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RandomNumberParams is the parameter struct for the RandomNumber function
|
||||||
|
type RandomNumberParams struct {
|
||||||
|
Min int `json:"min" description:"Minimum value (inclusive)"`
|
||||||
|
Max int `json:"max" description:"Maximum value (inclusive)"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RandomNumber generates a pseudo-random number (using current time nanoseconds)
|
||||||
|
func RandomNumber(_ *llm.Context, params RandomNumberParams) (any, error) {
|
||||||
|
if params.Min > params.Max {
|
||||||
|
return nil, fmt.Errorf("min cannot be greater than max")
|
||||||
|
}
|
||||||
|
// Simple pseudo-random using time
|
||||||
|
n := time.Now().UnixNano()
|
||||||
|
rangeSize := params.Max - params.Min + 1
|
||||||
|
result := params.Min + int(n%int64(rangeSize))
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createDemoToolbox creates a toolbox with demo tools for testing
|
||||||
|
func createDemoToolbox() llm.ToolBox {
|
||||||
|
return llm.NewToolBox(
|
||||||
|
llm.NewFunction("get_time", "Get the current date and time", GetTime),
|
||||||
|
llm.NewFunction("calculate", "Perform basic math operations (add, subtract, multiply, divide, power, sqrt, mod)", Calculate),
|
||||||
|
llm.NewFunction("get_weather", "Get weather information for a location (demo data)", GetWeather),
|
||||||
|
llm.NewFunction("random_number", "Generate a random number between min and max", RandomNumber),
|
||||||
|
)
|
||||||
|
}
|
||||||
435
cmd/llm/update.go
Normal file
435
cmd/llm/update.go
Normal file
@@ -0,0 +1,435 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/charmbracelet/bubbles/textinput"
|
||||||
|
"github.com/charmbracelet/bubbles/viewport"
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
|
||||||
|
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// pendingRequest stores the request being processed for follow-up
|
||||||
|
var pendingRequest llm.Request
|
||||||
|
var pendingResponse llm.ResponseChoice
|
||||||
|
|
||||||
|
// Update handles messages and updates the model
|
||||||
|
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
|
var cmd tea.Cmd
|
||||||
|
var cmds []tea.Cmd
|
||||||
|
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case tea.KeyMsg:
|
||||||
|
return m.handleKeyMsg(msg)
|
||||||
|
|
||||||
|
case tea.WindowSizeMsg:
|
||||||
|
m.width = msg.Width
|
||||||
|
m.height = msg.Height
|
||||||
|
|
||||||
|
headerHeight := 3
|
||||||
|
footerHeight := 4
|
||||||
|
verticalMargins := headerHeight + footerHeight
|
||||||
|
|
||||||
|
if !m.viewportReady {
|
||||||
|
m.viewport = viewport.New(msg.Width-4, msg.Height-verticalMargins)
|
||||||
|
m.viewport.HighPerformanceRendering = false
|
||||||
|
m.viewportReady = true
|
||||||
|
} else {
|
||||||
|
m.viewport.Width = msg.Width - 4
|
||||||
|
m.viewport.Height = msg.Height - verticalMargins
|
||||||
|
}
|
||||||
|
|
||||||
|
m.input.Width = msg.Width - 6
|
||||||
|
m.apiKeyInput.Width = msg.Width - 6
|
||||||
|
|
||||||
|
m.viewport.SetContent(m.renderMessages())
|
||||||
|
|
||||||
|
case ChatResponseMsg:
|
||||||
|
m.loading = false
|
||||||
|
if msg.Err != nil {
|
||||||
|
m.err = msg.Err
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msg.Response.Choices) == 0 {
|
||||||
|
m.err = fmt.Errorf("no response choices returned")
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
choice := msg.Response.Choices[0]
|
||||||
|
|
||||||
|
// Check for tool calls
|
||||||
|
if len(choice.Calls) > 0 && m.toolsEnabled {
|
||||||
|
// Store for follow-up
|
||||||
|
pendingResponse = choice
|
||||||
|
|
||||||
|
// Add assistant's response to conversation if there's content
|
||||||
|
if choice.Content != "" {
|
||||||
|
m.addAssistantMessage(choice.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Display tool calls
|
||||||
|
for _, call := range choice.Calls {
|
||||||
|
m.addToolCallMessage(call.FunctionCall.Name, call.FunctionCall.Arguments)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.viewport.SetContent(m.renderMessages())
|
||||||
|
m.viewport.GotoBottom()
|
||||||
|
|
||||||
|
// Execute tools
|
||||||
|
m.loading = true
|
||||||
|
return m, executeTools(m.toolbox, pendingRequest, choice)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regular response - add to conversation and display
|
||||||
|
m.conversation = append(m.conversation, choice)
|
||||||
|
m.addAssistantMessage(choice.Content)
|
||||||
|
|
||||||
|
m.viewport.SetContent(m.renderMessages())
|
||||||
|
m.viewport.GotoBottom()
|
||||||
|
|
||||||
|
case ToolExecutionMsg:
|
||||||
|
if msg.Err != nil {
|
||||||
|
m.loading = false
|
||||||
|
m.err = msg.Err
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Display tool results
|
||||||
|
for i, result := range msg.Results {
|
||||||
|
name := pendingResponse.Calls[i].FunctionCall.Name
|
||||||
|
resultStr := fmt.Sprintf("%v", result.Result)
|
||||||
|
if result.Error != nil {
|
||||||
|
resultStr = "Error: " + result.Error.Error()
|
||||||
|
}
|
||||||
|
m.addToolResultMessage(name, resultStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add tool call responses to conversation
|
||||||
|
for _, result := range msg.Results {
|
||||||
|
m.conversation = append(m.conversation, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the assistant's response to conversation
|
||||||
|
m.conversation = append(m.conversation, pendingResponse)
|
||||||
|
|
||||||
|
m.viewport.SetContent(m.renderMessages())
|
||||||
|
m.viewport.GotoBottom()
|
||||||
|
|
||||||
|
// Send follow-up request
|
||||||
|
followUp := buildFollowUpRequest(&m, pendingRequest, pendingResponse, msg.Results)
|
||||||
|
pendingRequest = followUp
|
||||||
|
return m, sendChatRequest(m.chat, followUp)
|
||||||
|
|
||||||
|
case ImageLoadedMsg:
|
||||||
|
if msg.Err != nil {
|
||||||
|
m.err = msg.Err
|
||||||
|
m.state = m.previousState
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.pendingImages = append(m.pendingImages, msg.Image)
|
||||||
|
m.state = m.previousState
|
||||||
|
m.err = nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Update text input
|
||||||
|
if m.state == StateChat {
|
||||||
|
m.input, cmd = m.input.Update(msg)
|
||||||
|
cmds = append(cmds, cmd)
|
||||||
|
} else if m.state == StateAPIKeyInput {
|
||||||
|
m.apiKeyInput, cmd = m.apiKeyInput.Update(msg)
|
||||||
|
cmds = append(cmds, cmd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return m, tea.Batch(cmds...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleKeyMsg handles keyboard input
|
||||||
|
func (m Model) handleKeyMsg(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||||
|
// Global key handling
|
||||||
|
switch msg.String() {
|
||||||
|
case "ctrl+c":
|
||||||
|
return m, tea.Quit
|
||||||
|
|
||||||
|
case "esc":
|
||||||
|
if m.state != StateChat {
|
||||||
|
m.state = StateChat
|
||||||
|
m.input.Focus()
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
return m, tea.Quit
|
||||||
|
}
|
||||||
|
|
||||||
|
// State-specific key handling
|
||||||
|
switch m.state {
|
||||||
|
case StateChat:
|
||||||
|
return m.handleChatKeys(msg)
|
||||||
|
case StateProviderSelect:
|
||||||
|
return m.handleProviderSelectKeys(msg)
|
||||||
|
case StateModelSelect:
|
||||||
|
return m.handleModelSelectKeys(msg)
|
||||||
|
case StateImageInput:
|
||||||
|
return m.handleImageInputKeys(msg)
|
||||||
|
case StateToolsPanel:
|
||||||
|
return m.handleToolsPanelKeys(msg)
|
||||||
|
case StateSettings:
|
||||||
|
return m.handleSettingsKeys(msg)
|
||||||
|
case StateAPIKeyInput:
|
||||||
|
return m.handleAPIKeyInputKeys(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleChatKeys handles keys in chat state
|
||||||
|
func (m Model) handleChatKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||||
|
switch msg.String() {
|
||||||
|
case "enter":
|
||||||
|
if m.loading {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
text := strings.TrimSpace(m.input.Value())
|
||||||
|
if text == "" {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.chat == nil {
|
||||||
|
m.err = fmt.Errorf("no model selected - press Ctrl+P to select a provider")
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build and send request
|
||||||
|
req := buildRequest(&m, text)
|
||||||
|
pendingRequest = req
|
||||||
|
|
||||||
|
// Add user message to display
|
||||||
|
m.addUserMessage(text, m.pendingImages)
|
||||||
|
|
||||||
|
// Clear input and pending images
|
||||||
|
m.input.Reset()
|
||||||
|
m.pendingImages = nil
|
||||||
|
m.err = nil
|
||||||
|
m.loading = true
|
||||||
|
|
||||||
|
m.viewport.SetContent(m.renderMessages())
|
||||||
|
m.viewport.GotoBottom()
|
||||||
|
|
||||||
|
return m, sendChatRequest(m.chat, req)
|
||||||
|
|
||||||
|
case "ctrl+i":
|
||||||
|
m.previousState = StateChat
|
||||||
|
m.state = StateImageInput
|
||||||
|
m.input.SetValue("")
|
||||||
|
m.input.Placeholder = "Enter image path or URL..."
|
||||||
|
return m, nil
|
||||||
|
|
||||||
|
case "ctrl+t":
|
||||||
|
m.state = StateToolsPanel
|
||||||
|
return m, nil
|
||||||
|
|
||||||
|
case "ctrl+p":
|
||||||
|
m.state = StateProviderSelect
|
||||||
|
m.listIndex = m.providerIndex
|
||||||
|
return m, nil
|
||||||
|
|
||||||
|
case "ctrl+m":
|
||||||
|
if m.provider == nil {
|
||||||
|
m.err = fmt.Errorf("select a provider first")
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
m.state = StateModelSelect
|
||||||
|
m.listItems = m.providers[m.providerIndex].Models
|
||||||
|
m.listIndex = m.providers[m.providerIndex].ModelIndex
|
||||||
|
return m, nil
|
||||||
|
|
||||||
|
case "ctrl+s":
|
||||||
|
m.state = StateSettings
|
||||||
|
return m, nil
|
||||||
|
|
||||||
|
case "ctrl+n":
|
||||||
|
m.newConversation()
|
||||||
|
m.viewport.SetContent(m.renderMessages())
|
||||||
|
return m, nil
|
||||||
|
|
||||||
|
case "up", "down", "pgup", "pgdown":
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.viewport, cmd = m.viewport.Update(msg)
|
||||||
|
return m, cmd
|
||||||
|
|
||||||
|
default:
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.input, cmd = m.input.Update(msg)
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleProviderSelectKeys handles keys in provider selection state
|
||||||
|
func (m Model) handleProviderSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||||
|
switch msg.String() {
|
||||||
|
case "up", "k":
|
||||||
|
if m.listIndex > 0 {
|
||||||
|
m.listIndex--
|
||||||
|
}
|
||||||
|
case "down", "j":
|
||||||
|
if m.listIndex < len(m.providers)-1 {
|
||||||
|
m.listIndex++
|
||||||
|
}
|
||||||
|
case "enter":
|
||||||
|
p := m.providers[m.listIndex]
|
||||||
|
if !p.HasAPIKey {
|
||||||
|
// Need to get API key
|
||||||
|
m.state = StateAPIKeyInput
|
||||||
|
m.apiKeyInput.Focus()
|
||||||
|
m.apiKeyInput.SetValue("")
|
||||||
|
return m, textinput.Blink
|
||||||
|
}
|
||||||
|
|
||||||
|
err := m.selectProvider(m.listIndex)
|
||||||
|
if err != nil {
|
||||||
|
m.err = err
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.state = StateChat
|
||||||
|
m.input.Focus()
|
||||||
|
m.newConversation()
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAPIKeyInputKeys handles keys in API key input state
|
||||||
|
func (m Model) handleAPIKeyInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||||
|
switch msg.String() {
|
||||||
|
case "enter":
|
||||||
|
key := strings.TrimSpace(m.apiKeyInput.Value())
|
||||||
|
if key == "" {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the API key
|
||||||
|
p := m.providers[m.listIndex]
|
||||||
|
m.apiKeys[p.Name] = key
|
||||||
|
m.providers[m.listIndex].HasAPIKey = true
|
||||||
|
|
||||||
|
// Update list items
|
||||||
|
for i, prov := range m.providers {
|
||||||
|
status := " (no key)"
|
||||||
|
if prov.HasAPIKey {
|
||||||
|
status = " (ready)"
|
||||||
|
}
|
||||||
|
m.listItems[i] = prov.Name + status
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select the provider
|
||||||
|
err := m.selectProvider(m.listIndex)
|
||||||
|
if err != nil {
|
||||||
|
m.err = err
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.state = StateChat
|
||||||
|
m.input.Focus()
|
||||||
|
m.newConversation()
|
||||||
|
return m, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.apiKeyInput, cmd = m.apiKeyInput.Update(msg)
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleModelSelectKeys handles keys in model selection state
|
||||||
|
func (m Model) handleModelSelectKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||||
|
switch msg.String() {
|
||||||
|
case "up", "k":
|
||||||
|
if m.listIndex > 0 {
|
||||||
|
m.listIndex--
|
||||||
|
}
|
||||||
|
case "down", "j":
|
||||||
|
if m.listIndex < len(m.listItems)-1 {
|
||||||
|
m.listIndex++
|
||||||
|
}
|
||||||
|
case "enter":
|
||||||
|
err := m.selectModel(m.listIndex)
|
||||||
|
if err != nil {
|
||||||
|
m.err = err
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
m.state = StateChat
|
||||||
|
m.input.Focus()
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleImageInputKeys handles keys in image input state
|
||||||
|
func (m Model) handleImageInputKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||||
|
switch msg.String() {
|
||||||
|
case "enter":
|
||||||
|
input := strings.TrimSpace(m.input.Value())
|
||||||
|
if input == "" {
|
||||||
|
m.state = m.previousState
|
||||||
|
m.input.Placeholder = "Type your message..."
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.input.Placeholder = "Type your message..."
|
||||||
|
|
||||||
|
// Determine input type and load
|
||||||
|
if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") {
|
||||||
|
return m, loadImageFromURL(input)
|
||||||
|
} else if strings.HasPrefix(input, "data:") || len(input) > 100 && !strings.Contains(input, "/") && !strings.Contains(input, "\\") {
|
||||||
|
return m, loadImageFromBase64(input)
|
||||||
|
} else {
|
||||||
|
return m, loadImageFromPath(input)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.input, cmd = m.input.Update(msg)
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleToolsPanelKeys handles keys in tools panel state
|
||||||
|
func (m Model) handleToolsPanelKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||||
|
switch msg.String() {
|
||||||
|
case "t":
|
||||||
|
m.toolsEnabled = !m.toolsEnabled
|
||||||
|
case "enter", "q":
|
||||||
|
m.state = StateChat
|
||||||
|
m.input.Focus()
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleSettingsKeys handles keys in settings state
|
||||||
|
func (m Model) handleSettingsKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||||
|
switch msg.String() {
|
||||||
|
case "1":
|
||||||
|
// Set temperature to nil (default)
|
||||||
|
m.temperature = nil
|
||||||
|
case "2":
|
||||||
|
t := 0.0
|
||||||
|
m.temperature = &t
|
||||||
|
case "3":
|
||||||
|
t := 0.5
|
||||||
|
m.temperature = &t
|
||||||
|
case "4":
|
||||||
|
t := 0.7
|
||||||
|
m.temperature = &t
|
||||||
|
case "5":
|
||||||
|
t := 1.0
|
||||||
|
m.temperature = &t
|
||||||
|
case "enter", "q":
|
||||||
|
m.state = StateChat
|
||||||
|
m.input.Focus()
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
296
cmd/llm/view.go
Normal file
296
cmd/llm/view.go
Normal file
@@ -0,0 +1,296 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
|
||||||
|
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// View renders the current state
|
||||||
|
func (m Model) View() string {
|
||||||
|
switch m.state {
|
||||||
|
case StateProviderSelect:
|
||||||
|
return m.renderProviderSelect()
|
||||||
|
case StateModelSelect:
|
||||||
|
return m.renderModelSelect()
|
||||||
|
case StateImageInput:
|
||||||
|
return m.renderImageInput()
|
||||||
|
case StateToolsPanel:
|
||||||
|
return m.renderToolsPanel()
|
||||||
|
case StateSettings:
|
||||||
|
return m.renderSettings()
|
||||||
|
case StateAPIKeyInput:
|
||||||
|
return m.renderAPIKeyInput()
|
||||||
|
default:
|
||||||
|
return m.renderChat()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderChat renders the main chat view
|
||||||
|
func (m Model) renderChat() string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
// Header
|
||||||
|
provider := m.providerName
|
||||||
|
if provider == "" {
|
||||||
|
provider = "None"
|
||||||
|
}
|
||||||
|
model := m.modelName
|
||||||
|
if model == "" {
|
||||||
|
model = "None"
|
||||||
|
}
|
||||||
|
|
||||||
|
header := headerStyle.Render(fmt.Sprintf("go-llm CLI %s",
|
||||||
|
providerBadgeStyle.Render(fmt.Sprintf("%s/%s", provider, model))))
|
||||||
|
|
||||||
|
b.WriteString(header)
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
// Messages viewport
|
||||||
|
if m.viewportReady {
|
||||||
|
b.WriteString(m.viewport.View())
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image indicator
|
||||||
|
if len(m.pendingImages) > 0 {
|
||||||
|
b.WriteString(imageIndicatorStyle.Render(fmt.Sprintf(" [%d image(s) attached]", len(m.pendingImages))))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error
|
||||||
|
if m.err != nil {
|
||||||
|
b.WriteString(errorStyle.Render(" Error: " + m.err.Error()))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loading
|
||||||
|
if m.loading {
|
||||||
|
b.WriteString(loadingStyle.Render(" Thinking..."))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Input
|
||||||
|
inputBox := inputStyle.Render(m.input.View())
|
||||||
|
b.WriteString(inputBox)
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
// Help
|
||||||
|
help := inputHelpStyle.Render("Enter: send | Ctrl+I: image | Ctrl+T: tools | Ctrl+P: provider | Ctrl+M: model | Ctrl+S: settings | Ctrl+N: new | Esc: quit")
|
||||||
|
b.WriteString(help)
|
||||||
|
|
||||||
|
return appStyle.Render(b.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderMessages renders all messages for the viewport
|
||||||
|
func (m Model) renderMessages() string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
if len(m.messages) == 0 {
|
||||||
|
b.WriteString(systemMsgStyle.Render("[System] " + m.systemPrompt))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(lipgloss.NewStyle().Foreground(mutedColor).Render("Start a conversation by typing a message below."))
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString(systemMsgStyle.Render("[System] " + m.systemPrompt))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
for _, msg := range m.messages {
|
||||||
|
var content string
|
||||||
|
var style lipgloss.Style
|
||||||
|
|
||||||
|
switch msg.Role {
|
||||||
|
case llm.RoleUser:
|
||||||
|
style = userMsgStyle
|
||||||
|
label := roleLabelStyle.Foreground(secondaryColor).Render("[User]")
|
||||||
|
content = label + " " + msg.Content
|
||||||
|
if msg.Images > 0 {
|
||||||
|
content += imageIndicatorStyle.Render(fmt.Sprintf(" [%d image(s)]", msg.Images))
|
||||||
|
}
|
||||||
|
case llm.RoleAssistant:
|
||||||
|
style = assistantMsgStyle
|
||||||
|
label := roleLabelStyle.Foreground(lipgloss.Color("255")).Render("[Assistant]")
|
||||||
|
content = label + " " + msg.Content
|
||||||
|
case llm.Role("tool_call"):
|
||||||
|
style = toolCallStyle
|
||||||
|
content = " -> Calling: " + msg.Content
|
||||||
|
case llm.Role("tool_result"):
|
||||||
|
style = toolResultStyle
|
||||||
|
content = " <- Result: " + msg.Content
|
||||||
|
default:
|
||||||
|
style = assistantMsgStyle
|
||||||
|
content = msg.Content
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString(style.Render(content))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderProviderSelect renders the provider selection view
|
||||||
|
func (m Model) renderProviderSelect() string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
b.WriteString(headerStyle.Render("Select Provider"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
for i, item := range m.listItems {
|
||||||
|
cursor := " "
|
||||||
|
style := normalItemStyle
|
||||||
|
if i == m.listIndex {
|
||||||
|
cursor = "> "
|
||||||
|
style = selectedItemStyle
|
||||||
|
}
|
||||||
|
b.WriteString(style.Render(cursor + item))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(helpStyle.Render("Use arrow keys or j/k to navigate, Enter to select, Esc to cancel"))
|
||||||
|
|
||||||
|
return appStyle.Render(b.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderAPIKeyInput renders the API key input view
|
||||||
|
func (m Model) renderAPIKeyInput() string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
provider := m.providers[m.listIndex]
|
||||||
|
|
||||||
|
b.WriteString(headerStyle.Render(fmt.Sprintf("Enter API Key for %s", provider.Name)))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
b.WriteString(fmt.Sprintf("Environment variable: %s\n\n", provider.EnvVar))
|
||||||
|
b.WriteString("Enter your API key below (it will be hidden):\n\n")
|
||||||
|
|
||||||
|
inputBox := inputStyle.Render(m.apiKeyInput.View())
|
||||||
|
b.WriteString(inputBox)
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
b.WriteString(helpStyle.Render("Enter to confirm, Esc to cancel"))
|
||||||
|
|
||||||
|
return appStyle.Render(b.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderModelSelect renders the model selection view
|
||||||
|
func (m Model) renderModelSelect() string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
b.WriteString(headerStyle.Render(fmt.Sprintf("Select Model (%s)", m.providerName)))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
for i, item := range m.listItems {
|
||||||
|
cursor := " "
|
||||||
|
style := normalItemStyle
|
||||||
|
if i == m.listIndex {
|
||||||
|
cursor = "> "
|
||||||
|
style = selectedItemStyle
|
||||||
|
}
|
||||||
|
if item == m.modelName {
|
||||||
|
item += " (current)"
|
||||||
|
}
|
||||||
|
b.WriteString(style.Render(cursor + item))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(helpStyle.Render("Use arrow keys or j/k to navigate, Enter to select, Esc to cancel"))
|
||||||
|
|
||||||
|
return appStyle.Render(b.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderImageInput renders the image input view
|
||||||
|
func (m Model) renderImageInput() string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
b.WriteString(headerStyle.Render("Add Image"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
b.WriteString("Enter an image source:\n")
|
||||||
|
b.WriteString(" - File path (e.g., /path/to/image.png)\n")
|
||||||
|
b.WriteString(" - URL (e.g., https://example.com/image.jpg)\n")
|
||||||
|
b.WriteString(" - Base64 data or data URL\n\n")
|
||||||
|
|
||||||
|
if len(m.pendingImages) > 0 {
|
||||||
|
b.WriteString(imageIndicatorStyle.Render(fmt.Sprintf("Currently attached: %d image(s)\n\n", len(m.pendingImages))))
|
||||||
|
}
|
||||||
|
|
||||||
|
inputBox := inputStyle.Render(m.input.View())
|
||||||
|
b.WriteString(inputBox)
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
b.WriteString(helpStyle.Render("Enter to add image, Esc to cancel"))
|
||||||
|
|
||||||
|
return appStyle.Render(b.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderToolsPanel renders the tools panel
|
||||||
|
func (m Model) renderToolsPanel() string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
b.WriteString(headerStyle.Render("Tools / Function Calling"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
status := "DISABLED"
|
||||||
|
statusStyle := errorStyle
|
||||||
|
if m.toolsEnabled {
|
||||||
|
status = "ENABLED"
|
||||||
|
statusStyle = lipgloss.NewStyle().Foreground(successColor).Bold(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString(settingLabelStyle.Render("Tools Status:"))
|
||||||
|
b.WriteString(statusStyle.Render(status))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
b.WriteString("Available tools:\n")
|
||||||
|
for _, fn := range m.toolbox.Functions() {
|
||||||
|
b.WriteString(fmt.Sprintf(" - %s: %s\n", selectedItemStyle.Render(fn.Name), fn.Description))
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(helpStyle.Render("Press 't' to toggle tools, Enter or 'q' to close"))
|
||||||
|
|
||||||
|
return appStyle.Render(b.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderSettings renders the settings view
|
||||||
|
func (m Model) renderSettings() string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
b.WriteString(headerStyle.Render("Settings"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
// Temperature
|
||||||
|
tempStr := "default"
|
||||||
|
if m.temperature != nil {
|
||||||
|
tempStr = fmt.Sprintf("%.1f", *m.temperature)
|
||||||
|
}
|
||||||
|
b.WriteString(settingLabelStyle.Render("Temperature:"))
|
||||||
|
b.WriteString(settingValueStyle.Render(tempStr))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
b.WriteString("Press a key to set temperature:\n")
|
||||||
|
b.WriteString(" 1 - Default (model decides)\n")
|
||||||
|
b.WriteString(" 2 - 0.0 (deterministic)\n")
|
||||||
|
b.WriteString(" 3 - 0.5 (balanced)\n")
|
||||||
|
b.WriteString(" 4 - 0.7 (creative)\n")
|
||||||
|
b.WriteString(" 5 - 1.0 (very creative)\n")
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
// System prompt
|
||||||
|
b.WriteString(settingLabelStyle.Render("System Prompt:"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(settingValueStyle.Render(" " + m.systemPrompt))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
b.WriteString(helpStyle.Render("Enter or 'q' to close"))
|
||||||
|
|
||||||
|
return appStyle.Render(b.String())
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package go_llm
|
package llm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|||||||
575
docs/sandbox-setup.md
Normal file
575
docs/sandbox-setup.md
Normal file
@@ -0,0 +1,575 @@
|
|||||||
|
# Sandbox Setup & Hardening Guide
|
||||||
|
|
||||||
|
Complete guide for setting up a Proxmox VE host to run isolated LXC sandbox containers for the go-llm sandbox package.
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
1. [Prerequisites](#1-prerequisites)
|
||||||
|
2. [Proxmox Host Preparation](#2-proxmox-host-preparation)
|
||||||
|
3. [Network Setup](#3-network-setup)
|
||||||
|
4. [LXC Template Creation](#4-lxc-template-creation)
|
||||||
|
5. [SSH Key Setup](#5-ssh-key-setup)
|
||||||
|
6. [Configuration](#6-configuration)
|
||||||
|
7. [Hardening Checklist](#7-hardening-checklist)
|
||||||
|
8. [Monitoring & Maintenance](#8-monitoring--maintenance)
|
||||||
|
9. [Troubleshooting](#9-troubleshooting)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Prerequisites
|
||||||
|
|
||||||
|
### Hardware/VM Requirements
|
||||||
|
|
||||||
|
| Resource | Minimum | Recommended |
|
||||||
|
|----------|---------|-------------|
|
||||||
|
| CPU | 4 cores | 8+ cores |
|
||||||
|
| RAM | 8 GB | 16+ GB |
|
||||||
|
| Storage | 100 GB SSD | 250+ GB SSD |
|
||||||
|
| Network | 1 NIC | 2 NICs (mgmt + sandbox) |
|
||||||
|
|
||||||
|
### Software
|
||||||
|
|
||||||
|
- Proxmox VE 8.x ([installation guide](https://pve.proxmox.com/wiki/Installation))
|
||||||
|
- During install, configure the management interface on `vmbr0`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Proxmox Host Preparation
|
||||||
|
|
||||||
|
### Create Resource Pool
|
||||||
|
|
||||||
|
Scope sandbox containers to a dedicated resource pool to limit API token access:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pvesh create /pools --poolid sandbox-pool
|
||||||
|
```
|
||||||
|
|
||||||
|
### Create API User and Token
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Create dedicated user
|
||||||
|
pveum useradd mort-sandbox@pve
|
||||||
|
|
||||||
|
# Create role with minimum required permissions
|
||||||
|
pveum roleadd SandboxAdmin -privs "VM.Allocate,VM.Clone,VM.Audit,VM.PowerMgmt,VM.Console,Datastore.AllocateSpace,Datastore.Audit"
|
||||||
|
|
||||||
|
# Grant role on the sandbox pool only
|
||||||
|
pveum aclmod /pool/sandbox-pool -user mort-sandbox@pve -role SandboxAdmin
|
||||||
|
|
||||||
|
# Grant access to the template storage
|
||||||
|
pveum aclmod /storage/local -user mort-sandbox@pve -role PVEDatastoreUser
|
||||||
|
|
||||||
|
# Create API token (privsep=0 means token inherits user's permissions)
|
||||||
|
pveum user token add mort-sandbox@pve sandbox-token --privsep=0
|
||||||
|
```
|
||||||
|
|
||||||
|
Save the output — it contains the token secret:
|
||||||
|
```
|
||||||
|
┌──────────┬──────────────────────────────────────────┐
|
||||||
|
│ key │ value │
|
||||||
|
╞══════════╪══════════════════════════════════════════╡
|
||||||
|
│ full-tokenid │ mort-sandbox@pve!sandbox-token │
|
||||||
|
│ value │ xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx │
|
||||||
|
└──────────┴──────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
Store the secret securely (environment variable, secret manager, etc.). Never commit it to source control.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Network Setup
|
||||||
|
|
||||||
|
### 3.1 Create Isolated Bridge
|
||||||
|
|
||||||
|
Add to `/etc/network/interfaces` on the Proxmox host:
|
||||||
|
|
||||||
|
```
|
||||||
|
auto vmbr1
|
||||||
|
iface vmbr1 inet static
|
||||||
|
address 10.99.0.1/16
|
||||||
|
bridge-ports none
|
||||||
|
bridge-stp off
|
||||||
|
bridge-fd 0
|
||||||
|
post-up echo 1 > /proc/sys/net/ipv4/ip_forward
|
||||||
|
# NAT for optional internet access (controlled per-container by nftables)
|
||||||
|
post-up nft add table nat 2>/dev/null; true
|
||||||
|
post-up nft add chain nat postrouting { type nat hook postrouting priority 100 \; } 2>/dev/null; true
|
||||||
|
post-up nft add rule nat postrouting oifname "vmbr0" ip saddr 10.99.0.0/16 masquerade 2>/dev/null; true
|
||||||
|
```
|
||||||
|
|
||||||
|
Apply the configuration:
|
||||||
|
```bash
|
||||||
|
ifreload -a
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.2 Install and Configure DHCP
|
||||||
|
|
||||||
|
```bash
|
||||||
|
apt-get install -y dnsmasq
|
||||||
|
```
|
||||||
|
|
||||||
|
Create `/etc/dnsmasq.d/sandbox.conf`:
|
||||||
|
```
|
||||||
|
interface=vmbr1
|
||||||
|
bind-interfaces
|
||||||
|
dhcp-range=10.99.1.1,10.99.254.254,255.255.0.0,1h
|
||||||
|
dhcp-option=option:router,10.99.0.1
|
||||||
|
dhcp-option=option:dns-server,1.1.1.1,8.8.8.8
|
||||||
|
```
|
||||||
|
|
||||||
|
Restart dnsmasq:
|
||||||
|
```bash
|
||||||
|
systemctl restart dnsmasq
|
||||||
|
systemctl enable dnsmasq
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.3 Configure nftables Firewall
|
||||||
|
|
||||||
|
Create `/etc/nftables.conf`:
|
||||||
|
|
||||||
|
```nft
|
||||||
|
#!/usr/sbin/nft -f
|
||||||
|
|
||||||
|
flush ruleset
|
||||||
|
|
||||||
|
table inet sandbox {
|
||||||
|
# Dynamic set of container IPs allowed internet access.
|
||||||
|
# Populated/cleared by the sandbox manager via the Proxmox API.
|
||||||
|
set internet_allowed {
|
||||||
|
type ipv4_addr
|
||||||
|
}
|
||||||
|
|
||||||
|
chain forward {
|
||||||
|
type filter hook forward priority 0; policy drop;
|
||||||
|
|
||||||
|
# Allow established/related connections
|
||||||
|
ct state established,related accept
|
||||||
|
|
||||||
|
# Allow inter-bridge traffic (host ↔ containers via vmbr1)
|
||||||
|
iifname "vmbr1" oifname "vmbr1" accept
|
||||||
|
|
||||||
|
# Allow DNS for all containers (needed for apt)
|
||||||
|
ip saddr 10.99.0.0/16 udp dport 53 accept
|
||||||
|
ip saddr 10.99.0.0/16 tcp dport 53 accept
|
||||||
|
|
||||||
|
# Allow HTTP/HTTPS only for containers in the internet_allowed set
|
||||||
|
ip saddr @internet_allowed tcp dport { 80, 443 } accept
|
||||||
|
|
||||||
|
# Rate limit: max 50 new connections per second per container
|
||||||
|
ip saddr 10.99.0.0/16 ct state new limit rate over 50/second drop
|
||||||
|
|
||||||
|
# Block everything else from containers
|
||||||
|
ip saddr 10.99.0.0/16 drop
|
||||||
|
|
||||||
|
# Allow host → containers (for SSH from the application)
|
||||||
|
ip daddr 10.99.0.0/16 accept
|
||||||
|
}
|
||||||
|
|
||||||
|
chain input {
|
||||||
|
type filter hook input priority 0; policy accept;
|
||||||
|
|
||||||
|
# Block containers from accessing Proxmox management ports
|
||||||
|
# (only SSH is allowed for the sandbox manager)
|
||||||
|
iifname "vmbr1" ip daddr 10.99.0.1 tcp dport != 22 drop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# NAT table for optional internet access
|
||||||
|
table nat {
|
||||||
|
chain postrouting {
|
||||||
|
type nat hook postrouting priority 100;
|
||||||
|
oifname "vmbr0" ip saddr 10.99.0.0/16 masquerade
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Apply and persist:
|
||||||
|
```bash
|
||||||
|
nft -f /etc/nftables.conf
|
||||||
|
systemctl enable nftables
|
||||||
|
```
|
||||||
|
|
||||||
|
Verify:
|
||||||
|
```bash
|
||||||
|
nft list ruleset
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.4 Test Network Isolation
|
||||||
|
|
||||||
|
From a test container on `vmbr1`:
|
||||||
|
```bash
|
||||||
|
# Should work: DNS resolution
|
||||||
|
dig google.com
|
||||||
|
|
||||||
|
# Should be blocked: HTTP (not in internet_allowed set)
|
||||||
|
curl -s --connect-timeout 5 https://google.com && echo "FAIL: should be blocked" || echo "OK: blocked"
|
||||||
|
|
||||||
|
# Should be blocked: access to LAN
|
||||||
|
ping -c 1 -W 2 192.168.1.1 && echo "FAIL: LAN reachable" || echo "OK: LAN blocked"
|
||||||
|
|
||||||
|
# Should be blocked: access to Proxmox management
|
||||||
|
curl -s --connect-timeout 5 https://10.99.0.1:8006 && echo "FAIL: Proxmox reachable" || echo "OK: Proxmox blocked"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. LXC Template Creation
|
||||||
|
|
||||||
|
### 4.1 Download Base Image
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pveam update
|
||||||
|
pveam download local ubuntu-24.04-standard_24.04-1_amd64.tar.zst
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.2 Create Template Container
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pct create 9000 local:vztmpl/ubuntu-24.04-standard_24.04-1_amd64.tar.zst \
|
||||||
|
--hostname sandbox-template \
|
||||||
|
--memory 1024 \
|
||||||
|
--swap 0 \
|
||||||
|
--cores 1 \
|
||||||
|
--rootfs local-lvm:8 \
|
||||||
|
--net0 name=eth0,bridge=vmbr1,ip=dhcp \
|
||||||
|
--unprivileged 1 \
|
||||||
|
--features nesting=0 \
|
||||||
|
--ostype ubuntu \
|
||||||
|
--ssh-public-keys /root/.ssh/mort_sandbox.pub \
|
||||||
|
--pool sandbox-pool \
|
||||||
|
--start 0
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.3 Install Base Packages
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pct start 9000
|
||||||
|
pct exec 9000 -- bash -c '
|
||||||
|
apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
build-essential \
|
||||||
|
python3 python3-pip python3-venv \
|
||||||
|
nodejs npm \
|
||||||
|
git curl wget jq \
|
||||||
|
vim nano \
|
||||||
|
htop tree \
|
||||||
|
ca-certificates \
|
||||||
|
openssh-server \
|
||||||
|
&& apt-get clean \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.4 Create Sandbox User
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pct exec 9000 -- bash -c '
|
||||||
|
# Create unprivileged sandbox user with sudo
|
||||||
|
useradd -m -s /bin/bash -G sudo sandbox
|
||||||
|
echo "sandbox ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/sandbox
|
||||||
|
|
||||||
|
# Set up SSH access
|
||||||
|
mkdir -p /home/sandbox/.ssh
|
||||||
|
cp /root/.ssh/authorized_keys /home/sandbox/.ssh/
|
||||||
|
chown -R sandbox:sandbox /home/sandbox/.ssh
|
||||||
|
chmod 700 /home/sandbox/.ssh
|
||||||
|
chmod 600 /home/sandbox/.ssh/authorized_keys
|
||||||
|
|
||||||
|
# Create uploads directory
|
||||||
|
mkdir -p /home/sandbox/uploads
|
||||||
|
chown sandbox:sandbox /home/sandbox/uploads
|
||||||
|
|
||||||
|
# Enable SSH
|
||||||
|
systemctl enable ssh
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.5 Security Hardening
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pct exec 9000 -- bash -c '
|
||||||
|
# Process limits (prevent fork bombs)
|
||||||
|
echo "* soft nproc 256" >> /etc/security/limits.conf
|
||||||
|
echo "* hard nproc 512" >> /etc/security/limits.conf
|
||||||
|
|
||||||
|
# Disable core dumps
|
||||||
|
echo "* hard core 0" >> /etc/security/limits.conf
|
||||||
|
|
||||||
|
# Disable unnecessary services
|
||||||
|
systemctl disable systemd-resolved 2>/dev/null || true
|
||||||
|
systemctl disable snapd 2>/dev/null || true
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.6 Convert to Template
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pct stop 9000
|
||||||
|
pct template 9000
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.7 Verify Template
|
||||||
|
|
||||||
|
Clone and test manually:
|
||||||
|
```bash
|
||||||
|
pct clone 9000 9999 --hostname test-sandbox --full
|
||||||
|
pct start 9999
|
||||||
|
# Wait for DHCP, then SSH in
|
||||||
|
ssh sandbox@<container-ip>
|
||||||
|
# Run some commands, verify packages installed
|
||||||
|
sudo apt-get update
|
||||||
|
python3 --version
|
||||||
|
node --version
|
||||||
|
# Clean up
|
||||||
|
exit
|
||||||
|
pct stop 9999
|
||||||
|
pct destroy 9999
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. SSH Key Setup
|
||||||
|
|
||||||
|
### 5.1 Generate Key Pair
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ssh-keygen -t ed25519 -f /etc/mort/sandbox_key -N "" -C "mort-sandbox"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.2 Install Public Key in Template
|
||||||
|
|
||||||
|
This was done in step 4.2 with `--ssh-public-keys`. If you need to update it:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pct start 9000 # Only if template — you'll need to untemplate first
|
||||||
|
# Copy key
|
||||||
|
cat /etc/mort/sandbox_key.pub | pct exec 9000 -- tee /home/sandbox/.ssh/authorized_keys
|
||||||
|
pct exec 9000 -- chown sandbox:sandbox /home/sandbox/.ssh/authorized_keys
|
||||||
|
pct exec 9000 -- chmod 600 /home/sandbox/.ssh/authorized_keys
|
||||||
|
pct stop 9000
|
||||||
|
pct template 9000
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.3 Set Permissions
|
||||||
|
|
||||||
|
```bash
|
||||||
|
chmod 600 /etc/mort/sandbox_key
|
||||||
|
chmod 644 /etc/mort/sandbox_key.pub
|
||||||
|
# If running as a specific user:
|
||||||
|
chown mort:mort /etc/mort/sandbox_key /etc/mort/sandbox_key.pub
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Configuration
|
||||||
|
|
||||||
|
### Go Configuration
|
||||||
|
|
||||||
|
```go
|
||||||
|
signer, _ := sandbox.LoadSSHKey("/etc/mort/sandbox_key")
|
||||||
|
|
||||||
|
mgr, _ := sandbox.NewManager(sandbox.Config{
|
||||||
|
Proxmox: sandbox.ProxmoxConfig{
|
||||||
|
BaseURL: "https://proxmox.local:8006",
|
||||||
|
TokenID: "mort-sandbox@pve!sandbox-token",
|
||||||
|
Secret: os.Getenv("SANDBOX_PROXMOX_SECRET"),
|
||||||
|
Node: "pve",
|
||||||
|
TemplateID: 9000,
|
||||||
|
Pool: "sandbox-pool",
|
||||||
|
Bridge: "vmbr1",
|
||||||
|
InsecureSkipVerify: true, // Only for self-signed certs
|
||||||
|
},
|
||||||
|
SSH: sandbox.SSHConfig{
|
||||||
|
Signer: signer,
|
||||||
|
User: "sandbox", // default
|
||||||
|
ConnectTimeout: 10 * time.Second, // default
|
||||||
|
CommandTimeout: 60 * time.Second, // default
|
||||||
|
},
|
||||||
|
Defaults: sandbox.ContainerConfig{
|
||||||
|
CPUs: 1,
|
||||||
|
MemoryMB: 1024,
|
||||||
|
DiskGB: 8,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
| Variable | Description |
|
||||||
|
|----------|-------------|
|
||||||
|
| `SANDBOX_PROXMOX_SECRET` | Proxmox API token secret |
|
||||||
|
| `SANDBOX_SSH_KEY_PATH` | Path to SSH private key (alternative to config) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. Hardening Checklist
|
||||||
|
|
||||||
|
Run through this checklist after setup:
|
||||||
|
|
||||||
|
### Container Isolation
|
||||||
|
- [ ] Containers are unprivileged (verify UID mapping in `/etc/pve/lxc/<id>.conf`)
|
||||||
|
- [ ] Nesting is disabled (`features: nesting=0`)
|
||||||
|
- [ ] Swap is disabled on containers (`swap: 0`)
|
||||||
|
- [ ] Resource pool scoping: API token can only touch `sandbox-pool`
|
||||||
|
|
||||||
|
### Network Isolation
|
||||||
|
- [ ] `vmbr1` has no physical ports (`bridge-ports none`)
|
||||||
|
- [ ] nftables rules loaded: `nft list ruleset` shows sandbox table
|
||||||
|
- [ ] nftables persists across reboots: `systemctl is-enabled nftables`
|
||||||
|
- [ ] Default-deny outbound for containers
|
||||||
|
- [ ] DNS (port 53) allowed for all containers
|
||||||
|
- [ ] HTTP/HTTPS only for containers in `internet_allowed` set
|
||||||
|
- [ ] Rate limiting active (50 conn/sec)
|
||||||
|
- [ ] Containers cannot reach Proxmox management (port 8006 blocked)
|
||||||
|
|
||||||
|
### Security Profiles
|
||||||
|
- [ ] AppArmor profile active: `lxc-container-default-cgns`
|
||||||
|
- [ ] Process limits in `/etc/security/limits.conf` (nproc 256/512)
|
||||||
|
- [ ] Core dumps disabled
|
||||||
|
- [ ] Capability drops verified in container config
|
||||||
|
|
||||||
|
### Functional Tests
|
||||||
|
- [ ] **Fork bomb test**: run `:(){ :|:& };:` in container → PID limit fires, container survives
|
||||||
|
- [ ] **OOM test**: allocate >1GB memory → container OOM-killed, host unaffected
|
||||||
|
- [ ] **Network scan test**: `nmap` from container → blocked by nftables
|
||||||
|
- [ ] **Container escape test**: attempt to mount host filesystem → denied
|
||||||
|
- [ ] **LAN access test**: ping LAN hosts → blocked
|
||||||
|
- [ ] **Cross-container test**: ping other sandbox containers → blocked
|
||||||
|
- [ ] **Internet access test**: HTTP without being in `internet_allowed` → blocked
|
||||||
|
- [ ] **Internet access test**: add to `internet_allowed` → HTTP works
|
||||||
|
- [ ] **Cleanup test**: destroy container → verify no orphan volumes
|
||||||
|
|
||||||
|
### Operational Tests
|
||||||
|
- [ ] Clone template → container starts → SSH connects → commands work
|
||||||
|
- [ ] File upload/download via SFTP works
|
||||||
|
- [ ] Container destroy removes all resources
|
||||||
|
- [ ] Orphan cleanup: kill application mid-session, restart, verify cleanup
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Monitoring & Maintenance
|
||||||
|
|
||||||
|
### Log Rotation
|
||||||
|
|
||||||
|
Sandbox session logs should be rotated to prevent disk exhaustion. If using slog to a file:
|
||||||
|
|
||||||
|
```
|
||||||
|
# /etc/logrotate.d/sandbox
|
||||||
|
/var/log/sandbox/*.log {
|
||||||
|
daily
|
||||||
|
rotate 14
|
||||||
|
compress
|
||||||
|
delaycompress
|
||||||
|
missingok
|
||||||
|
notifempty
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Storage Cleanup
|
||||||
|
|
||||||
|
Verify destroyed containers don't leave orphan volumes:
|
||||||
|
```bash
|
||||||
|
# List all LVM volumes in the sandbox storage
|
||||||
|
lvs | grep sandbox
|
||||||
|
|
||||||
|
# Compare with running containers
|
||||||
|
pct list | grep sandbox
|
||||||
|
```
|
||||||
|
|
||||||
|
### Template Updates
|
||||||
|
|
||||||
|
Periodically update the template with latest packages:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Un-template (creates a regular container from template)
|
||||||
|
# Note: you can't un-template directly; clone then replace
|
||||||
|
pct clone 9000 9001 --hostname template-update --full
|
||||||
|
pct start 9001
|
||||||
|
pct exec 9001 -- bash -c 'apt-get update && apt-get upgrade -y && apt-get clean'
|
||||||
|
pct stop 9001
|
||||||
|
|
||||||
|
# Destroy old template and create new one
|
||||||
|
pct destroy 9000
|
||||||
|
# Rename 9001 → 9000 (or update your config to use the new ID)
|
||||||
|
pct template 9001
|
||||||
|
```
|
||||||
|
|
||||||
|
### Proxmox Host Updates
|
||||||
|
|
||||||
|
```bash
|
||||||
|
apt-get update && apt-get dist-upgrade -y
|
||||||
|
# Reboot if kernel was updated
|
||||||
|
# Verify nftables rules are still loaded after reboot
|
||||||
|
nft list ruleset
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. Troubleshooting
|
||||||
|
|
||||||
|
### Container won't start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Check task log
|
||||||
|
pct start <id>
|
||||||
|
# If error, check:
|
||||||
|
journalctl -u pve-container@<id> -n 50
|
||||||
|
|
||||||
|
# Common issues:
|
||||||
|
# - Storage full: check `df -h` and `lvs`
|
||||||
|
# - UID mapping issues: verify /etc/subuid and /etc/subgid
|
||||||
|
```
|
||||||
|
|
||||||
|
### SSH connection refused
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Verify container is running
|
||||||
|
pct status <id>
|
||||||
|
|
||||||
|
# Check if SSH is running inside container
|
||||||
|
pct exec <id> -- systemctl status ssh
|
||||||
|
|
||||||
|
# Verify IP assignment
|
||||||
|
pct exec <id> -- ip addr show eth0
|
||||||
|
|
||||||
|
# Check DHCP leases
|
||||||
|
cat /var/lib/misc/dnsmasq.leases
|
||||||
|
```
|
||||||
|
|
||||||
|
### Container has no internet (when it should)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Verify container IP is in the internet_allowed set
|
||||||
|
nft list set inet sandbox internet_allowed
|
||||||
|
|
||||||
|
# Manually add for testing
|
||||||
|
nft add element inet sandbox internet_allowed { 10.99.1.5 }
|
||||||
|
|
||||||
|
# Verify NAT is working
|
||||||
|
nft list table nat
|
||||||
|
|
||||||
|
# Check if IP forwarding is enabled
|
||||||
|
cat /proc/sys/net/ipv4/ip_forward # Should be 1
|
||||||
|
```
|
||||||
|
|
||||||
|
### nftables rules lost after reboot
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Verify nftables is enabled
|
||||||
|
systemctl is-enabled nftables
|
||||||
|
|
||||||
|
# If rules are missing, reload
|
||||||
|
nft -f /etc/nftables.conf
|
||||||
|
|
||||||
|
# Make sure the config file is correct
|
||||||
|
nft -c -f /etc/nftables.conf # Check syntax without applying
|
||||||
|
```
|
||||||
|
|
||||||
|
### Orphaned containers
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# List all containers in the sandbox pool
|
||||||
|
pvesh get /pools/sandbox-pool --output-format json | jq '.members[] | select(.type == "lxc")'
|
||||||
|
|
||||||
|
# Destroy orphans manually
|
||||||
|
pct stop <id> && pct destroy <id> --force --purge
|
||||||
|
```
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package go_llm
|
package llm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package go_llm
|
package llm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|||||||
87
go.mod
87
go.mod
@@ -1,48 +1,67 @@
|
|||||||
module gitea.stevedudenhoeffer.com/steve/go-llm
|
module gitea.stevedudenhoeffer.com/steve/go-llm
|
||||||
|
|
||||||
go 1.23.1
|
go 1.24.0
|
||||||
|
|
||||||
|
toolchain go1.24.2
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/google/generative-ai-go v0.19.0
|
github.com/charmbracelet/bubbles v0.21.0
|
||||||
github.com/liushuangls/go-anthropic/v2 v2.15.0
|
github.com/charmbracelet/bubbletea v1.3.10
|
||||||
github.com/openai/openai-go v0.1.0-beta.9
|
github.com/charmbracelet/lipgloss v1.1.0
|
||||||
golang.org/x/image v0.29.0
|
github.com/joho/godotenv v1.5.1
|
||||||
google.golang.org/api v0.228.0
|
github.com/liushuangls/go-anthropic/v2 v2.17.0
|
||||||
|
github.com/modelcontextprotocol/go-sdk v1.2.0
|
||||||
|
github.com/openai/openai-go v1.12.0
|
||||||
|
golang.org/x/image v0.35.0
|
||||||
|
google.golang.org/genai v1.43.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
cloud.google.com/go v0.120.0 // indirect
|
cloud.google.com/go v0.123.0 // indirect
|
||||||
cloud.google.com/go/ai v0.10.1 // indirect
|
cloud.google.com/go/auth v0.18.1 // indirect
|
||||||
cloud.google.com/go/auth v0.15.0 // indirect
|
cloud.google.com/go/compute/metadata v0.9.0 // indirect
|
||||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
github.com/atotto/clipboard v0.1.4 // indirect
|
||||||
cloud.google.com/go/compute/metadata v0.6.0 // indirect
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||||
cloud.google.com/go/longrunning v0.6.6 // indirect
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
|
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
||||||
|
github.com/charmbracelet/x/ansi v0.10.1 // indirect
|
||||||
|
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
|
||||||
|
github.com/charmbracelet/x/term v0.2.1 // indirect
|
||||||
|
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||||
github.com/go-logr/logr v1.4.2 // indirect
|
github.com/go-logr/logr v1.4.3 // indirect
|
||||||
github.com/go-logr/stdr v1.2.2 // indirect
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
|
github.com/google/go-cmp v0.7.0 // indirect
|
||||||
|
github.com/google/jsonschema-go v0.3.0 // indirect
|
||||||
github.com/google/s2a-go v0.1.9 // indirect
|
github.com/google/s2a-go v0.1.9 // indirect
|
||||||
github.com/google/uuid v1.6.0 // indirect
|
github.com/googleapis/enterprise-certificate-proxy v0.3.11 // indirect
|
||||||
github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect
|
github.com/googleapis/gax-go/v2 v2.16.0 // indirect
|
||||||
github.com/googleapis/gax-go/v2 v2.14.1 // indirect
|
github.com/gorilla/websocket v1.5.3 // indirect
|
||||||
|
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||||
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||||
|
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||||
|
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||||
|
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||||
|
github.com/muesli/termenv v0.16.0 // indirect
|
||||||
|
github.com/rivo/uniseg v0.4.7 // indirect
|
||||||
github.com/tidwall/gjson v1.18.0 // indirect
|
github.com/tidwall/gjson v1.18.0 // indirect
|
||||||
github.com/tidwall/match v1.1.1 // indirect
|
github.com/tidwall/match v1.2.0 // indirect
|
||||||
github.com/tidwall/pretty v1.2.1 // indirect
|
github.com/tidwall/pretty v1.2.1 // indirect
|
||||||
github.com/tidwall/sjson v1.2.5 // indirect
|
github.com/tidwall/sjson v1.2.5 // indirect
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0 // indirect
|
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 // indirect
|
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||||
go.opentelemetry.io/otel v1.35.0 // indirect
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 // indirect
|
||||||
go.opentelemetry.io/otel/metric v1.35.0 // indirect
|
go.opentelemetry.io/otel v1.39.0 // indirect
|
||||||
go.opentelemetry.io/otel/trace v1.35.0 // indirect
|
go.opentelemetry.io/otel/metric v1.39.0 // indirect
|
||||||
golang.org/x/crypto v0.37.0 // indirect
|
go.opentelemetry.io/otel/trace v1.39.0 // indirect
|
||||||
golang.org/x/net v0.39.0 // indirect
|
golang.org/x/crypto v0.47.0 // indirect
|
||||||
golang.org/x/oauth2 v0.29.0 // indirect
|
golang.org/x/net v0.49.0 // indirect
|
||||||
golang.org/x/sync v0.16.0 // indirect
|
golang.org/x/oauth2 v0.32.0 // indirect
|
||||||
golang.org/x/sys v0.32.0 // indirect
|
golang.org/x/sys v0.40.0 // indirect
|
||||||
golang.org/x/text v0.27.0 // indirect
|
golang.org/x/text v0.33.0 // indirect
|
||||||
golang.org/x/time v0.11.0 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20260122232226-8e98ce8d340d // indirect
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a // indirect
|
google.golang.org/grpc v1.78.0 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a // indirect
|
google.golang.org/protobuf v1.36.11 // indirect
|
||||||
google.golang.org/grpc v1.71.1 // indirect
|
|
||||||
google.golang.org/protobuf v1.36.6 // indirect
|
|
||||||
)
|
)
|
||||||
|
|||||||
186
go.sum
186
go.sum
@@ -1,97 +1,145 @@
|
|||||||
cloud.google.com/go v0.120.0 h1:wc6bgG9DHyKqF5/vQvX1CiZrtHnxJjBlKUyF9nP6meA=
|
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
|
||||||
cloud.google.com/go v0.120.0/go.mod h1:/beW32s8/pGRuj4IILWQNd4uuebeT4dkOhKmkfit64Q=
|
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
|
||||||
cloud.google.com/go/ai v0.10.1 h1:EU93KqYmMeOKgaBXAz2DshH2C/BzAT1P+iJORksLIic=
|
cloud.google.com/go/auth v0.18.1 h1:IwTEx92GFUo2pJ6Qea0EU3zYvKnTAeRCODxfA/G5UWs=
|
||||||
cloud.google.com/go/ai v0.10.1/go.mod h1:sWWHZvmJ83BjuxAQtYEiA0SFTpijtbH+SXWFO14ri5A=
|
cloud.google.com/go/auth v0.18.1/go.mod h1:GfTYoS9G3CWpRA3Va9doKN9mjPGRS+v41jmZAhBzbrA=
|
||||||
cloud.google.com/go/auth v0.15.0 h1:Ly0u4aA5vG/fsSsxu98qCQBemXtAtJf+95z9HK+cxps=
|
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
|
||||||
cloud.google.com/go/auth v0.15.0/go.mod h1:WJDGqZ1o9E9wKIL+IwStfyn/+s59zl4Bi+1KQNVXLZ8=
|
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
|
||||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
|
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||||
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
|
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||||
cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I=
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||||
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||||
cloud.google.com/go/longrunning v0.6.6 h1:XJNDo5MUfMM05xK3ewpbSdmt7R2Zw+aQEMbdQR65Rbw=
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
cloud.google.com/go/longrunning v0.6.6/go.mod h1:hyeGJUrPHcx0u2Uu1UFSoYZLn4lkMrccJig0t4FI7yw=
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
|
github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs=
|
||||||
|
github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg=
|
||||||
|
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
||||||
|
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
|
||||||
|
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
|
||||||
|
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
|
||||||
|
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||||
|
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||||
|
github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ=
|
||||||
|
github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
|
||||||
|
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
|
||||||
|
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
|
||||||
|
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
|
||||||
|
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||||
|
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||||
github.com/google/generative-ai-go v0.19.0 h1:R71szggh8wHMCUlEMsW2A/3T+5LdEIkiaHSYgSpUgdg=
|
|
||||||
github.com/google/generative-ai-go v0.19.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E=
|
|
||||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
|
github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q=
|
||||||
|
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||||
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
|
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
|
||||||
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
|
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/googleapis/enterprise-certificate-proxy v0.3.6 h1:GW/XbdyBFQ8Qe+YAmFU9uHLo7OnF5tL52HFAgMmyrf4=
|
github.com/googleapis/enterprise-certificate-proxy v0.3.11 h1:vAe81Msw+8tKUxi2Dqh/NZMz7475yUvmRIkXr4oN2ao=
|
||||||
github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA=
|
github.com/googleapis/enterprise-certificate-proxy v0.3.11/go.mod h1:RFV7MUdlb7AgEq2v7FmMCfeSMCllAzWxFgRdusoGks8=
|
||||||
github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrkurSS/Q=
|
github.com/googleapis/gax-go/v2 v2.16.0 h1:iHbQmKLLZrexmb0OSsNGTeSTS0HO4YvFOG8g5E4Zd0Y=
|
||||||
github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA=
|
github.com/googleapis/gax-go/v2 v2.16.0/go.mod h1:o1vfQjjNZn4+dPnRdl/4ZD7S9414Y4xA+a/6Icj6l14=
|
||||||
github.com/liushuangls/go-anthropic/v2 v2.15.0 h1:zpplg7BRV/9FlMmeMPI0eDwhViB0l9SkNrF8ErYlRoQ=
|
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||||
github.com/liushuangls/go-anthropic/v2 v2.15.0/go.mod h1:kq2yW3JVy1/rph8u5KzX7F3q95CEpCT2RXp/2nfCmb4=
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/openai/openai-go v0.1.0-beta.9 h1:ABpubc5yU/3ejee2GgRrbFta81SG/d7bQbB8mIdP0Xo=
|
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||||
github.com/openai/openai-go v0.1.0-beta.9/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
|
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||||
|
github.com/liushuangls/go-anthropic/v2 v2.17.0 h1:iBA6h7aghi1q86owEQ95XE2R2MF/0dQ7bCxtwTxOg4c=
|
||||||
|
github.com/liushuangls/go-anthropic/v2 v2.17.0/go.mod h1:a550cJXPoTG2FL3DvfKG2zzD5O2vjgvo4tHtoGPzFLU=
|
||||||
|
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||||
|
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||||
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||||
|
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||||
|
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||||
|
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||||
|
github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s=
|
||||||
|
github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10=
|
||||||
|
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||||
|
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||||
|
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||||
|
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||||
|
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||||
|
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||||
|
github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0=
|
||||||
|
github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||||
|
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
|
||||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
|
github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM=
|
||||||
|
github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0 h1:x7wzEgXfnzJcHDwStJT+mxOz4etr2EcexjqhBvmoakw=
|
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0/go.mod h1:rg+RlpR5dKwaS95IyyZqj5Wd4E13lk/msnTS0Xl9lJM=
|
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 h1:sbiXRNDSWJOTobXh5HyQKjq6wUC5tNybqjIqDpAY4CU=
|
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0/go.mod h1:69uWxva0WgAA/4bu2Yy70SLDBwZXuQ6PbBpbsa5iZrQ=
|
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||||
go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ=
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 h1:ssfIgGNANqpVFCndZvcuyKbl0g+UAVcbBcqGkG28H0Y=
|
||||||
go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y=
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0/go.mod h1:GQ/474YrbE4Jx8gZ4q5I4hrhUzM6UPzyrqJYV2AqPoQ=
|
||||||
go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M=
|
go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48=
|
||||||
go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE=
|
go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8=
|
||||||
go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY=
|
go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0=
|
||||||
go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg=
|
go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs=
|
||||||
go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o=
|
go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18=
|
||||||
go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w=
|
go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE=
|
||||||
go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs=
|
go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8=
|
||||||
go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc=
|
go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew=
|
||||||
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI=
|
||||||
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA=
|
||||||
golang.org/x/image v0.29.0 h1:HcdsyR4Gsuys/Axh0rDEmlBmB68rW1U9BUdB3UVHsas=
|
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
|
||||||
golang.org/x/image v0.29.0/go.mod h1:RVJROnf3SLK8d26OW91j4FrIHGbsJ8QnbEocVTOWQDA=
|
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
|
||||||
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
|
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E=
|
||||||
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
|
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
|
||||||
golang.org/x/oauth2 v0.29.0 h1:WdYw2tdTK1S8olAzWHdgeqfy+Mtm9XNhv/xJsY65d98=
|
golang.org/x/image v0.35.0 h1:LKjiHdgMtO8z7Fh18nGY6KDcoEtVfsgLDPeLyguqb7I=
|
||||||
golang.org/x/oauth2 v0.29.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
|
golang.org/x/image v0.35.0/go.mod h1:MwPLTVgvxSASsxdLzKrl8BRFuyqMyGhLwmC+TO1Sybk=
|
||||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
|
||||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
|
||||||
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
|
golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY=
|
||||||
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
||||||
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
|
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||||
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
|
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
|
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
google.golang.org/api v0.228.0 h1:X2DJ/uoWGnY5obVjewbp8icSL5U4FzuCfy9OjbLSnLs=
|
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||||
google.golang.org/api v0.228.0/go.mod h1:wNvRS1Pbe8r4+IfBIniV8fwCpGwTrYa+kMUDiC5z5a4=
|
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a h1:OQ7sHVzkx6L57dQpzUS4ckfWJ51KDH74XHTDe23xWAs=
|
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a/go.mod h1:2R6XrVC8Oc08GlNh8ujEpc7HkLiEZ16QeY7FxIs20ac=
|
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a h1:GIqLhp/cYUkuGuiT+vJk8vhOP86L4+SP5j8yXgeVpvI=
|
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
|
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||||
google.golang.org/grpc v1.71.1 h1:ffsFWr7ygTUscGPI0KKK6TLrGz0476KUvvsbqWK0rPI=
|
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||||
google.golang.org/grpc v1.71.1/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec=
|
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||||
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
|
google.golang.org/genai v1.43.0 h1:8vhqhzJNZu1U94e2m+KvDq/TUUjSmDrs1aKkvTa8SoM=
|
||||||
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
|
google.golang.org/genai v1.43.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||||
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20260122232226-8e98ce8d340d h1:xXzuihhT3gL/ntduUZwHECzAn57E8dA6l8SOtYWdD8Q=
|
||||||
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20260122232226-8e98ce8d340d/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ=
|
||||||
|
google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc=
|
||||||
|
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
|
||||||
|
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||||
|
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
|||||||
103
google.go
103
google.go
@@ -1,4 +1,4 @@
|
|||||||
package go_llm
|
package llm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -8,26 +8,28 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/google/generative-ai-go/genai"
|
"google.golang.org/genai"
|
||||||
"google.golang.org/api/option"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type google struct {
|
type googleImpl struct {
|
||||||
key string
|
key string
|
||||||
model string
|
model string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g google) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
var _ LLM = googleImpl{}
|
||||||
|
|
||||||
|
func (g googleImpl) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
||||||
g.model = modelVersion
|
g.model = modelVersion
|
||||||
|
|
||||||
return g, nil
|
return g, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (*genai.GenerativeModel, *genai.ChatSession, []genai.Part) {
|
func (g googleImpl) requestToContents(in Request) ([]*genai.Content, *genai.GenerateContentConfig) {
|
||||||
res := *model
|
var contents []*genai.Content
|
||||||
|
var cfg genai.GenerateContentConfig
|
||||||
|
|
||||||
for _, tool := range in.Toolbox.functions {
|
for _, tool := range in.Toolbox.Functions() {
|
||||||
res.Tools = append(res.Tools, &genai.Tool{
|
cfg.Tools = append(cfg.Tools, &genai.Tool{
|
||||||
FunctionDeclarations: []*genai.FunctionDeclaration{
|
FunctionDeclarations: []*genai.FunctionDeclaration{
|
||||||
{
|
{
|
||||||
Name: tool.Name,
|
Name: tool.Name,
|
||||||
@@ -38,48 +40,44 @@ func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if !in.Toolbox.RequiresTool() {
|
if in.Toolbox.RequiresTool() {
|
||||||
res.ToolConfig = &genai.ToolConfig{FunctionCallingConfig: &genai.FunctionCallingConfig{
|
cfg.ToolConfig = &genai.ToolConfig{FunctionCallingConfig: &genai.FunctionCallingConfig{
|
||||||
Mode: genai.FunctionCallingAny,
|
Mode: genai.FunctionCallingConfigModeAny,
|
||||||
}}
|
}}
|
||||||
}
|
}
|
||||||
|
|
||||||
cs := res.StartChat()
|
for _, c := range in.Messages {
|
||||||
|
var role genai.Role
|
||||||
for i, c := range in.Messages {
|
|
||||||
content := genai.NewUserContent(genai.Text(c.Text))
|
|
||||||
|
|
||||||
switch c.Role {
|
switch c.Role {
|
||||||
case RoleAssistant, RoleSystem:
|
case RoleAssistant, RoleSystem:
|
||||||
content.Role = "model"
|
role = genai.RoleModel
|
||||||
|
|
||||||
case RoleUser:
|
case RoleUser:
|
||||||
content.Role = "user"
|
role = genai.RoleUser
|
||||||
|
}
|
||||||
|
|
||||||
|
var parts []*genai.Part
|
||||||
|
if c.Text != "" {
|
||||||
|
parts = append(parts, genai.NewPartFromText(c.Text))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, img := range c.Images {
|
for _, img := range c.Images {
|
||||||
if img.Url != "" {
|
if img.Url != "" {
|
||||||
// gemini does not support URLs, so we need to download the image and convert it to a blob
|
// gemini does not support URLs, so we need to download the image and convert it to a blob
|
||||||
|
|
||||||
// Download the image from the URL
|
|
||||||
resp, err := http.Get(img.Url)
|
resp, err := http.Get(img.Url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("error downloading image: %v", err))
|
panic(fmt.Sprintf("error downloading image: %v", err))
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
// Check the Content-Length to ensure it's not over 20MB
|
|
||||||
if resp.ContentLength > 20*1024*1024 {
|
if resp.ContentLength > 20*1024*1024 {
|
||||||
panic(fmt.Sprintf("image size exceeds 20MB: %d bytes", resp.ContentLength))
|
panic(fmt.Sprintf("image size exceeds 20MB: %d bytes", resp.ContentLength))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read the content into a byte slice
|
|
||||||
data, err := io.ReadAll(resp.Body)
|
data, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("error reading image data: %v", err))
|
panic(fmt.Sprintf("error reading image data: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the MIME type is appropriate
|
|
||||||
mimeType := http.DetectContentType(data)
|
mimeType := http.DetectContentType(data)
|
||||||
switch mimeType {
|
switch mimeType {
|
||||||
case "image/jpeg", "image/png", "image/gif":
|
case "image/jpeg", "image/png", "image/gif":
|
||||||
@@ -88,38 +86,24 @@ func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (
|
|||||||
panic(fmt.Sprintf("unsupported image MIME type: %s", mimeType))
|
panic(fmt.Sprintf("unsupported image MIME type: %s", mimeType))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a genai.Blob using the validated image data
|
parts = append(parts, genai.NewPartFromBytes(data, mimeType))
|
||||||
content.Parts = append(content.Parts, genai.Blob{
|
|
||||||
MIMEType: mimeType,
|
|
||||||
Data: data,
|
|
||||||
})
|
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// convert base64 to blob
|
|
||||||
b, e := base64.StdEncoding.DecodeString(img.Base64)
|
b, e := base64.StdEncoding.DecodeString(img.Base64)
|
||||||
if e != nil {
|
if e != nil {
|
||||||
panic(fmt.Sprintf("error decoding base64: %v", e))
|
panic(fmt.Sprintf("error decoding base64: %v", e))
|
||||||
}
|
}
|
||||||
|
|
||||||
content.Parts = append(content.Parts, genai.Blob{
|
parts = append(parts, genai.NewPartFromBytes(b, img.ContentType))
|
||||||
MIMEType: img.ContentType,
|
|
||||||
Data: b,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// if this is the last message, we want to add to history, we want it to be the parts
|
contents = append(contents, genai.NewContentFromParts(parts, role))
|
||||||
if i == len(in.Messages)-1 {
|
|
||||||
return &res, cs, content.Parts
|
|
||||||
}
|
|
||||||
|
|
||||||
cs.History = append(cs.History, content)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &res, cs, nil
|
return contents, &cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) {
|
func (g googleImpl) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) {
|
||||||
res := Response{}
|
res := Response{}
|
||||||
|
|
||||||
for _, c := range in.Candidates {
|
for _, c := range in.Candidates {
|
||||||
@@ -127,15 +111,12 @@ func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Respon
|
|||||||
var set = false
|
var set = false
|
||||||
if c.Content != nil {
|
if c.Content != nil {
|
||||||
for _, p := range c.Content.Parts {
|
for _, p := range c.Content.Parts {
|
||||||
switch p.(type) {
|
if p.Text != "" {
|
||||||
case genai.Text:
|
|
||||||
choice.Content = string(p.(genai.Text))
|
|
||||||
set = true
|
set = true
|
||||||
|
choice.Content = p.Text
|
||||||
case genai.FunctionCall:
|
} else if p.FunctionCall != nil {
|
||||||
v := p.(genai.FunctionCall)
|
v := p.FunctionCall
|
||||||
b, e := json.Marshal(v.Args)
|
b, e := json.Marshal(v.Args)
|
||||||
|
|
||||||
if e != nil {
|
if e != nil {
|
||||||
return Response{}, fmt.Errorf("error marshalling args: %w", e)
|
return Response{}, fmt.Errorf("error marshalling args: %w", e)
|
||||||
}
|
}
|
||||||
@@ -150,8 +131,6 @@ func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Respon
|
|||||||
|
|
||||||
choice.Calls = append(choice.Calls, call)
|
choice.Calls = append(choice.Calls, call)
|
||||||
set = true
|
set = true
|
||||||
default:
|
|
||||||
return Response{}, fmt.Errorf("unknown part type: %T", p)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -165,23 +144,19 @@ func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Respon
|
|||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g google) ChatComplete(ctx context.Context, req Request) (Response, error) {
|
func (g googleImpl) ChatComplete(ctx context.Context, req Request) (Response, error) {
|
||||||
cl, err := genai.NewClient(ctx, option.WithAPIKey(g.key))
|
cl, err := genai.NewClient(ctx, &genai.ClientConfig{
|
||||||
|
APIKey: g.key,
|
||||||
|
Backend: genai.BackendGeminiAPI,
|
||||||
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Response{}, fmt.Errorf("error creating genai client: %w", err)
|
return Response{}, fmt.Errorf("error creating genai client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
model := cl.GenerativeModel(g.model)
|
contents, cfg := g.requestToContents(req)
|
||||||
|
|
||||||
_, cs, parts := g.requestToChatHistory(req, model)
|
|
||||||
|
|
||||||
resp, err := cs.SendMessage(ctx, parts...)
|
|
||||||
|
|
||||||
//parts := g.requestToGoogleRequest(req, model)
|
|
||||||
|
|
||||||
//resp, err := model.GenerateContent(ctx, parts...)
|
|
||||||
|
|
||||||
|
resp, err := cl.Models.GenerateContent(ctx, g.model, contents, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Response{}, fmt.Errorf("error generating content: %w", err)
|
return Response{}, fmt.Errorf("error generating content: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package utils
|
package imageutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@@ -12,8 +12,8 @@ import (
|
|||||||
"golang.org/x/image/draw"
|
"golang.org/x/image/draw"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CompressImage takes a base‑64‑encoded image (JPEG, PNG or GIF) and returns
|
// CompressImage takes a base-64-encoded image (JPEG, PNG or GIF) and returns
|
||||||
// a base‑64‑encoded version that is at most maxLength in size, or an error.
|
// a base-64-encoded version that is at most maxLength in size, or an error.
|
||||||
func CompressImage(b64 string, maxLength int) (string, string, error) {
|
func CompressImage(b64 string, maxLength int) (string, string, error) {
|
||||||
raw, err := base64.StdEncoding.DecodeString(b64)
|
raw, err := base64.StdEncoding.DecodeString(b64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -29,12 +29,12 @@ func CompressImage(b64 string, maxLength int) (string, string, error) {
|
|||||||
case "image/gif":
|
case "image/gif":
|
||||||
return compressGIF(raw, maxLength)
|
return compressGIF(raw, maxLength)
|
||||||
|
|
||||||
default: // jpeg, png, webp, etc. → treat as raster
|
default: // jpeg, png, webp, etc. -> treat as raster
|
||||||
return compressRaster(raw, maxLength)
|
return compressRaster(raw, maxLength)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------- Raster path (jpeg / png / single‑frame gif) ----------
|
// ---------- Raster path (jpeg / png / single-frame gif) ----------
|
||||||
|
|
||||||
func compressRaster(src []byte, maxLength int) (string, string, error) {
|
func compressRaster(src []byte, maxLength int) (string, string, error) {
|
||||||
img, _, err := image.Decode(bytes.NewReader(src))
|
img, _, err := image.Decode(bytes.NewReader(src))
|
||||||
@@ -57,7 +57,7 @@ func compressRaster(src []byte, maxLength int) (string, string, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// down‑scale 80%
|
// down-scale 80%
|
||||||
b := img.Bounds()
|
b := img.Bounds()
|
||||||
if b.Dx() < 100 || b.Dy() < 100 {
|
if b.Dx() < 100 || b.Dy() < 100 {
|
||||||
return "", "", fmt.Errorf("cannot compress below %.02fMiB without destroying image", float64(maxLength)/1048576.0)
|
return "", "", fmt.Errorf("cannot compress below %.02fMiB without destroying image", float64(maxLength)/1048576.0)
|
||||||
@@ -86,7 +86,7 @@ func compressGIF(src []byte, maxLength int) (string, string, error) {
|
|||||||
return base64.StdEncoding.EncodeToString(buf.Bytes()), "image/gif", nil
|
return base64.StdEncoding.EncodeToString(buf.Bytes()), "image/gif", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// down‑scale every frame by 80%
|
// down-scale every frame by 80%
|
||||||
w, h := g.Config.Width, g.Config.Height
|
w, h := g.Config.Width, g.Config.Height
|
||||||
if w < 100 || h < 100 {
|
if w < 100 || h < 100 {
|
||||||
return "", "", fmt.Errorf("cannot compress animated GIF below 5 MiB without excessive quality loss")
|
return "", "", fmt.Errorf("cannot compress animated GIF below 5 MiB without excessive quality loss")
|
||||||
@@ -94,7 +94,7 @@ func compressGIF(src []byte, maxLength int) (string, string, error) {
|
|||||||
|
|
||||||
nw, nh := int(float64(w)*0.8), int(float64(h)*0.8)
|
nw, nh := int(float64(w)*0.8), int(float64(h)*0.8)
|
||||||
for i, frm := range g.Image {
|
for i, frm := range g.Image {
|
||||||
// convert paletted frame → RGBA for scaling
|
// convert paletted frame -> RGBA for scaling
|
||||||
rgba := image.NewRGBA(frm.Bounds())
|
rgba := image.NewRGBA(frm.Bounds())
|
||||||
draw.Draw(rgba, rgba.Bounds(), frm, frm.Bounds().Min, draw.Src)
|
draw.Draw(rgba, rgba.Bounds(), frm, frm.Bounds().Min, draw.Src)
|
||||||
|
|
||||||
@@ -109,6 +109,6 @@ func compressGIF(src []byte, maxLength int) (string, string, error) {
|
|||||||
g.Image[i] = paletted
|
g.Image[i] = paletted
|
||||||
}
|
}
|
||||||
g.Config.Width, g.Config.Height = nw, nh
|
g.Config.Width, g.Config.Height = nw, nh
|
||||||
// loop back and test size again …
|
// loop back and test size again ...
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
272
llm.go
272
llm.go
@@ -1,286 +1,30 @@
|
|||||||
package go_llm
|
package llm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/openai/openai-go"
|
|
||||||
"github.com/openai/openai-go/packages/param"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Role string
|
// ChatCompletion is the interface for chat completion.
|
||||||
|
|
||||||
const (
|
|
||||||
RoleSystem Role = "system"
|
|
||||||
RoleUser Role = "user"
|
|
||||||
RoleAssistant Role = "assistant"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Image struct {
|
|
||||||
Base64 string
|
|
||||||
ContentType string
|
|
||||||
Url string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i Image) toRaw() map[string]any {
|
|
||||||
res := map[string]any{
|
|
||||||
"base64": i.Base64,
|
|
||||||
"contenttype": i.ContentType,
|
|
||||||
"url": i.Url,
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *Image) fromRaw(raw map[string]any) Image {
|
|
||||||
var res Image
|
|
||||||
|
|
||||||
res.Base64 = raw["base64"].(string)
|
|
||||||
res.ContentType = raw["contenttype"].(string)
|
|
||||||
res.Url = raw["url"].(string)
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
type Message struct {
|
|
||||||
Role Role
|
|
||||||
Name string
|
|
||||||
Text string
|
|
||||||
Images []Image
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Message) toRaw() map[string]any {
|
|
||||||
res := map[string]any{
|
|
||||||
"role": m.Role,
|
|
||||||
"name": m.Name,
|
|
||||||
"text": m.Text,
|
|
||||||
}
|
|
||||||
|
|
||||||
images := make([]map[string]any, 0, len(m.Images))
|
|
||||||
for _, img := range m.Images {
|
|
||||||
images = append(images, img.toRaw())
|
|
||||||
}
|
|
||||||
|
|
||||||
res["images"] = images
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Message) fromRaw(raw map[string]any) Message {
|
|
||||||
var res Message
|
|
||||||
|
|
||||||
res.Role = Role(raw["role"].(string))
|
|
||||||
res.Name = raw["name"].(string)
|
|
||||||
res.Text = raw["text"].(string)
|
|
||||||
|
|
||||||
images := raw["images"].([]map[string]any)
|
|
||||||
for _, img := range images {
|
|
||||||
var i Image
|
|
||||||
|
|
||||||
res.Images = append(res.Images, i.fromRaw(img))
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Message) toChatCompletionMessages(model string) []openai.ChatCompletionMessageParamUnion {
|
|
||||||
var res openai.ChatCompletionMessageParamUnion
|
|
||||||
|
|
||||||
var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam
|
|
||||||
var textContent param.Opt[string]
|
|
||||||
|
|
||||||
for _, img := range m.Images {
|
|
||||||
if img.Base64 != "" {
|
|
||||||
arrayOfContentParts = append(arrayOfContentParts,
|
|
||||||
openai.ChatCompletionContentPartUnionParam{
|
|
||||||
OfImageURL: &openai.ChatCompletionContentPartImageParam{
|
|
||||||
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
|
|
||||||
URL: "data:" + img.ContentType + ";base64," + img.Base64,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
} else if img.Url != "" {
|
|
||||||
arrayOfContentParts = append(arrayOfContentParts,
|
|
||||||
openai.ChatCompletionContentPartUnionParam{
|
|
||||||
OfImageURL: &openai.ChatCompletionContentPartImageParam{
|
|
||||||
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
|
|
||||||
URL: img.Url,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.Text != "" {
|
|
||||||
if len(arrayOfContentParts) > 0 {
|
|
||||||
arrayOfContentParts = append(arrayOfContentParts,
|
|
||||||
openai.ChatCompletionContentPartUnionParam{
|
|
||||||
OfText: &openai.ChatCompletionContentPartTextParam{
|
|
||||||
Text: "\n",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
textContent = openai.String(m.Text)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
a := strings.Split(model, "-")
|
|
||||||
|
|
||||||
useSystemInsteadOfDeveloper := true
|
|
||||||
if len(a) > 1 && a[0][0] == 'o' {
|
|
||||||
useSystemInsteadOfDeveloper = false
|
|
||||||
}
|
|
||||||
|
|
||||||
switch m.Role {
|
|
||||||
case RoleSystem:
|
|
||||||
if useSystemInsteadOfDeveloper {
|
|
||||||
res = openai.ChatCompletionMessageParamUnion{
|
|
||||||
OfSystem: &openai.ChatCompletionSystemMessageParam{
|
|
||||||
Content: openai.ChatCompletionSystemMessageParamContentUnion{
|
|
||||||
OfString: textContent,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
res = openai.ChatCompletionMessageParamUnion{
|
|
||||||
OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{
|
|
||||||
Content: openai.ChatCompletionDeveloperMessageParamContentUnion{
|
|
||||||
OfString: textContent,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case RoleUser:
|
|
||||||
var name param.Opt[string]
|
|
||||||
if m.Name != "" {
|
|
||||||
name = openai.String(m.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
res = openai.ChatCompletionMessageParamUnion{
|
|
||||||
OfUser: &openai.ChatCompletionUserMessageParam{
|
|
||||||
Name: name,
|
|
||||||
Content: openai.ChatCompletionUserMessageParamContentUnion{
|
|
||||||
OfString: textContent,
|
|
||||||
OfArrayOfContentParts: arrayOfContentParts,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
case RoleAssistant:
|
|
||||||
var name param.Opt[string]
|
|
||||||
if m.Name != "" {
|
|
||||||
name = openai.String(m.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
res = openai.ChatCompletionMessageParamUnion{
|
|
||||||
OfAssistant: &openai.ChatCompletionAssistantMessageParam{
|
|
||||||
Name: name,
|
|
||||||
Content: openai.ChatCompletionAssistantMessageParamContentUnion{
|
|
||||||
OfString: textContent,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return []openai.ChatCompletionMessageParamUnion{res}
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolCall struct {
|
|
||||||
ID string
|
|
||||||
FunctionCall FunctionCall
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t ToolCall) toRaw() map[string]any {
|
|
||||||
res := map[string]any{
|
|
||||||
"id": t.ID,
|
|
||||||
}
|
|
||||||
|
|
||||||
res["function"] = t.FunctionCall.toRaw()
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t ToolCall) toChatCompletionMessages(_ string) []openai.ChatCompletionMessageParamUnion {
|
|
||||||
return []openai.ChatCompletionMessageParamUnion{{
|
|
||||||
OfAssistant: &openai.ChatCompletionAssistantMessageParam{
|
|
||||||
ToolCalls: []openai.ChatCompletionMessageToolCallParam{
|
|
||||||
{
|
|
||||||
ID: t.ID,
|
|
||||||
Function: openai.ChatCompletionMessageToolCallFunctionParam{
|
|
||||||
Name: t.FunctionCall.Name,
|
|
||||||
Arguments: t.FunctionCall.Arguments,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolCallResponse struct {
|
|
||||||
ID string
|
|
||||||
Result any
|
|
||||||
Error error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t ToolCallResponse) toRaw() map[string]any {
|
|
||||||
res := map[string]any{
|
|
||||||
"id": t.ID,
|
|
||||||
"result": t.Result,
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.Error != nil {
|
|
||||||
res["error"] = t.Error.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t ToolCallResponse) toChatCompletionMessages(_ string) []openai.ChatCompletionMessageParamUnion {
|
|
||||||
var refusal string
|
|
||||||
if t.Error != nil {
|
|
||||||
refusal = t.Error.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
if refusal != "" {
|
|
||||||
if t.Result != "" {
|
|
||||||
t.Result = fmt.Sprint(t.Result) + " (error in execution: " + refusal + ")"
|
|
||||||
} else {
|
|
||||||
t.Result = "error in execution:" + refusal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return []openai.ChatCompletionMessageParamUnion{{
|
|
||||||
OfTool: &openai.ChatCompletionToolMessageParam{
|
|
||||||
ToolCallID: t.ID,
|
|
||||||
Content: openai.ChatCompletionToolMessageParamContentUnion{
|
|
||||||
OfString: openai.String(fmt.Sprint(t.Result)),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletion interface {
|
type ChatCompletion interface {
|
||||||
ChatComplete(ctx context.Context, req Request) (Response, error)
|
ChatComplete(ctx context.Context, req Request) (Response, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LLM is the interface for language model providers.
|
||||||
type LLM interface {
|
type LLM interface {
|
||||||
ModelVersion(modelVersion string) (ChatCompletion, error)
|
ModelVersion(modelVersion string) (ChatCompletion, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenAI creates a new OpenAI LLM provider with the given API key.
|
||||||
func OpenAI(key string) LLM {
|
func OpenAI(key string) LLM {
|
||||||
return openaiImpl{key: key}
|
return openaiImpl{key: key}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Anthropic creates a new Anthropic LLM provider with the given API key.
|
||||||
func Anthropic(key string) LLM {
|
func Anthropic(key string) LLM {
|
||||||
return anthropic{key: key}
|
return anthropicImpl{key: key}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Google creates a new Google LLM provider with the given API key.
|
||||||
func Google(key string) LLM {
|
func Google(key string) LLM {
|
||||||
return google{key: key}
|
return googleImpl{key: key}
|
||||||
}
|
}
|
||||||
|
|||||||
238
mcp.go
Normal file
238
mcp.go
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MCPServer represents a connection to an MCP server.
|
||||||
|
// It manages the lifecycle of the connection and provides access to the server's tools.
|
||||||
|
type MCPServer struct {
|
||||||
|
// Name is a friendly name for this server (used for logging/identification)
|
||||||
|
Name string
|
||||||
|
|
||||||
|
// Command is the command to run the MCP server (for stdio transport)
|
||||||
|
Command string
|
||||||
|
|
||||||
|
// Args are arguments to pass to the command
|
||||||
|
Args []string
|
||||||
|
|
||||||
|
// Env are environment variables to set for the command (in addition to current environment)
|
||||||
|
Env []string
|
||||||
|
|
||||||
|
// URL is the URL for SSE or HTTP transport (alternative to Command)
|
||||||
|
URL string
|
||||||
|
|
||||||
|
// Transport specifies the transport type: "stdio" (default), "sse", or "http"
|
||||||
|
Transport string
|
||||||
|
|
||||||
|
client *mcp.Client
|
||||||
|
session *mcp.ClientSession
|
||||||
|
tools map[string]*mcp.Tool // tool name -> tool definition
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect establishes a connection to the MCP server.
|
||||||
|
func (m *MCPServer) Connect(ctx context.Context) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.session != nil {
|
||||||
|
return nil // Already connected
|
||||||
|
}
|
||||||
|
|
||||||
|
m.client = mcp.NewClient(&mcp.Implementation{
|
||||||
|
Name: "go-llm",
|
||||||
|
Version: "1.0.0",
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
var transport mcp.Transport
|
||||||
|
|
||||||
|
switch m.Transport {
|
||||||
|
case "sse":
|
||||||
|
transport = &mcp.SSEClientTransport{
|
||||||
|
Endpoint: m.URL,
|
||||||
|
}
|
||||||
|
case "http":
|
||||||
|
transport = &mcp.StreamableClientTransport{
|
||||||
|
Endpoint: m.URL,
|
||||||
|
}
|
||||||
|
default: // "stdio" or empty
|
||||||
|
cmd := exec.Command(m.Command, m.Args...)
|
||||||
|
cmd.Env = append(os.Environ(), m.Env...)
|
||||||
|
transport = &mcp.CommandTransport{
|
||||||
|
Command: cmd,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := m.client.Connect(ctx, transport, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to MCP server %s: %w", m.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.session = session
|
||||||
|
|
||||||
|
// Load tools
|
||||||
|
m.tools = make(map[string]*mcp.Tool)
|
||||||
|
for tool, err := range session.Tools(ctx, nil) {
|
||||||
|
if err != nil {
|
||||||
|
m.session.Close()
|
||||||
|
m.session = nil
|
||||||
|
return fmt.Errorf("failed to list tools from %s: %w", m.Name, err)
|
||||||
|
}
|
||||||
|
m.tools[tool.Name] = tool
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the connection to the MCP server.
|
||||||
|
func (m *MCPServer) Close() error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.session == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := m.session.Close()
|
||||||
|
m.session = nil
|
||||||
|
m.tools = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsConnected returns true if the server is connected.
|
||||||
|
func (m *MCPServer) IsConnected() bool {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
return m.session != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tools returns the list of tool names available from this server.
|
||||||
|
func (m *MCPServer) Tools() []string {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
var names []string
|
||||||
|
for name := range m.tools {
|
||||||
|
names = append(names, name)
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasTool returns true if this server provides the named tool.
|
||||||
|
func (m *MCPServer) HasTool(name string) bool {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
_, ok := m.tools[name]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// CallTool calls a tool on the MCP server.
|
||||||
|
func (m *MCPServer) CallTool(ctx context.Context, name string, arguments map[string]any) (any, error) {
|
||||||
|
m.mu.RLock()
|
||||||
|
session := m.session
|
||||||
|
m.mu.RUnlock()
|
||||||
|
|
||||||
|
if session == nil {
|
||||||
|
return nil, fmt.Errorf("not connected to MCP server %s", m.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := session.CallTool(ctx, &mcp.CallToolParams{
|
||||||
|
Name: name,
|
||||||
|
Arguments: arguments,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the result content
|
||||||
|
if len(result.Content) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there's a single text content, return it as a string
|
||||||
|
if len(result.Content) == 1 {
|
||||||
|
if textContent, ok := result.Content[0].(*mcp.TextContent); ok {
|
||||||
|
return textContent.Text, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For multiple contents or non-text, serialize to string
|
||||||
|
return contentToString(result.Content), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// toFunction converts an MCP tool to a go-llm Function (for schema purposes only).
|
||||||
|
func (m *MCPServer) toFunction(tool *mcp.Tool) Function {
|
||||||
|
var inputSchema map[string]any
|
||||||
|
if tool.InputSchema != nil {
|
||||||
|
data, err := json.Marshal(tool.InputSchema)
|
||||||
|
if err == nil {
|
||||||
|
_ = json.Unmarshal(data, &inputSchema)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if inputSchema == nil {
|
||||||
|
inputSchema = map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Function{
|
||||||
|
Name: tool.Name,
|
||||||
|
Description: tool.Description,
|
||||||
|
Parameters: schema.NewRaw(inputSchema),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// contentToString converts MCP content to a string representation.
|
||||||
|
func contentToString(content []mcp.Content) string {
|
||||||
|
var parts []string
|
||||||
|
for _, c := range content {
|
||||||
|
switch tc := c.(type) {
|
||||||
|
case *mcp.TextContent:
|
||||||
|
parts = append(parts, tc.Text)
|
||||||
|
default:
|
||||||
|
if data, err := json.Marshal(c); err == nil {
|
||||||
|
parts = append(parts, string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(parts) == 1 {
|
||||||
|
return parts[0]
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(parts)
|
||||||
|
return string(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMCPServer adds an MCP server to the toolbox.
|
||||||
|
// The server must already be connected. Tools from the server will be available
|
||||||
|
// for use, and tool calls will be routed to the appropriate server.
|
||||||
|
func (t ToolBox) WithMCPServer(server *MCPServer) ToolBox {
|
||||||
|
if t.mcpServers == nil {
|
||||||
|
t.mcpServers = make(map[string]*MCPServer)
|
||||||
|
}
|
||||||
|
|
||||||
|
server.mu.RLock()
|
||||||
|
defer server.mu.RUnlock()
|
||||||
|
|
||||||
|
for name, tool := range server.tools {
|
||||||
|
// Add the function definition (for schema)
|
||||||
|
fn := server.toFunction(tool)
|
||||||
|
t.functions[name] = fn
|
||||||
|
|
||||||
|
// Track which server owns this tool
|
||||||
|
t.mcpServers[name] = server
|
||||||
|
}
|
||||||
|
|
||||||
|
return t
|
||||||
|
}
|
||||||
115
message.go
Normal file
115
message.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
// Role represents the role of a message in a conversation.
|
||||||
|
type Role string
|
||||||
|
|
||||||
|
const (
|
||||||
|
RoleSystem Role = "system"
|
||||||
|
RoleUser Role = "user"
|
||||||
|
RoleAssistant Role = "assistant"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Image represents an image that can be included in a message.
|
||||||
|
type Image struct {
|
||||||
|
Base64 string
|
||||||
|
ContentType string
|
||||||
|
Url string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i Image) toRaw() map[string]any {
|
||||||
|
res := map[string]any{
|
||||||
|
"base64": i.Base64,
|
||||||
|
"contenttype": i.ContentType,
|
||||||
|
"url": i.Url,
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Image) fromRaw(raw map[string]any) Image {
|
||||||
|
var res Image
|
||||||
|
|
||||||
|
res.Base64 = raw["base64"].(string)
|
||||||
|
res.ContentType = raw["contenttype"].(string)
|
||||||
|
res.Url = raw["url"].(string)
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message represents a message in a conversation.
|
||||||
|
type Message struct {
|
||||||
|
Role Role
|
||||||
|
Name string
|
||||||
|
Text string
|
||||||
|
Images []Image
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Message) toRaw() map[string]any {
|
||||||
|
res := map[string]any{
|
||||||
|
"role": m.Role,
|
||||||
|
"name": m.Name,
|
||||||
|
"text": m.Text,
|
||||||
|
}
|
||||||
|
|
||||||
|
images := make([]map[string]any, 0, len(m.Images))
|
||||||
|
for _, img := range m.Images {
|
||||||
|
images = append(images, img.toRaw())
|
||||||
|
}
|
||||||
|
|
||||||
|
res["images"] = images
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) fromRaw(raw map[string]any) Message {
|
||||||
|
var res Message
|
||||||
|
|
||||||
|
res.Role = Role(raw["role"].(string))
|
||||||
|
res.Name = raw["name"].(string)
|
||||||
|
res.Text = raw["text"].(string)
|
||||||
|
|
||||||
|
images := raw["images"].([]map[string]any)
|
||||||
|
for _, img := range images {
|
||||||
|
var i Image
|
||||||
|
|
||||||
|
res.Images = append(res.Images, i.fromRaw(img))
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolCall represents a tool call made by an assistant.
|
||||||
|
type ToolCall struct {
|
||||||
|
ID string
|
||||||
|
FunctionCall FunctionCall
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t ToolCall) toRaw() map[string]any {
|
||||||
|
res := map[string]any{
|
||||||
|
"id": t.ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
res["function"] = t.FunctionCall.toRaw()
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolCallResponse represents the response to a tool call.
|
||||||
|
type ToolCallResponse struct {
|
||||||
|
ID string
|
||||||
|
Result any
|
||||||
|
Error error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t ToolCallResponse) toRaw() map[string]any {
|
||||||
|
res := map[string]any{
|
||||||
|
"id": t.ID,
|
||||||
|
"result": t.Result,
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.Error != nil {
|
||||||
|
res["error"] = t.Error.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
210
openai.go
210
openai.go
@@ -1,4 +1,4 @@
|
|||||||
package go_llm
|
package llm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/openai/openai-go"
|
"github.com/openai/openai-go"
|
||||||
"github.com/openai/openai-go/option"
|
"github.com/openai/openai-go/option"
|
||||||
|
"github.com/openai/openai-go/packages/param"
|
||||||
"github.com/openai/openai-go/shared"
|
"github.com/openai/openai-go/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -24,14 +25,14 @@ func (o openaiImpl) newRequestToOpenAIRequest(request Request) openai.ChatComple
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, i := range request.Conversation {
|
for _, i := range request.Conversation {
|
||||||
res.Messages = append(res.Messages, i.toChatCompletionMessages(o.model)...)
|
res.Messages = append(res.Messages, inputToChatCompletionMessages(i, o.model)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, msg := range request.Messages {
|
for _, msg := range request.Messages {
|
||||||
res.Messages = append(res.Messages, msg.toChatCompletionMessages(o.model)...)
|
res.Messages = append(res.Messages, messageToChatCompletionMessages(msg, o.model)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tool := range request.Toolbox.functions {
|
for _, tool := range request.Toolbox.Functions() {
|
||||||
res.Tools = append(res.Tools, openai.ChatCompletionToolParam{
|
res.Tools = append(res.Tools, openai.ChatCompletionToolParam{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
Function: shared.FunctionDefinitionParam{
|
Function: shared.FunctionDefinitionParam{
|
||||||
@@ -111,10 +112,9 @@ func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response
|
|||||||
req := o.newRequestToOpenAIRequest(request)
|
req := o.newRequestToOpenAIRequest(request)
|
||||||
|
|
||||||
resp, err := cl.Chat.Completions.New(ctx, req)
|
resp, err := cl.Chat.Completions.New(ctx, req)
|
||||||
//resp, err := cl.CreateChatCompletion(ctx, req)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Response{}, fmt.Errorf("unhandled openaiImpl error: %w", err)
|
return Response{}, fmt.Errorf("unhandled openai error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return o.responseToLLMResponse(resp), nil
|
return o.responseToLLMResponse(resp), nil
|
||||||
@@ -122,7 +122,201 @@ func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response
|
|||||||
|
|
||||||
func (o openaiImpl) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
func (o openaiImpl) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
||||||
return openaiImpl{
|
return openaiImpl{
|
||||||
key: o.key,
|
key: o.key,
|
||||||
model: modelVersion,
|
model: modelVersion,
|
||||||
|
baseUrl: o.baseUrl,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// inputToChatCompletionMessages converts an Input to OpenAI chat completion messages.
|
||||||
|
func inputToChatCompletionMessages(input Input, model string) []openai.ChatCompletionMessageParamUnion {
|
||||||
|
switch v := input.(type) {
|
||||||
|
case Message:
|
||||||
|
return messageToChatCompletionMessages(v, model)
|
||||||
|
case ToolCall:
|
||||||
|
return toolCallToChatCompletionMessages(v)
|
||||||
|
case ToolCallResponse:
|
||||||
|
return toolCallResponseToChatCompletionMessages(v)
|
||||||
|
case ResponseChoice:
|
||||||
|
return responseChoiceToChatCompletionMessages(v)
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func messageToChatCompletionMessages(m Message, model string) []openai.ChatCompletionMessageParamUnion {
|
||||||
|
var res openai.ChatCompletionMessageParamUnion
|
||||||
|
|
||||||
|
var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam
|
||||||
|
var textContent param.Opt[string]
|
||||||
|
|
||||||
|
for _, img := range m.Images {
|
||||||
|
if img.Base64 != "" {
|
||||||
|
arrayOfContentParts = append(arrayOfContentParts,
|
||||||
|
openai.ChatCompletionContentPartUnionParam{
|
||||||
|
OfImageURL: &openai.ChatCompletionContentPartImageParam{
|
||||||
|
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
|
||||||
|
URL: "data:" + img.ContentType + ";base64," + img.Base64,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
} else if img.Url != "" {
|
||||||
|
arrayOfContentParts = append(arrayOfContentParts,
|
||||||
|
openai.ChatCompletionContentPartUnionParam{
|
||||||
|
OfImageURL: &openai.ChatCompletionContentPartImageParam{
|
||||||
|
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
|
||||||
|
URL: img.Url,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Text != "" {
|
||||||
|
if len(arrayOfContentParts) > 0 {
|
||||||
|
arrayOfContentParts = append(arrayOfContentParts,
|
||||||
|
openai.ChatCompletionContentPartUnionParam{
|
||||||
|
OfText: &openai.ChatCompletionContentPartTextParam{
|
||||||
|
Text: "\n",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
textContent = openai.String(m.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
a := strings.Split(model, "-")
|
||||||
|
|
||||||
|
useSystemInsteadOfDeveloper := true
|
||||||
|
if len(a) > 1 && a[0][0] == 'o' {
|
||||||
|
useSystemInsteadOfDeveloper = false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch m.Role {
|
||||||
|
case RoleSystem:
|
||||||
|
if useSystemInsteadOfDeveloper {
|
||||||
|
res = openai.ChatCompletionMessageParamUnion{
|
||||||
|
OfSystem: &openai.ChatCompletionSystemMessageParam{
|
||||||
|
Content: openai.ChatCompletionSystemMessageParamContentUnion{
|
||||||
|
OfString: textContent,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
res = openai.ChatCompletionMessageParamUnion{
|
||||||
|
OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{
|
||||||
|
Content: openai.ChatCompletionDeveloperMessageParamContentUnion{
|
||||||
|
OfString: textContent,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case RoleUser:
|
||||||
|
var name param.Opt[string]
|
||||||
|
if m.Name != "" {
|
||||||
|
name = openai.String(m.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
res = openai.ChatCompletionMessageParamUnion{
|
||||||
|
OfUser: &openai.ChatCompletionUserMessageParam{
|
||||||
|
Name: name,
|
||||||
|
Content: openai.ChatCompletionUserMessageParamContentUnion{
|
||||||
|
OfString: textContent,
|
||||||
|
OfArrayOfContentParts: arrayOfContentParts,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
case RoleAssistant:
|
||||||
|
var name param.Opt[string]
|
||||||
|
if m.Name != "" {
|
||||||
|
name = openai.String(m.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
res = openai.ChatCompletionMessageParamUnion{
|
||||||
|
OfAssistant: &openai.ChatCompletionAssistantMessageParam{
|
||||||
|
Name: name,
|
||||||
|
Content: openai.ChatCompletionAssistantMessageParamContentUnion{
|
||||||
|
OfString: textContent,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return []openai.ChatCompletionMessageParamUnion{res}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toolCallToChatCompletionMessages(t ToolCall) []openai.ChatCompletionMessageParamUnion {
|
||||||
|
return []openai.ChatCompletionMessageParamUnion{{
|
||||||
|
OfAssistant: &openai.ChatCompletionAssistantMessageParam{
|
||||||
|
ToolCalls: []openai.ChatCompletionMessageToolCallParam{
|
||||||
|
{
|
||||||
|
ID: t.ID,
|
||||||
|
Function: openai.ChatCompletionMessageToolCallFunctionParam{
|
||||||
|
Name: t.FunctionCall.Name,
|
||||||
|
Arguments: t.FunctionCall.Arguments,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toolCallResponseToChatCompletionMessages(t ToolCallResponse) []openai.ChatCompletionMessageParamUnion {
|
||||||
|
var refusal string
|
||||||
|
if t.Error != nil {
|
||||||
|
refusal = t.Error.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
result := t.Result
|
||||||
|
if refusal != "" {
|
||||||
|
if result != "" {
|
||||||
|
result = fmt.Sprint(result) + " (error in execution: " + refusal + ")"
|
||||||
|
} else {
|
||||||
|
result = "error in execution:" + refusal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return []openai.ChatCompletionMessageParamUnion{{
|
||||||
|
OfTool: &openai.ChatCompletionToolMessageParam{
|
||||||
|
ToolCallID: t.ID,
|
||||||
|
Content: openai.ChatCompletionToolMessageParamContentUnion{
|
||||||
|
OfString: openai.String(fmt.Sprint(result)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseChoiceToChatCompletionMessages(r ResponseChoice) []openai.ChatCompletionMessageParamUnion {
|
||||||
|
var as openai.ChatCompletionAssistantMessageParam
|
||||||
|
|
||||||
|
if r.Name != "" {
|
||||||
|
as.Name = openai.String(r.Name)
|
||||||
|
}
|
||||||
|
if r.Refusal != "" {
|
||||||
|
as.Refusal = openai.String(r.Refusal)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Content != "" {
|
||||||
|
as.Content.OfString = openai.String(r.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, call := range r.Calls {
|
||||||
|
as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{
|
||||||
|
ID: call.ID,
|
||||||
|
Function: openai.ChatCompletionMessageToolCallFunctionParam{
|
||||||
|
Name: call.FunctionCall.Name,
|
||||||
|
Arguments: call.FunctionCall.Arguments,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return []openai.ChatCompletionMessageParamUnion{
|
||||||
|
{
|
||||||
|
OfAssistant: &as,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
219
openai_transcriber.go
Normal file
219
openai_transcriber.go
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/openai/openai-go"
|
||||||
|
"github.com/openai/openai-go/option"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openaiTranscriber struct {
|
||||||
|
key string
|
||||||
|
model string
|
||||||
|
baseUrl string
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Transcriber = openaiTranscriber{}
|
||||||
|
|
||||||
|
// OpenAITranscriber creates a transcriber backed by OpenAI's audio models.
|
||||||
|
// If model is empty, whisper-1 is used by default.
|
||||||
|
func OpenAITranscriber(key string, model string) Transcriber {
|
||||||
|
if strings.TrimSpace(model) == "" {
|
||||||
|
model = "whisper-1"
|
||||||
|
}
|
||||||
|
return openaiTranscriber{
|
||||||
|
key: key,
|
||||||
|
model: model,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o openaiTranscriber) Transcribe(ctx context.Context, wav []byte, opts TranscriptionOptions) (Transcription, error) {
|
||||||
|
if len(wav) == 0 {
|
||||||
|
return Transcription{}, fmt.Errorf("wav data is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
format := opts.ResponseFormat
|
||||||
|
if format == "" {
|
||||||
|
if strings.HasPrefix(o.model, "gpt-4o") {
|
||||||
|
format = TranscriptionResponseFormatJSON
|
||||||
|
} else {
|
||||||
|
format = TranscriptionResponseFormatVerboseJSON
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if format != TranscriptionResponseFormatJSON && format != TranscriptionResponseFormatVerboseJSON {
|
||||||
|
return Transcription{}, fmt.Errorf("openai transcriber requires response_format json or verbose_json for structured output")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(opts.TimestampGranularities) > 0 && format != TranscriptionResponseFormatVerboseJSON {
|
||||||
|
return Transcription{}, fmt.Errorf("timestamp granularities require response_format=verbose_json")
|
||||||
|
}
|
||||||
|
|
||||||
|
params := openai.AudioTranscriptionNewParams{
|
||||||
|
File: openai.File(bytes.NewReader(wav), "audio.wav", "audio/wav"),
|
||||||
|
Model: openai.AudioModel(o.model),
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.Language != "" {
|
||||||
|
params.Language = openai.String(opts.Language)
|
||||||
|
}
|
||||||
|
if opts.Prompt != "" {
|
||||||
|
params.Prompt = openai.String(opts.Prompt)
|
||||||
|
}
|
||||||
|
if opts.Temperature != nil {
|
||||||
|
params.Temperature = openai.Float(*opts.Temperature)
|
||||||
|
}
|
||||||
|
|
||||||
|
params.ResponseFormat = openai.AudioResponseFormat(format)
|
||||||
|
|
||||||
|
if opts.IncludeLogprobs {
|
||||||
|
params.Include = []openai.TranscriptionInclude{openai.TranscriptionIncludeLogprobs}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(opts.TimestampGranularities) > 0 {
|
||||||
|
for _, granularity := range opts.TimestampGranularities {
|
||||||
|
params.TimestampGranularities = append(params.TimestampGranularities, string(granularity))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
clientOptions := []option.RequestOption{
|
||||||
|
option.WithAPIKey(o.key),
|
||||||
|
}
|
||||||
|
if o.baseUrl != "" {
|
||||||
|
clientOptions = append(clientOptions, option.WithBaseURL(o.baseUrl))
|
||||||
|
}
|
||||||
|
|
||||||
|
client := openai.NewClient(clientOptions...)
|
||||||
|
resp, err := client.Audio.Transcriptions.New(ctx, params)
|
||||||
|
if err != nil {
|
||||||
|
return Transcription{}, fmt.Errorf("openai transcription failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return openaiTranscriptionToResult(o.model, resp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type openaiVerboseTranscription struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
Language string `json:"language"`
|
||||||
|
Duration float64 `json:"duration"`
|
||||||
|
Segments []openaiVerboseSegment `json:"segments"`
|
||||||
|
Words []openaiVerboseWord `json:"words"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type openaiVerboseSegment struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Start float64 `json:"start"`
|
||||||
|
End float64 `json:"end"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
Tokens []int `json:"tokens"`
|
||||||
|
AvgLogprob *float64 `json:"avg_logprob"`
|
||||||
|
CompressionRatio *float64 `json:"compression_ratio"`
|
||||||
|
NoSpeechProb *float64 `json:"no_speech_prob"`
|
||||||
|
Words []openaiVerboseWord `json:"words"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type openaiVerboseWord struct {
|
||||||
|
Word string `json:"word"`
|
||||||
|
Start float64 `json:"start"`
|
||||||
|
End float64 `json:"end"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func openaiTranscriptionToResult(model string, resp *openai.Transcription) Transcription {
|
||||||
|
result := Transcription{
|
||||||
|
Provider: "openai",
|
||||||
|
Model: model,
|
||||||
|
}
|
||||||
|
if resp == nil {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
result.Text = resp.Text
|
||||||
|
result.RawJSON = resp.RawJSON()
|
||||||
|
|
||||||
|
for _, logprob := range resp.Logprobs {
|
||||||
|
result.Logprobs = append(result.Logprobs, TranscriptionTokenLogprob{
|
||||||
|
Token: logprob.Token,
|
||||||
|
Bytes: logprob.Bytes,
|
||||||
|
Logprob: logprob.Logprob,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if usage := openaiUsageToTranscriptionUsage(resp.Usage); usage.Type != "" {
|
||||||
|
result.Usage = usage
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RawJSON == "" {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
var verbose openaiVerboseTranscription
|
||||||
|
if err := json.Unmarshal([]byte(result.RawJSON), &verbose); err != nil {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
if verbose.Text != "" {
|
||||||
|
result.Text = verbose.Text
|
||||||
|
}
|
||||||
|
result.Language = verbose.Language
|
||||||
|
result.DurationSeconds = verbose.Duration
|
||||||
|
|
||||||
|
for _, seg := range verbose.Segments {
|
||||||
|
segment := TranscriptionSegment{
|
||||||
|
ID: seg.ID,
|
||||||
|
Start: seg.Start,
|
||||||
|
End: seg.End,
|
||||||
|
Text: seg.Text,
|
||||||
|
Tokens: append([]int(nil), seg.Tokens...),
|
||||||
|
AvgLogprob: seg.AvgLogprob,
|
||||||
|
CompressionRatio: seg.CompressionRatio,
|
||||||
|
NoSpeechProb: seg.NoSpeechProb,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, word := range seg.Words {
|
||||||
|
segment.Words = append(segment.Words, TranscriptionWord{
|
||||||
|
Word: word.Word,
|
||||||
|
Start: word.Start,
|
||||||
|
End: word.End,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
result.Segments = append(result.Segments, segment)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, word := range verbose.Words {
|
||||||
|
result.Words = append(result.Words, TranscriptionWord{
|
||||||
|
Word: word.Word,
|
||||||
|
Start: word.Start,
|
||||||
|
End: word.End,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func openaiUsageToTranscriptionUsage(usage openai.TranscriptionUsageUnion) TranscriptionUsage {
|
||||||
|
switch usage.Type {
|
||||||
|
case "tokens":
|
||||||
|
tokens := usage.AsTokens()
|
||||||
|
return TranscriptionUsage{
|
||||||
|
Type: usage.Type,
|
||||||
|
InputTokens: tokens.InputTokens,
|
||||||
|
OutputTokens: tokens.OutputTokens,
|
||||||
|
TotalTokens: tokens.TotalTokens,
|
||||||
|
AudioTokens: tokens.InputTokenDetails.AudioTokens,
|
||||||
|
TextTokens: tokens.InputTokenDetails.TextTokens,
|
||||||
|
}
|
||||||
|
case "duration":
|
||||||
|
duration := usage.AsDuration()
|
||||||
|
return TranscriptionUsage{
|
||||||
|
Type: usage.Type,
|
||||||
|
Seconds: duration.Seconds,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return TranscriptionUsage{}
|
||||||
|
}
|
||||||
|
}
|
||||||
2
parse.go
2
parse.go
@@ -1,4 +1,4 @@
|
|||||||
package go_llm
|
package llm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|||||||
11
provider/anthropic/anthropic.go
Normal file
11
provider/anthropic/anthropic.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// Package anthropic provides the Anthropic LLM provider.
|
||||||
|
package anthropic
|
||||||
|
|
||||||
|
import (
|
||||||
|
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// New creates a new Anthropic LLM provider with the given API key.
|
||||||
|
func New(key string) llm.LLM {
|
||||||
|
return llm.Anthropic(key)
|
||||||
|
}
|
||||||
11
provider/google/google.go
Normal file
11
provider/google/google.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// Package google provides the Google LLM provider.
|
||||||
|
package google
|
||||||
|
|
||||||
|
import (
|
||||||
|
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// New creates a new Google LLM provider with the given API key.
|
||||||
|
func New(key string) llm.LLM {
|
||||||
|
return llm.Google(key)
|
||||||
|
}
|
||||||
11
provider/openai/openai.go
Normal file
11
provider/openai/openai.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// Package openai provides the OpenAI LLM provider.
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
llm "gitea.stevedudenhoeffer.com/steve/go-llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// New creates a new OpenAI LLM provider with the given API key.
|
||||||
|
func New(key string) llm.LLM {
|
||||||
|
return llm.OpenAI(key)
|
||||||
|
}
|
||||||
25
request.go
25
request.go
@@ -1,17 +1,20 @@
|
|||||||
package go_llm
|
package llm
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/openai/openai-go"
|
|
||||||
)
|
|
||||||
|
|
||||||
type rawAble interface {
|
|
||||||
toRaw() map[string]any
|
|
||||||
fromRaw(raw map[string]any) Input
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// Input is the interface for conversation inputs.
|
||||||
|
// Types that implement this interface can be part of a conversation:
|
||||||
|
// Message, ToolCall, ToolCallResponse, and ResponseChoice.
|
||||||
type Input interface {
|
type Input interface {
|
||||||
toChatCompletionMessages(model string) []openai.ChatCompletionMessageParamUnion
|
// isInput is a marker method to ensure only valid types implement this interface.
|
||||||
|
isInput()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Implement Input interface for all valid input types.
|
||||||
|
func (Message) isInput() {}
|
||||||
|
func (ToolCall) isInput() {}
|
||||||
|
func (ToolCallResponse) isInput() {}
|
||||||
|
func (ResponseChoice) isInput() {}
|
||||||
|
|
||||||
|
// Request represents a request to a language model.
|
||||||
type Request struct {
|
type Request struct {
|
||||||
Conversation []Input
|
Conversation []Input
|
||||||
Messages []Message
|
Messages []Message
|
||||||
|
|||||||
38
response.go
38
response.go
@@ -1,9 +1,6 @@
|
|||||||
package go_llm
|
package llm
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/openai/openai-go"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
// ResponseChoice represents a single choice in a response.
|
||||||
type ResponseChoice struct {
|
type ResponseChoice struct {
|
||||||
Index int
|
Index int
|
||||||
Role Role
|
Role Role
|
||||||
@@ -32,36 +29,6 @@ func (r ResponseChoice) toRaw() map[string]any {
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r ResponseChoice) toChatCompletionMessages(_ string) []openai.ChatCompletionMessageParamUnion {
|
|
||||||
var as openai.ChatCompletionAssistantMessageParam
|
|
||||||
|
|
||||||
if r.Name != "" {
|
|
||||||
as.Name = openai.String(r.Name)
|
|
||||||
}
|
|
||||||
if r.Refusal != "" {
|
|
||||||
as.Refusal = openai.String(r.Refusal)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.Content != "" {
|
|
||||||
as.Content.OfString = openai.String(r.Content)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, call := range r.Calls {
|
|
||||||
as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{
|
|
||||||
ID: call.ID,
|
|
||||||
Function: openai.ChatCompletionMessageToolCallFunctionParam{
|
|
||||||
Name: call.FunctionCall.Name,
|
|
||||||
Arguments: call.FunctionCall.Arguments,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return []openai.ChatCompletionMessageParamUnion{
|
|
||||||
{
|
|
||||||
OfAssistant: &as,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r ResponseChoice) toInput() []Input {
|
func (r ResponseChoice) toInput() []Input {
|
||||||
var res []Input
|
var res []Input
|
||||||
|
|
||||||
@@ -79,6 +46,7 @@ func (r ResponseChoice) toInput() []Input {
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Response represents a response from a language model.
|
||||||
type Response struct {
|
type Response struct {
|
||||||
Choices []ResponseChoice
|
Choices []ResponseChoice
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"github.com/google/generative-ai-go/genai"
|
|
||||||
"github.com/openai/openai-go"
|
"github.com/openai/openai-go"
|
||||||
|
"google.golang.org/genai"
|
||||||
)
|
)
|
||||||
|
|
||||||
type array struct {
|
type array struct {
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/google/generative-ai-go/genai"
|
|
||||||
"github.com/openai/openai-go"
|
"github.com/openai/openai-go"
|
||||||
|
"google.golang.org/genai"
|
||||||
)
|
)
|
||||||
|
|
||||||
// just enforcing that basic implements Type
|
// just enforcing that basic implements Type
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"github.com/google/generative-ai-go/genai"
|
|
||||||
"github.com/openai/openai-go"
|
"github.com/openai/openai-go"
|
||||||
|
"google.golang.org/genai"
|
||||||
)
|
)
|
||||||
|
|
||||||
type enum struct {
|
type enum struct {
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"github.com/google/generative-ai-go/genai"
|
|
||||||
"github.com/openai/openai-go"
|
"github.com/openai/openai-go"
|
||||||
|
"google.golang.org/genai"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
134
schema/raw.go
Normal file
134
schema/raw.go
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"github.com/openai/openai-go"
|
||||||
|
"google.golang.org/genai"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Raw represents a raw JSON schema that is passed through directly.
|
||||||
|
// This is used for MCP tools where we receive the schema from the server.
|
||||||
|
type Raw struct {
|
||||||
|
schema map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRaw creates a new Raw schema from a map.
|
||||||
|
func NewRaw(schema map[string]any) Raw {
|
||||||
|
if schema == nil {
|
||||||
|
schema = map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Raw{schema: schema}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRawFromJSON creates a new Raw schema from JSON bytes.
|
||||||
|
func NewRawFromJSON(data []byte) (Raw, error) {
|
||||||
|
var schema map[string]any
|
||||||
|
if err := json.Unmarshal(data, &schema); err != nil {
|
||||||
|
return Raw{}, fmt.Errorf("failed to parse JSON schema: %w", err)
|
||||||
|
}
|
||||||
|
return NewRaw(schema), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Raw) OpenAIParameters() openai.FunctionParameters {
|
||||||
|
return openai.FunctionParameters(r.schema)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Raw) GoogleParameters() *genai.Schema {
|
||||||
|
return mapToGenaiSchema(r.schema)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Raw) AnthropicInputSchema() map[string]any {
|
||||||
|
return r.schema
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Raw) Required() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Raw) Description() string {
|
||||||
|
if desc, ok := r.schema["description"].(string); ok {
|
||||||
|
return desc
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Raw) FromAny(val any) (reflect.Value, error) {
|
||||||
|
return reflect.ValueOf(val), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Raw) SetValueOnField(obj reflect.Value, val reflect.Value) {
|
||||||
|
// No-op for raw schemas
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapToGenaiSchema converts a map[string]any JSON schema to genai.Schema
|
||||||
|
func mapToGenaiSchema(m map[string]any) *genai.Schema {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
schema := &genai.Schema{}
|
||||||
|
|
||||||
|
// Type
|
||||||
|
if t, ok := m["type"].(string); ok {
|
||||||
|
switch t {
|
||||||
|
case "string":
|
||||||
|
schema.Type = genai.TypeString
|
||||||
|
case "number":
|
||||||
|
schema.Type = genai.TypeNumber
|
||||||
|
case "integer":
|
||||||
|
schema.Type = genai.TypeInteger
|
||||||
|
case "boolean":
|
||||||
|
schema.Type = genai.TypeBoolean
|
||||||
|
case "array":
|
||||||
|
schema.Type = genai.TypeArray
|
||||||
|
case "object":
|
||||||
|
schema.Type = genai.TypeObject
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Description
|
||||||
|
if desc, ok := m["description"].(string); ok {
|
||||||
|
schema.Description = desc
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enum
|
||||||
|
if enum, ok := m["enum"].([]any); ok {
|
||||||
|
for _, e := range enum {
|
||||||
|
if s, ok := e.(string); ok {
|
||||||
|
schema.Enum = append(schema.Enum, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Properties (for objects)
|
||||||
|
if props, ok := m["properties"].(map[string]any); ok {
|
||||||
|
schema.Properties = make(map[string]*genai.Schema)
|
||||||
|
for k, v := range props {
|
||||||
|
if vm, ok := v.(map[string]any); ok {
|
||||||
|
schema.Properties[k] = mapToGenaiSchema(vm)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Required
|
||||||
|
if req, ok := m["required"].([]any); ok {
|
||||||
|
for _, r := range req {
|
||||||
|
if s, ok := r.(string); ok {
|
||||||
|
schema.Required = append(schema.Required, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Items (for arrays)
|
||||||
|
if items, ok := m["items"].(map[string]any); ok {
|
||||||
|
schema.Items = mapToGenaiSchema(items)
|
||||||
|
}
|
||||||
|
|
||||||
|
return schema
|
||||||
|
}
|
||||||
@@ -3,8 +3,8 @@ package schema
|
|||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"github.com/google/generative-ai-go/genai"
|
|
||||||
"github.com/openai/openai-go"
|
"github.com/openai/openai-go"
|
||||||
|
"google.golang.org/genai"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Type interface {
|
type Type interface {
|
||||||
|
|||||||
16
toolbox.go
16
toolbox.go
@@ -1,7 +1,8 @@
|
|||||||
package go_llm
|
package llm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
@@ -11,6 +12,7 @@ import (
|
|||||||
// the correct parameters.
|
// the correct parameters.
|
||||||
type ToolBox struct {
|
type ToolBox struct {
|
||||||
functions map[string]Function
|
functions map[string]Function
|
||||||
|
mcpServers map[string]*MCPServer // tool name -> MCP server that provides it
|
||||||
dontRequireTool bool
|
dontRequireTool bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,6 +93,18 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (t ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) {
|
func (t ToolBox) executeFunction(ctx *Context, functionName string, params string) (any, error) {
|
||||||
|
// Check if this is an MCP tool
|
||||||
|
if server, ok := t.mcpServers[functionName]; ok {
|
||||||
|
var args map[string]any
|
||||||
|
if params != "" {
|
||||||
|
if err := json.Unmarshal([]byte(params), &args); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse MCP tool arguments: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return server.CallTool(ctx, functionName, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regular function
|
||||||
f, ok := t.functions[functionName]
|
f, ok := t.functions[functionName]
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|||||||
145
transcriber.go
Normal file
145
transcriber.go
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Transcriber abstracts a speech-to-text model implementation.
|
||||||
|
type Transcriber interface {
|
||||||
|
Transcribe(ctx context.Context, wav []byte, opts TranscriptionOptions) (Transcription, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TranscriptionResponseFormat controls the output format requested from a transcriber.
|
||||||
|
type TranscriptionResponseFormat string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TranscriptionResponseFormatJSON TranscriptionResponseFormat = "json"
|
||||||
|
TranscriptionResponseFormatVerboseJSON TranscriptionResponseFormat = "verbose_json"
|
||||||
|
TranscriptionResponseFormatText TranscriptionResponseFormat = "text"
|
||||||
|
TranscriptionResponseFormatSRT TranscriptionResponseFormat = "srt"
|
||||||
|
TranscriptionResponseFormatVTT TranscriptionResponseFormat = "vtt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TranscriptionTimestampGranularity defines the requested timestamp detail.
|
||||||
|
type TranscriptionTimestampGranularity string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TranscriptionTimestampGranularityWord TranscriptionTimestampGranularity = "word"
|
||||||
|
TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TranscriptionOptions configures transcription behavior.
|
||||||
|
type TranscriptionOptions struct {
|
||||||
|
Language string
|
||||||
|
Prompt string
|
||||||
|
Temperature *float64
|
||||||
|
ResponseFormat TranscriptionResponseFormat
|
||||||
|
TimestampGranularities []TranscriptionTimestampGranularity
|
||||||
|
IncludeLogprobs bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transcription captures a normalized transcription result.
|
||||||
|
type Transcription struct {
|
||||||
|
Provider string
|
||||||
|
Model string
|
||||||
|
Text string
|
||||||
|
Language string
|
||||||
|
DurationSeconds float64
|
||||||
|
Segments []TranscriptionSegment
|
||||||
|
Words []TranscriptionWord
|
||||||
|
Logprobs []TranscriptionTokenLogprob
|
||||||
|
Usage TranscriptionUsage
|
||||||
|
RawJSON string
|
||||||
|
}
|
||||||
|
|
||||||
|
// TranscriptionSegment provides a coarse time-sliced transcription segment.
|
||||||
|
type TranscriptionSegment struct {
|
||||||
|
ID int
|
||||||
|
Start float64
|
||||||
|
End float64
|
||||||
|
Text string
|
||||||
|
Tokens []int
|
||||||
|
AvgLogprob *float64
|
||||||
|
CompressionRatio *float64
|
||||||
|
NoSpeechProb *float64
|
||||||
|
Words []TranscriptionWord
|
||||||
|
}
|
||||||
|
|
||||||
|
// TranscriptionWord provides a word-level timestamp.
|
||||||
|
type TranscriptionWord struct {
|
||||||
|
Word string
|
||||||
|
Start float64
|
||||||
|
End float64
|
||||||
|
Confidence *float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// TranscriptionTokenLogprob captures token-level log probability details.
|
||||||
|
type TranscriptionTokenLogprob struct {
|
||||||
|
Token string
|
||||||
|
Bytes []float64
|
||||||
|
Logprob float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// TranscriptionUsage captures token or duration usage details.
|
||||||
|
type TranscriptionUsage struct {
|
||||||
|
Type string
|
||||||
|
InputTokens int64
|
||||||
|
OutputTokens int64
|
||||||
|
TotalTokens int64
|
||||||
|
AudioTokens int64
|
||||||
|
TextTokens int64
|
||||||
|
Seconds float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// TranscribeFile converts an audio file to WAV and transcribes it.
|
||||||
|
func TranscribeFile(ctx context.Context, filename string, transcriber Transcriber, opts TranscriptionOptions) (Transcription, error) {
|
||||||
|
if transcriber == nil {
|
||||||
|
return Transcription{}, fmt.Errorf("transcriber is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
wav, err := audioFileToWav(ctx, filename)
|
||||||
|
if err != nil {
|
||||||
|
return Transcription{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return transcriber.Transcribe(ctx, wav, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func audioFileToWav(ctx context.Context, filename string) ([]byte, error) {
|
||||||
|
if filename == "" {
|
||||||
|
return nil, fmt.Errorf("filename is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.EqualFold(filepath.Ext(filename), ".wav") {
|
||||||
|
data, err := os.ReadFile(filename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read wav file: %w", err)
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tempFile, err := os.CreateTemp("", "go-llm-audio-*.wav")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create temp wav file: %w", err)
|
||||||
|
}
|
||||||
|
tempPath := tempFile.Name()
|
||||||
|
_ = tempFile.Close()
|
||||||
|
defer os.Remove(tempPath)
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "ffmpeg", "-hide_banner", "-loglevel", "error", "-y", "-i", filename, "-vn", "-f", "wav", tempPath)
|
||||||
|
if output, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return nil, fmt.Errorf("ffmpeg convert failed: %w (output: %s)", err, strings.TrimSpace(string(output)))
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(tempPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read converted wav file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
32
v2/CLAUDE.md
Normal file
32
v2/CLAUDE.md
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
# CLAUDE.md for go-llm v2
|
||||||
|
|
||||||
|
## Build and Test Commands
|
||||||
|
- Build project: `cd v2 && go build ./...`
|
||||||
|
- Run all tests: `cd v2 && go test ./...`
|
||||||
|
- Run specific test: `cd v2 && go test -v -run <TestName> ./...`
|
||||||
|
- Tidy dependencies: `cd v2 && go mod tidy`
|
||||||
|
- Vet: `cd v2 && go vet ./...`
|
||||||
|
|
||||||
|
## Code Style Guidelines
|
||||||
|
- **Indentation**: Standard Go tabs
|
||||||
|
- **Naming**: `camelCase` for unexported, `PascalCase` for exported
|
||||||
|
- **Error Handling**: Always check and handle errors immediately. Wrap with `fmt.Errorf("%w: ...", err)`
|
||||||
|
- **Imports**: Standard library first, then third-party, then internal packages
|
||||||
|
|
||||||
|
## Package Structure
|
||||||
|
- Root package `llm` — public API (Client, Model, Chat, ToolBox, Message types)
|
||||||
|
- `provider/` — Provider interface that backends implement
|
||||||
|
- `openai/`, `anthropic/`, `google/` — Provider implementations
|
||||||
|
- `tools/` — Ready-to-use sample tools (WebSearch, Browser, Exec, ReadFile, WriteFile, HTTP)
|
||||||
|
- `sandbox/` — Isolated Linux container environments via Proxmox LXC + SSH
|
||||||
|
- `internal/schema/` — JSON Schema generation from Go structs
|
||||||
|
- `internal/imageutil/` — Image compression utilities
|
||||||
|
|
||||||
|
## Key Design Decisions
|
||||||
|
1. Unified `Message` type instead of marker interfaces
|
||||||
|
2. `map[string]any` JSON Schema (no provider coupling)
|
||||||
|
3. Tool functions return `(string, error)`, use standard `context.Context`
|
||||||
|
4. `Chat.Send()` auto-loops tool calls; `Chat.SendRaw()` for manual control
|
||||||
|
5. MCP one-call connect: `MCPStdioServer(ctx, cmd, args...)`
|
||||||
|
6. Streaming via pull-based `StreamReader.Next()`
|
||||||
|
7. Middleware for logging, retry, timeout, usage tracking
|
||||||
113
v2/agent/agent.go
Normal file
113
v2/agent/agent.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
// Package agent provides a simple agent abstraction built on top of go-llm.
|
||||||
|
//
|
||||||
|
// An Agent wraps a model, system prompt, and tools into a reusable unit.
|
||||||
|
// Agents can be turned into tools via AsTool, enabling parent agents to
|
||||||
|
// delegate work to specialized sub-agents through the normal tool-call loop.
|
||||||
|
//
|
||||||
|
// Example — orchestrator with sub-agents:
|
||||||
|
//
|
||||||
|
// researcher := agent.New(model, "You research topics via web search.",
|
||||||
|
// agent.WithTools(llm.NewToolBox(tools.WebSearch(apiKey))),
|
||||||
|
// )
|
||||||
|
// coder := agent.New(model, "You write and run code.",
|
||||||
|
// agent.WithTools(llm.NewToolBox(tools.Exec())),
|
||||||
|
// )
|
||||||
|
// orchestrator := agent.New(model, "You coordinate research and coding tasks.",
|
||||||
|
// agent.WithTools(llm.NewToolBox(
|
||||||
|
// researcher.AsTool("research", "Research a topic"),
|
||||||
|
// coder.AsTool("code", "Write and run code"),
|
||||||
|
// )),
|
||||||
|
// )
|
||||||
|
// result, err := orchestrator.Run(ctx, "Build a fibonacci function in Go")
|
||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Agent is a configured LLM agent with a system prompt and tools.
|
||||||
|
// Each call to Run creates a fresh conversation (no state is carried between runs).
|
||||||
|
type Agent struct {
|
||||||
|
model *llm.Model
|
||||||
|
system string
|
||||||
|
tools *llm.ToolBox
|
||||||
|
reqOpts []llm.RequestOption
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option configures an Agent.
|
||||||
|
type Option func(*Agent)
|
||||||
|
|
||||||
|
// WithTools sets the tools available to the agent.
|
||||||
|
func WithTools(tb *llm.ToolBox) Option {
|
||||||
|
return func(a *Agent) { a.tools = tb }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRequestOptions sets default request options (temperature, max tokens, etc.)
|
||||||
|
// applied to every completion call the agent makes.
|
||||||
|
func WithRequestOptions(opts ...llm.RequestOption) Option {
|
||||||
|
return func(a *Agent) { a.reqOpts = opts }
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates an agent with the given model and system prompt.
|
||||||
|
func New(model *llm.Model, system string, opts ...Option) *Agent {
|
||||||
|
a := &Agent{
|
||||||
|
model: model,
|
||||||
|
system: system,
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(a)
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run executes the agent with a user prompt. Each call is a fresh conversation.
|
||||||
|
// The agent loops tool calls automatically until it produces a text response.
|
||||||
|
func (a *Agent) Run(ctx context.Context, prompt string) (string, error) {
|
||||||
|
return a.RunMessages(ctx, []llm.Message{llm.UserMessage(prompt)})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunMessages executes the agent with full message control.
|
||||||
|
// Each call is a fresh conversation. The agent loops tool calls automatically.
|
||||||
|
func (a *Agent) RunMessages(ctx context.Context, messages []llm.Message) (string, error) {
|
||||||
|
chat := llm.NewChat(a.model, a.reqOpts...)
|
||||||
|
if a.system != "" {
|
||||||
|
chat.SetSystem(a.system)
|
||||||
|
}
|
||||||
|
if a.tools != nil {
|
||||||
|
chat.SetTools(a.tools)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send each message; the last one triggers the completion loop.
|
||||||
|
// All but the last are added as context.
|
||||||
|
for i, msg := range messages {
|
||||||
|
if i < len(messages)-1 {
|
||||||
|
chat.AddToolResults(msg) // AddToolResults just appends to history
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return chat.SendMessage(ctx, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty messages — send an empty user message
|
||||||
|
return chat.Send(ctx, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// delegateParams is the parameter struct for the tool created by AsTool.
|
||||||
|
type delegateParams struct {
|
||||||
|
Input string `json:"input" description:"The task or question to delegate to this agent"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AsTool creates a llm.Tool that delegates to this agent.
|
||||||
|
// When a parent agent calls this tool, it runs the agent with the provided input
|
||||||
|
// as the prompt and returns the agent's text response as the tool result.
|
||||||
|
//
|
||||||
|
// This enables sub-agent patterns where a parent agent can spawn specialized
|
||||||
|
// child agents through the normal tool-call mechanism.
|
||||||
|
func (a *Agent) AsTool(name, description string) llm.Tool {
|
||||||
|
return llm.Define[delegateParams](name, description,
|
||||||
|
func(ctx context.Context, p delegateParams) (string, error) {
|
||||||
|
return a.Run(ctx, p.Input)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
244
v2/agent/agent_test.go
Normal file
244
v2/agent/agent_test.go
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockProvider is a test helper that implements provider.Provider.
|
||||||
|
type mockProvider struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
completeFunc func(ctx context.Context, req provider.Request) (provider.Response, error)
|
||||||
|
requests []provider.Request
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProvider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
m.requests = append(m.requests, req)
|
||||||
|
m.mu.Unlock()
|
||||||
|
return m.completeFunc(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProvider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
|
||||||
|
close(events)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProvider) lastRequest() provider.Request {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if len(m.requests) == 0 {
|
||||||
|
return provider.Request{}
|
||||||
|
}
|
||||||
|
return m.requests[len(m.requests)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockModel(fn func(ctx context.Context, req provider.Request) (provider.Response, error)) *llm.Model {
|
||||||
|
mp := &mockProvider{completeFunc: fn}
|
||||||
|
return llm.NewClient(mp).Model("mock-model")
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSimpleMockModel(text string) *llm.Model {
|
||||||
|
return newMockModel(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
return provider.Response{Text: text}, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgent_Run(t *testing.T) {
|
||||||
|
model := newSimpleMockModel("Hello from agent!")
|
||||||
|
a := New(model, "You are a helpful assistant.")
|
||||||
|
|
||||||
|
result, err := a.Run(context.Background(), "Say hello")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if result != "Hello from agent!" {
|
||||||
|
t.Errorf("expected 'Hello from agent!', got %q", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgent_Run_WithTools(t *testing.T) {
|
||||||
|
callCount := 0
|
||||||
|
model := newMockModel(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
callCount++
|
||||||
|
if callCount == 1 {
|
||||||
|
// First call: model requests a tool call
|
||||||
|
return provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{ID: "tc1", Name: "greet", Arguments: `{}`},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
// Second call: model returns text after seeing tool result
|
||||||
|
return provider.Response{Text: "Tool said: hello!"}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
tool := llm.DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) {
|
||||||
|
return "hello!", nil
|
||||||
|
})
|
||||||
|
|
||||||
|
a := New(model, "You are helpful.", WithTools(llm.NewToolBox(tool)))
|
||||||
|
result, err := a.Run(context.Background(), "Use the greet tool")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if result != "Tool said: hello!" {
|
||||||
|
t.Errorf("expected 'Tool said: hello!', got %q", result)
|
||||||
|
}
|
||||||
|
if callCount != 2 {
|
||||||
|
t.Errorf("expected 2 calls (tool loop), got %d", callCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgent_AsTool(t *testing.T) {
|
||||||
|
// Create a child agent
|
||||||
|
childModel := newSimpleMockModel("child result: 42")
|
||||||
|
child := New(childModel, "You compute things.")
|
||||||
|
|
||||||
|
// Create the tool from the child agent
|
||||||
|
childTool := child.AsTool("compute", "Delegate computation to child agent")
|
||||||
|
|
||||||
|
// Verify tool metadata
|
||||||
|
if childTool.Name != "compute" {
|
||||||
|
t.Errorf("expected tool name 'compute', got %q", childTool.Name)
|
||||||
|
}
|
||||||
|
if childTool.Description != "Delegate computation to child agent" {
|
||||||
|
t.Errorf("expected correct description, got %q", childTool.Description)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the tool directly (simulating what the parent's Chat.Send loop does)
|
||||||
|
result, err := childTool.Execute(context.Background(), `{"input":"what is 6*7?"}`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if result != "child result: 42" {
|
||||||
|
t.Errorf("expected 'child result: 42', got %q", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgent_AsTool_ParentChild(t *testing.T) {
|
||||||
|
// Child agent that always returns a fixed result
|
||||||
|
childModel := newSimpleMockModel("researched: Go generics are great")
|
||||||
|
child := New(childModel, "You are a researcher.")
|
||||||
|
|
||||||
|
// Parent agent: first call returns tool call, second returns text
|
||||||
|
parentCallCount := 0
|
||||||
|
parentModel := newMockModel(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
parentCallCount++
|
||||||
|
if parentCallCount == 1 {
|
||||||
|
return provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{ID: "tc1", Name: "research", Arguments: `{"input":"Tell me about Go generics"}`},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
// After getting tool result, parent synthesizes final answer
|
||||||
|
return provider.Response{Text: "Based on research: Go generics are great"}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
parent := New(parentModel, "You coordinate tasks.",
|
||||||
|
WithTools(llm.NewToolBox(
|
||||||
|
child.AsTool("research", "Research a topic"),
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err := parent.Run(context.Background(), "Tell me about Go generics")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if result != "Based on research: Go generics are great" {
|
||||||
|
t.Errorf("expected synthesized result, got %q", result)
|
||||||
|
}
|
||||||
|
if parentCallCount != 2 {
|
||||||
|
t.Errorf("expected 2 parent calls (tool loop), got %d", parentCallCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgent_RunMessages(t *testing.T) {
|
||||||
|
model := newSimpleMockModel("I see the system and user messages")
|
||||||
|
a := New(model, "You are helpful.")
|
||||||
|
|
||||||
|
messages := []llm.Message{
|
||||||
|
llm.UserMessage("First question"),
|
||||||
|
llm.AssistantMessage("First answer"),
|
||||||
|
llm.UserMessage("Follow up"),
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := a.RunMessages(context.Background(), messages)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if result != "I see the system and user messages" {
|
||||||
|
t.Errorf("unexpected result: %q", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgent_ContextCancellation(t *testing.T) {
|
||||||
|
model := newMockModel(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
return provider.Response{}, ctx.Err()
|
||||||
|
})
|
||||||
|
a := New(model, "You are helpful.")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // Cancel immediately
|
||||||
|
|
||||||
|
_, err := a.Run(ctx, "This should fail")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error from cancelled context")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgent_WithRequestOptions(t *testing.T) {
|
||||||
|
var capturedReq provider.Request
|
||||||
|
model := newMockModel(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
capturedReq = req
|
||||||
|
return provider.Response{Text: "ok"}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
a := New(model, "You are helpful.",
|
||||||
|
WithRequestOptions(llm.WithTemperature(0.3), llm.WithMaxTokens(100)),
|
||||||
|
)
|
||||||
|
|
||||||
|
_, err := a.Run(context.Background(), "test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if capturedReq.Temperature == nil || *capturedReq.Temperature != 0.3 {
|
||||||
|
t.Errorf("expected temperature 0.3, got %v", capturedReq.Temperature)
|
||||||
|
}
|
||||||
|
if capturedReq.MaxTokens == nil || *capturedReq.MaxTokens != 100 {
|
||||||
|
t.Errorf("expected maxTokens 100, got %v", capturedReq.MaxTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgent_Run_Error(t *testing.T) {
|
||||||
|
wantErr := errors.New("model failed")
|
||||||
|
model := newMockModel(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
return provider.Response{}, wantErr
|
||||||
|
})
|
||||||
|
a := New(model, "You are helpful.")
|
||||||
|
|
||||||
|
_, err := a.Run(context.Background(), "test")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgent_EmptySystem(t *testing.T) {
|
||||||
|
model := newSimpleMockModel("no system prompt")
|
||||||
|
a := New(model, "") // Empty system prompt
|
||||||
|
|
||||||
|
result, err := a.Run(context.Background(), "test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if result != "no system prompt" {
|
||||||
|
t.Errorf("unexpected result: %q", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
107
v2/agent/example_test.go
Normal file
107
v2/agent/example_test.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package agent_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/agent"
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/tools"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A researcher agent that can search the web and browse pages.
|
||||||
|
func Example_researcher() {
|
||||||
|
model := llm.OpenAI(os.Getenv("OPENAI_API_KEY")).Model("gpt-4o")
|
||||||
|
|
||||||
|
researcher := agent.New(model,
|
||||||
|
"You are a research assistant. Use web search to find information, "+
|
||||||
|
"then use the browser to read full articles when needed. "+
|
||||||
|
"Provide a concise summary of your findings.",
|
||||||
|
agent.WithTools(llm.NewToolBox(
|
||||||
|
tools.WebSearch(os.Getenv("BRAVE_API_KEY")),
|
||||||
|
tools.Browser(),
|
||||||
|
)),
|
||||||
|
agent.WithRequestOptions(llm.WithTemperature(0.3)),
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err := researcher.Run(context.Background(), "What are the latest developments in Go generics?")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Error:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Println(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A coder agent that can read, write, and execute code.
|
||||||
|
func Example_coder() {
|
||||||
|
model := llm.OpenAI(os.Getenv("OPENAI_API_KEY")).Model("gpt-4o")
|
||||||
|
|
||||||
|
coder := agent.New(model,
|
||||||
|
"You are a coding assistant. You can read files, write files, and execute commands. "+
|
||||||
|
"When asked to create a program, write the code to a file and then run it to verify it works.",
|
||||||
|
agent.WithTools(llm.NewToolBox(
|
||||||
|
tools.ReadFile(),
|
||||||
|
tools.WriteFile(),
|
||||||
|
tools.Exec(
|
||||||
|
tools.WithAllowedCommands([]string{"go", "python", "node", "cat", "ls"}),
|
||||||
|
tools.WithWorkDir(os.TempDir()),
|
||||||
|
),
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err := coder.Run(context.Background(),
|
||||||
|
"Create a Go program that prints the first 10 Fibonacci numbers. Save it and run it.")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Error:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Println(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// An orchestrator agent that delegates to specialized sub-agents.
|
||||||
|
// The orchestrator breaks a complex task into subtasks and dispatches them
|
||||||
|
// to the appropriate sub-agent via tool calls.
|
||||||
|
func Example_orchestrator() {
|
||||||
|
model := llm.OpenAI(os.Getenv("OPENAI_API_KEY")).Model("gpt-4o")
|
||||||
|
|
||||||
|
// Specialized sub-agents
|
||||||
|
researcher := agent.New(model,
|
||||||
|
"You are a research assistant. Search the web for information on the given topic "+
|
||||||
|
"and return a concise summary.",
|
||||||
|
agent.WithTools(llm.NewToolBox(
|
||||||
|
tools.WebSearch(os.Getenv("BRAVE_API_KEY")),
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
|
||||||
|
coder := agent.New(model,
|
||||||
|
"You are a coding assistant. Write and test code as requested. "+
|
||||||
|
"Save files and run them to verify they work.",
|
||||||
|
agent.WithTools(llm.NewToolBox(
|
||||||
|
tools.ReadFile(),
|
||||||
|
tools.WriteFile(),
|
||||||
|
tools.Exec(tools.WithAllowedCommands([]string{"go", "python"})),
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Orchestrator can delegate to both sub-agents
|
||||||
|
orchestrator := agent.New(model,
|
||||||
|
"You are a project manager. Break complex tasks into research and coding subtasks. "+
|
||||||
|
"Use delegate_research for information gathering and delegate_coding for implementation. "+
|
||||||
|
"Synthesize the results into a final answer.",
|
||||||
|
agent.WithTools(llm.NewToolBox(
|
||||||
|
researcher.AsTool("delegate_research",
|
||||||
|
"Delegate a research task. Provide a clear question or topic to research."),
|
||||||
|
coder.AsTool("delegate_coding",
|
||||||
|
"Delegate a coding task. Provide clear requirements for what to implement."),
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err := orchestrator.Run(context.Background(),
|
||||||
|
"Research how to implement a binary search tree in Go, then create one with insert and search operations.")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Error:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Println(result)
|
||||||
|
}
|
||||||
275
v2/anthropic/anthropic.go
Normal file
275
v2/anthropic/anthropic.go
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
// Package anthropic implements the go-llm v2 provider interface for Anthropic.
|
||||||
|
package anthropic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/internal/imageutil"
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
|
||||||
|
anth "github.com/liushuangls/go-anthropic/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Provider implements the provider.Provider interface for Anthropic.
|
||||||
|
type Provider struct {
|
||||||
|
apiKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new Anthropic provider.
|
||||||
|
func New(apiKey string) *Provider {
|
||||||
|
return &Provider{apiKey: apiKey}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Complete performs a non-streaming completion.
|
||||||
|
func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
cl := anth.NewClient(p.apiKey)
|
||||||
|
|
||||||
|
anthReq := p.buildRequest(req)
|
||||||
|
|
||||||
|
resp, err := cl.CreateMessages(ctx, anthReq)
|
||||||
|
if err != nil {
|
||||||
|
return provider.Response{}, fmt.Errorf("anthropic completion error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.convertResponse(resp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream performs a streaming completion.
|
||||||
|
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
|
||||||
|
cl := anth.NewClient(p.apiKey)
|
||||||
|
|
||||||
|
anthReq := p.buildRequest(req)
|
||||||
|
|
||||||
|
resp, err := cl.CreateMessagesStream(ctx, anth.MessagesStreamRequest{
|
||||||
|
MessagesRequest: anthReq,
|
||||||
|
OnContentBlockDelta: func(data anth.MessagesEventContentBlockDeltaData) {
|
||||||
|
if data.Delta.Type == "text_delta" && data.Delta.Text != nil {
|
||||||
|
events <- provider.StreamEvent{
|
||||||
|
Type: provider.StreamEventText,
|
||||||
|
Text: *data.Delta.Text,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("anthropic stream error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := p.convertResponse(resp)
|
||||||
|
events <- provider.StreamEvent{
|
||||||
|
Type: provider.StreamEventDone,
|
||||||
|
Response: &result,
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) buildRequest(req provider.Request) anth.MessagesRequest {
|
||||||
|
anthReq := anth.MessagesRequest{
|
||||||
|
Model: anth.Model(req.Model),
|
||||||
|
MaxTokens: 4096,
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.MaxTokens != nil {
|
||||||
|
anthReq.MaxTokens = *req.MaxTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
var msgs []anth.Message
|
||||||
|
|
||||||
|
for _, msg := range req.Messages {
|
||||||
|
if msg.Role == "system" {
|
||||||
|
if len(anthReq.System) > 0 {
|
||||||
|
anthReq.System += "\n"
|
||||||
|
}
|
||||||
|
anthReq.System += msg.Content
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg.Role == "tool" {
|
||||||
|
// Tool results in Anthropic format - use the helper
|
||||||
|
toolUseID := msg.ToolCallID
|
||||||
|
content := msg.Content
|
||||||
|
isError := false
|
||||||
|
msgs = append(msgs, anth.Message{
|
||||||
|
Role: anth.RoleUser,
|
||||||
|
Content: []anth.MessageContent{
|
||||||
|
{
|
||||||
|
Type: anth.MessagesContentTypeToolResult,
|
||||||
|
MessageContentToolResult: &anth.MessageContentToolResult{
|
||||||
|
ToolUseID: &toolUseID,
|
||||||
|
Content: []anth.MessageContent{
|
||||||
|
{
|
||||||
|
Type: anth.MessagesContentTypeText,
|
||||||
|
Text: &content,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
IsError: &isError,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
role := anth.RoleUser
|
||||||
|
if msg.Role == "assistant" {
|
||||||
|
role = anth.RoleAssistant
|
||||||
|
}
|
||||||
|
|
||||||
|
m := anth.Message{
|
||||||
|
Role: role,
|
||||||
|
Content: []anth.MessageContent{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg.Content != "" {
|
||||||
|
m.Content = append(m.Content, anth.MessageContent{
|
||||||
|
Type: anth.MessagesContentTypeText,
|
||||||
|
Text: &msg.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle tool calls in assistant messages
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
var input json.RawMessage
|
||||||
|
if tc.Arguments != "" {
|
||||||
|
input = json.RawMessage(tc.Arguments)
|
||||||
|
} else {
|
||||||
|
input = json.RawMessage("{}")
|
||||||
|
}
|
||||||
|
m.Content = append(m.Content, anth.MessageContent{
|
||||||
|
Type: anth.MessagesContentTypeToolUse,
|
||||||
|
MessageContentToolUse: &anth.MessageContentToolUse{
|
||||||
|
ID: tc.ID,
|
||||||
|
Name: tc.Name,
|
||||||
|
Input: input,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle images
|
||||||
|
for _, img := range msg.Images {
|
||||||
|
if role == anth.RoleAssistant {
|
||||||
|
role = anth.RoleUser
|
||||||
|
m.Role = anth.RoleUser
|
||||||
|
}
|
||||||
|
|
||||||
|
if img.Base64 != "" {
|
||||||
|
b64 := img.Base64
|
||||||
|
contentType := img.ContentType
|
||||||
|
|
||||||
|
// Compress if > 5MiB
|
||||||
|
raw, err := base64.StdEncoding.DecodeString(b64)
|
||||||
|
if err == nil && len(raw) >= 5242880 {
|
||||||
|
compressed, mime, cerr := imageutil.CompressImage(b64, 5*1024*1024)
|
||||||
|
if cerr == nil {
|
||||||
|
b64 = compressed
|
||||||
|
contentType = mime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Content = append(m.Content, anth.NewImageMessageContent(
|
||||||
|
anth.NewMessageContentSource(
|
||||||
|
anth.MessagesContentSourceTypeBase64,
|
||||||
|
contentType,
|
||||||
|
b64,
|
||||||
|
)))
|
||||||
|
} else if img.URL != "" {
|
||||||
|
// Download and convert to base64 (Anthropic doesn't support URLs directly)
|
||||||
|
resp, err := http.Get(img.URL)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data, err := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
b64 := base64.StdEncoding.EncodeToString(data)
|
||||||
|
|
||||||
|
m.Content = append(m.Content, anth.NewImageMessageContent(
|
||||||
|
anth.NewMessageContentSource(
|
||||||
|
anth.MessagesContentSourceTypeBase64,
|
||||||
|
contentType,
|
||||||
|
b64,
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Audio is not supported by Anthropic — skip silently.
|
||||||
|
|
||||||
|
// Merge consecutive same-role messages (Anthropic requires alternating)
|
||||||
|
if len(msgs) > 0 && msgs[len(msgs)-1].Role == role {
|
||||||
|
msgs[len(msgs)-1].Content = append(msgs[len(msgs)-1].Content, m.Content...)
|
||||||
|
} else {
|
||||||
|
msgs = append(msgs, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tool := range req.Tools {
|
||||||
|
anthReq.Tools = append(anthReq.Tools, anth.ToolDefinition{
|
||||||
|
Name: tool.Name,
|
||||||
|
Description: tool.Description,
|
||||||
|
InputSchema: tool.Schema,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
anthReq.Messages = msgs
|
||||||
|
|
||||||
|
if req.Temperature != nil {
|
||||||
|
f := float32(*req.Temperature)
|
||||||
|
anthReq.Temperature = &f
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.TopP != nil {
|
||||||
|
f := float32(*req.TopP)
|
||||||
|
anthReq.TopP = &f
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(req.Stop) > 0 {
|
||||||
|
anthReq.StopSequences = req.Stop
|
||||||
|
}
|
||||||
|
|
||||||
|
return anthReq
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) convertResponse(resp anth.MessagesResponse) provider.Response {
|
||||||
|
var res provider.Response
|
||||||
|
var textParts []string
|
||||||
|
|
||||||
|
for _, block := range resp.Content {
|
||||||
|
switch block.Type {
|
||||||
|
case anth.MessagesContentTypeText:
|
||||||
|
if block.Text != nil {
|
||||||
|
textParts = append(textParts, *block.Text)
|
||||||
|
}
|
||||||
|
case anth.MessagesContentTypeToolUse:
|
||||||
|
if block.MessageContentToolUse != nil {
|
||||||
|
args, _ := json.Marshal(block.MessageContentToolUse.Input)
|
||||||
|
res.ToolCalls = append(res.ToolCalls, provider.ToolCall{
|
||||||
|
ID: block.MessageContentToolUse.ID,
|
||||||
|
Name: block.MessageContentToolUse.Name,
|
||||||
|
Arguments: string(args),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
res.Text = strings.Join(textParts, "")
|
||||||
|
|
||||||
|
res.Usage = &provider.Usage{
|
||||||
|
InputTokens: resp.Usage.InputTokens,
|
||||||
|
OutputTokens: resp.Usage.OutputTokens,
|
||||||
|
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
153
v2/chat.go
Normal file
153
v2/chat.go
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Chat manages a multi-turn conversation with automatic history tracking
|
||||||
|
// and optional automatic tool-call execution.
|
||||||
|
type Chat struct {
|
||||||
|
model *Model
|
||||||
|
messages []Message
|
||||||
|
tools *ToolBox
|
||||||
|
opts []RequestOption
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewChat creates a new conversation with the given model.
|
||||||
|
func NewChat(model *Model, opts ...RequestOption) *Chat {
|
||||||
|
return &Chat{
|
||||||
|
model: model,
|
||||||
|
opts: opts,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSystem sets or replaces the system message.
|
||||||
|
func (c *Chat) SetSystem(text string) {
|
||||||
|
filtered := make([]Message, 0, len(c.messages)+1)
|
||||||
|
for _, m := range c.messages {
|
||||||
|
if m.Role != RoleSystem {
|
||||||
|
filtered = append(filtered, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.messages = append([]Message{SystemMessage(text)}, filtered...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTools configures the tools available for this chat.
|
||||||
|
func (c *Chat) SetTools(tb *ToolBox) {
|
||||||
|
c.tools = tb
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send sends a user message and returns the assistant's text response.
|
||||||
|
// If the model calls tools, they are executed automatically and the loop
|
||||||
|
// continues until the model produces a text response (the "agent loop").
|
||||||
|
func (c *Chat) Send(ctx context.Context, text string) (string, error) {
|
||||||
|
return c.SendMessage(ctx, UserMessage(text))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendWithImages sends a user message with images attached.
|
||||||
|
func (c *Chat) SendWithImages(ctx context.Context, text string, images ...Image) (string, error) {
|
||||||
|
return c.SendMessage(ctx, UserMessageWithImages(text, images...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendMessage sends an arbitrary message and returns the final text response.
|
||||||
|
// Handles the full tool-call loop automatically.
|
||||||
|
func (c *Chat) SendMessage(ctx context.Context, msg Message) (string, error) {
|
||||||
|
c.messages = append(c.messages, msg)
|
||||||
|
|
||||||
|
opts := c.buildOpts()
|
||||||
|
|
||||||
|
for {
|
||||||
|
resp, err := c.model.Complete(ctx, c.messages, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("completion failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.messages = append(c.messages, resp.Message())
|
||||||
|
|
||||||
|
if !resp.HasToolCalls() {
|
||||||
|
return resp.Text, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.tools == nil {
|
||||||
|
return "", ErrNoToolsConfigured
|
||||||
|
}
|
||||||
|
|
||||||
|
toolResults, err := c.tools.ExecuteAll(ctx, resp.ToolCalls)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("tool execution failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.messages = append(c.messages, toolResults...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendRaw sends a message and returns the raw Response without automatic tool execution.
|
||||||
|
// Useful when you want to handle tool calls manually.
|
||||||
|
func (c *Chat) SendRaw(ctx context.Context, msg Message) (Response, error) {
|
||||||
|
c.messages = append(c.messages, msg)
|
||||||
|
|
||||||
|
opts := c.buildOpts()
|
||||||
|
|
||||||
|
resp, err := c.model.Complete(ctx, c.messages, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return Response{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.messages = append(c.messages, resp.Message())
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendStream sends a user message and returns a StreamReader for streaming responses.
|
||||||
|
func (c *Chat) SendStream(ctx context.Context, text string) (*StreamReader, error) {
|
||||||
|
c.messages = append(c.messages, UserMessage(text))
|
||||||
|
|
||||||
|
opts := c.buildOpts()
|
||||||
|
|
||||||
|
cfg := &requestConfig{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := buildProviderRequest(c.model.model, c.messages, cfg)
|
||||||
|
return newStreamReader(ctx, c.model.provider, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddToolResults manually adds tool results to the conversation.
|
||||||
|
// Use with SendRaw when handling tool calls manually.
|
||||||
|
func (c *Chat) AddToolResults(results ...Message) {
|
||||||
|
c.messages = append(c.messages, results...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Messages returns the current conversation history (read-only copy).
|
||||||
|
func (c *Chat) Messages() []Message {
|
||||||
|
cp := make([]Message, len(c.messages))
|
||||||
|
copy(cp, c.messages)
|
||||||
|
return cp
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset clears the conversation history.
|
||||||
|
func (c *Chat) Reset() {
|
||||||
|
c.messages = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fork creates a copy of this chat with identical history, for branching conversations.
|
||||||
|
func (c *Chat) Fork() *Chat {
|
||||||
|
c2 := &Chat{
|
||||||
|
model: c.model,
|
||||||
|
messages: make([]Message, len(c.messages)),
|
||||||
|
tools: c.tools,
|
||||||
|
opts: c.opts,
|
||||||
|
}
|
||||||
|
copy(c2.messages, c.messages)
|
||||||
|
return c2
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Chat) buildOpts() []RequestOption {
|
||||||
|
opts := make([]RequestOption, len(c.opts))
|
||||||
|
copy(opts, c.opts)
|
||||||
|
if c.tools != nil {
|
||||||
|
opts = append(opts, WithTools(c.tools))
|
||||||
|
}
|
||||||
|
return opts
|
||||||
|
}
|
||||||
407
v2/chat_test.go
Normal file
407
v2/chat_test.go
Normal file
@@ -0,0 +1,407 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestChat_Send(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{Text: "Hello there!"})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
chat := NewChat(model)
|
||||||
|
|
||||||
|
text, err := chat.Send(context.Background(), "Hi")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if text != "Hello there!" {
|
||||||
|
t.Errorf("expected 'Hello there!', got %q", text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChat_SendMessage(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{Text: "reply"})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
chat := NewChat(model)
|
||||||
|
|
||||||
|
_, err := chat.SendMessage(context.Background(), UserMessage("msg1"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs := chat.Messages()
|
||||||
|
if len(msgs) != 2 {
|
||||||
|
t.Fatalf("expected 2 messages (user + assistant), got %d", len(msgs))
|
||||||
|
}
|
||||||
|
if msgs[0].Role != RoleUser {
|
||||||
|
t.Errorf("expected first message role=user, got %v", msgs[0].Role)
|
||||||
|
}
|
||||||
|
if msgs[0].Content.Text != "msg1" {
|
||||||
|
t.Errorf("expected first message text='msg1', got %q", msgs[0].Content.Text)
|
||||||
|
}
|
||||||
|
if msgs[1].Role != RoleAssistant {
|
||||||
|
t.Errorf("expected second message role=assistant, got %v", msgs[1].Role)
|
||||||
|
}
|
||||||
|
if msgs[1].Content.Text != "reply" {
|
||||||
|
t.Errorf("expected second message text='reply', got %q", msgs[1].Content.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChat_SetSystem(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
chat := NewChat(model)
|
||||||
|
|
||||||
|
chat.SetSystem("You are a bot")
|
||||||
|
msgs := chat.Messages()
|
||||||
|
if len(msgs) != 1 {
|
||||||
|
t.Fatalf("expected 1 message, got %d", len(msgs))
|
||||||
|
}
|
||||||
|
if msgs[0].Role != RoleSystem {
|
||||||
|
t.Errorf("expected role=system, got %v", msgs[0].Role)
|
||||||
|
}
|
||||||
|
if msgs[0].Content.Text != "You are a bot" {
|
||||||
|
t.Errorf("expected system text, got %q", msgs[0].Content.Text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace system message
|
||||||
|
chat.SetSystem("You are a helpful bot")
|
||||||
|
msgs = chat.Messages()
|
||||||
|
if len(msgs) != 1 {
|
||||||
|
t.Fatalf("expected 1 message after replace, got %d", len(msgs))
|
||||||
|
}
|
||||||
|
if msgs[0].Content.Text != "You are a helpful bot" {
|
||||||
|
t.Errorf("expected replaced system text, got %q", msgs[0].Content.Text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// System message stays first even after adding other messages
|
||||||
|
_, _ = chat.Send(context.Background(), "Hi")
|
||||||
|
chat.SetSystem("New system")
|
||||||
|
msgs = chat.Messages()
|
||||||
|
if msgs[0].Role != RoleSystem {
|
||||||
|
t.Errorf("expected system as first message, got %v", msgs[0].Role)
|
||||||
|
}
|
||||||
|
if msgs[0].Content.Text != "New system" {
|
||||||
|
t.Errorf("expected 'New system', got %q", msgs[0].Content.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChat_ToolCallLoop(t *testing.T) {
|
||||||
|
var callCount int32
|
||||||
|
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
n := atomic.AddInt32(&callCount, 1)
|
||||||
|
if n == 1 {
|
||||||
|
// First call: request a tool
|
||||||
|
return provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{ID: "tc1", Name: "greet", Arguments: "{}"},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
// Second call: return text
|
||||||
|
return provider.Response{Text: "done"}, nil
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
chat := NewChat(model)
|
||||||
|
|
||||||
|
tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) {
|
||||||
|
return "hello!", nil
|
||||||
|
})
|
||||||
|
chat.SetTools(NewToolBox(tool))
|
||||||
|
|
||||||
|
text, err := chat.Send(context.Background(), "test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if text != "done" {
|
||||||
|
t.Errorf("expected 'done', got %q", text)
|
||||||
|
}
|
||||||
|
if atomic.LoadInt32(&callCount) != 2 {
|
||||||
|
t.Errorf("expected 2 provider calls, got %d", callCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check message history: user, assistant (tool call), tool result, assistant (text)
|
||||||
|
msgs := chat.Messages()
|
||||||
|
if len(msgs) != 4 {
|
||||||
|
t.Fatalf("expected 4 messages, got %d", len(msgs))
|
||||||
|
}
|
||||||
|
if msgs[0].Role != RoleUser {
|
||||||
|
t.Errorf("msg[0]: expected user, got %v", msgs[0].Role)
|
||||||
|
}
|
||||||
|
if msgs[1].Role != RoleAssistant {
|
||||||
|
t.Errorf("msg[1]: expected assistant, got %v", msgs[1].Role)
|
||||||
|
}
|
||||||
|
if len(msgs[1].ToolCalls) != 1 {
|
||||||
|
t.Errorf("msg[1]: expected 1 tool call, got %d", len(msgs[1].ToolCalls))
|
||||||
|
}
|
||||||
|
if msgs[2].Role != RoleTool {
|
||||||
|
t.Errorf("msg[2]: expected tool, got %v", msgs[2].Role)
|
||||||
|
}
|
||||||
|
if msgs[2].Content.Text != "hello!" {
|
||||||
|
t.Errorf("msg[2]: expected 'hello!', got %q", msgs[2].Content.Text)
|
||||||
|
}
|
||||||
|
if msgs[3].Role != RoleAssistant {
|
||||||
|
t.Errorf("msg[3]: expected assistant, got %v", msgs[3].Role)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChat_ToolCallLoop_NoTools(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{ID: "tc1", Name: "fake", Arguments: "{}"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
chat := NewChat(model)
|
||||||
|
|
||||||
|
_, err := chat.Send(context.Background(), "test")
|
||||||
|
if !errors.Is(err, ErrNoToolsConfigured) {
|
||||||
|
t.Errorf("expected ErrNoToolsConfigured, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChat_SendRaw(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{
|
||||||
|
Text: "raw response",
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{ID: "tc1", Name: "tool1", Arguments: `{"x":1}`},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
chat := NewChat(model)
|
||||||
|
|
||||||
|
resp, err := chat.SendRaw(context.Background(), UserMessage("test"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Text != "raw response" {
|
||||||
|
t.Errorf("expected 'raw response', got %q", resp.Text)
|
||||||
|
}
|
||||||
|
if !resp.HasToolCalls() {
|
||||||
|
t.Error("expected HasToolCalls() to be true")
|
||||||
|
}
|
||||||
|
if len(resp.ToolCalls) != 1 {
|
||||||
|
t.Fatalf("expected 1 tool call, got %d", len(resp.ToolCalls))
|
||||||
|
}
|
||||||
|
if resp.ToolCalls[0].Name != "tool1" {
|
||||||
|
t.Errorf("expected tool name 'tool1', got %q", resp.ToolCalls[0].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChat_SendRaw_ManualToolResults(t *testing.T) {
|
||||||
|
var callCount int32
|
||||||
|
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
n := atomic.AddInt32(&callCount, 1)
|
||||||
|
if n == 1 {
|
||||||
|
return provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{ID: "tc1", Name: "tool1", Arguments: "{}"},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return provider.Response{Text: "final"}, nil
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
chat := NewChat(model)
|
||||||
|
|
||||||
|
// First call returns tool calls
|
||||||
|
resp, err := chat.SendRaw(context.Background(), UserMessage("test"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !resp.HasToolCalls() {
|
||||||
|
t.Fatal("expected tool calls")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manually add tool result
|
||||||
|
chat.AddToolResults(ToolResultMessage("tc1", "tool result"))
|
||||||
|
|
||||||
|
// Second call returns text
|
||||||
|
resp, err = chat.SendRaw(context.Background(), UserMessage("continue"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Text != "final" {
|
||||||
|
t.Errorf("expected 'final', got %q", resp.Text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the full history
|
||||||
|
msgs := chat.Messages()
|
||||||
|
// user, assistant(tool call), tool result, user, assistant(text)
|
||||||
|
if len(msgs) != 5 {
|
||||||
|
t.Fatalf("expected 5 messages, got %d", len(msgs))
|
||||||
|
}
|
||||||
|
if msgs[2].Role != RoleTool {
|
||||||
|
t.Errorf("expected msg[2] role=tool, got %v", msgs[2].Role)
|
||||||
|
}
|
||||||
|
if msgs[2].ToolCallID != "tc1" {
|
||||||
|
t.Errorf("expected msg[2] toolCallID=tc1, got %q", msgs[2].ToolCallID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChat_Messages(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
chat := NewChat(model)
|
||||||
|
|
||||||
|
_, _ = chat.Send(context.Background(), "test")
|
||||||
|
|
||||||
|
msgs := chat.Messages()
|
||||||
|
// Verify it's a copy — modifying returned slice shouldn't affect chat
|
||||||
|
msgs[0] = Message{}
|
||||||
|
|
||||||
|
original := chat.Messages()
|
||||||
|
if original[0].Role != RoleUser {
|
||||||
|
t.Error("Messages() did not return a copy")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChat_Reset(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
chat := NewChat(model)
|
||||||
|
|
||||||
|
_, _ = chat.Send(context.Background(), "test")
|
||||||
|
if len(chat.Messages()) == 0 {
|
||||||
|
t.Fatal("expected messages before reset")
|
||||||
|
}
|
||||||
|
|
||||||
|
chat.Reset()
|
||||||
|
if len(chat.Messages()) != 0 {
|
||||||
|
t.Errorf("expected 0 messages after reset, got %d", len(chat.Messages()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChat_Fork(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
chat := NewChat(model)
|
||||||
|
|
||||||
|
_, _ = chat.Send(context.Background(), "msg1")
|
||||||
|
|
||||||
|
fork := chat.Fork()
|
||||||
|
|
||||||
|
// Fork should have same history
|
||||||
|
if len(fork.Messages()) != len(chat.Messages()) {
|
||||||
|
t.Fatalf("fork should have same message count: got %d vs %d", len(fork.Messages()), len(chat.Messages()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adding to fork should not affect original
|
||||||
|
_, _ = fork.Send(context.Background(), "msg2")
|
||||||
|
if len(fork.Messages()) == len(chat.Messages()) {
|
||||||
|
t.Error("fork messages should be independent of original")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adding to original should not affect fork
|
||||||
|
originalLen := len(chat.Messages())
|
||||||
|
_, _ = chat.Send(context.Background(), "msg3")
|
||||||
|
if len(chat.Messages()) == originalLen {
|
||||||
|
t.Error("original should have more messages after send")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChat_SendWithImages(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{Text: "I see an image"})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
chat := NewChat(model)
|
||||||
|
|
||||||
|
img := Image{URL: "https://example.com/image.png"}
|
||||||
|
text, err := chat.SendWithImages(context.Background(), "What's in this image?", img)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if text != "I see an image" {
|
||||||
|
t.Errorf("expected 'I see an image', got %q", text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the image was passed through to the provider
|
||||||
|
req := mp.lastRequest()
|
||||||
|
if len(req.Messages) == 0 {
|
||||||
|
t.Fatal("expected messages in request")
|
||||||
|
}
|
||||||
|
lastUserMsg := req.Messages[0]
|
||||||
|
if len(lastUserMsg.Images) != 1 {
|
||||||
|
t.Fatalf("expected 1 image, got %d", len(lastUserMsg.Images))
|
||||||
|
}
|
||||||
|
if lastUserMsg.Images[0].URL != "https://example.com/image.png" {
|
||||||
|
t.Errorf("expected image URL, got %q", lastUserMsg.Images[0].URL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChat_MultipleToolCallRounds(t *testing.T) {
|
||||||
|
var callCount int32
|
||||||
|
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
n := atomic.AddInt32(&callCount, 1)
|
||||||
|
if n <= 3 {
|
||||||
|
return provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{ID: "tc" + string(rune('0'+n)), Name: "counter", Arguments: "{}"},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return provider.Response{Text: "all done"}, nil
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
chat := NewChat(model)
|
||||||
|
|
||||||
|
var execCount int32
|
||||||
|
tool := DefineSimple("counter", "Counts", func(ctx context.Context) (string, error) {
|
||||||
|
atomic.AddInt32(&execCount, 1)
|
||||||
|
return "counted", nil
|
||||||
|
})
|
||||||
|
chat.SetTools(NewToolBox(tool))
|
||||||
|
|
||||||
|
text, err := chat.Send(context.Background(), "count three times")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if text != "all done" {
|
||||||
|
t.Errorf("expected 'all done', got %q", text)
|
||||||
|
}
|
||||||
|
if atomic.LoadInt32(&callCount) != 4 {
|
||||||
|
t.Errorf("expected 4 provider calls, got %d", callCount)
|
||||||
|
}
|
||||||
|
if atomic.LoadInt32(&execCount) != 3 {
|
||||||
|
t.Errorf("expected 3 tool executions, got %d", execCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChat_SendError(t *testing.T) {
|
||||||
|
wantErr := errors.New("provider failed")
|
||||||
|
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
return provider.Response{}, wantErr
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
chat := NewChat(model)
|
||||||
|
|
||||||
|
_, err := chat.Send(context.Background(), "test")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, wantErr) {
|
||||||
|
t.Errorf("expected wrapped provider error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChat_WithRequestOptions(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
chat := NewChat(model, WithTemperature(0.5), WithMaxTokens(200))
|
||||||
|
|
||||||
|
_, err := chat.Send(context.Background(), "test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := mp.lastRequest()
|
||||||
|
if req.Temperature == nil || *req.Temperature != 0.5 {
|
||||||
|
t.Errorf("expected temperature 0.5, got %v", req.Temperature)
|
||||||
|
}
|
||||||
|
if req.MaxTokens == nil || *req.MaxTokens != 200 {
|
||||||
|
t.Errorf("expected maxTokens 200, got %v", req.MaxTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
48
v2/constructors.go
Normal file
48
v2/constructors.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
anthProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/anthropic"
|
||||||
|
googleProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/google"
|
||||||
|
openaiProvider "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenAI creates an OpenAI client.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// model := llm.OpenAI("sk-...").Model("gpt-4o")
|
||||||
|
func OpenAI(apiKey string, opts ...ClientOption) *Client {
|
||||||
|
cfg := &clientConfig{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(cfg)
|
||||||
|
}
|
||||||
|
return NewClient(openaiProvider.New(apiKey, cfg.baseURL))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Anthropic creates an Anthropic client.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// model := llm.Anthropic("sk-ant-...").Model("claude-sonnet-4-20250514")
|
||||||
|
func Anthropic(apiKey string, opts ...ClientOption) *Client {
|
||||||
|
cfg := &clientConfig{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(cfg)
|
||||||
|
}
|
||||||
|
_ = cfg // Anthropic doesn't support custom base URL in the SDK
|
||||||
|
return NewClient(anthProvider.New(apiKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Google creates a Google (Gemini) client.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// model := llm.Google("...").Model("gemini-2.0-flash")
|
||||||
|
func Google(apiKey string, opts ...ClientOption) *Client {
|
||||||
|
cfg := &clientConfig{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(cfg)
|
||||||
|
}
|
||||||
|
_ = cfg // Google doesn't support custom base URL in the SDK
|
||||||
|
return NewClient(googleProvider.New(apiKey))
|
||||||
|
}
|
||||||
20
v2/errors.go
Normal file
20
v2/errors.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrNoToolsConfigured is returned when the model requests tool calls but no tools are available.
|
||||||
|
ErrNoToolsConfigured = errors.New("model requested tool calls but no tools configured")
|
||||||
|
|
||||||
|
// ErrToolNotFound is returned when a requested tool is not in the toolbox.
|
||||||
|
ErrToolNotFound = errors.New("tool not found")
|
||||||
|
|
||||||
|
// ErrNotConnected is returned when trying to use an MCP server that isn't connected.
|
||||||
|
ErrNotConnected = errors.New("MCP server not connected")
|
||||||
|
|
||||||
|
// ErrStreamClosed is returned when trying to read from a closed stream.
|
||||||
|
ErrStreamClosed = errors.New("stream closed")
|
||||||
|
|
||||||
|
// ErrNoStructuredOutput is returned when the model did not return a structured output tool call.
|
||||||
|
ErrNoStructuredOutput = errors.New("model did not return structured output")
|
||||||
|
)
|
||||||
54
v2/generate.go
Normal file
54
v2/generate.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/internal/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
const structuredOutputToolName = "structured_output"
|
||||||
|
|
||||||
|
// Generate sends a single user prompt to the model and parses the response into T.
|
||||||
|
// T must be a struct. The model is forced to return structured output matching T's schema
|
||||||
|
// by using a hidden tool call internally.
|
||||||
|
func Generate[T any](ctx context.Context, model *Model, prompt string, opts ...RequestOption) (T, error) {
|
||||||
|
return GenerateWith[T](ctx, model, []Message{UserMessage(prompt)}, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateWith sends the given messages to the model and parses the response into T.
|
||||||
|
// T must be a struct. The model is forced to return structured output matching T's schema
|
||||||
|
// by using a hidden tool call internally.
|
||||||
|
func GenerateWith[T any](ctx context.Context, model *Model, messages []Message, opts ...RequestOption) (T, error) {
|
||||||
|
var zero T
|
||||||
|
|
||||||
|
s := schema.FromStruct(zero)
|
||||||
|
|
||||||
|
tool := Tool{
|
||||||
|
Name: structuredOutputToolName,
|
||||||
|
Description: "Return your response as structured data using this function. You MUST call this function with your response.",
|
||||||
|
Schema: s,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append WithTools as the last option so it overrides any user-provided tools.
|
||||||
|
opts = append(opts, WithTools(NewToolBox(tool)))
|
||||||
|
|
||||||
|
resp, err := model.Complete(ctx, messages, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the structured_output tool call in the response.
|
||||||
|
for _, tc := range resp.ToolCalls {
|
||||||
|
if tc.Name == structuredOutputToolName {
|
||||||
|
var result T
|
||||||
|
if err := json.Unmarshal([]byte(tc.Arguments), &result); err != nil {
|
||||||
|
return zero, fmt.Errorf("failed to parse structured output: %w", err)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return zero, ErrNoStructuredOutput
|
||||||
|
}
|
||||||
241
v2/generate_test.go
Normal file
241
v2/generate_test.go
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testPerson struct {
|
||||||
|
Name string `json:"name" description:"The person's name"`
|
||||||
|
Age int `json:"age" description:"The person's age"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerate(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Name: "structured_output",
|
||||||
|
Arguments: `{"name":"Alice","age":30}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
result, err := Generate[testPerson](context.Background(), model, "Tell me about Alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if result.Name != "Alice" {
|
||||||
|
t.Errorf("expected name 'Alice', got %q", result.Name)
|
||||||
|
}
|
||||||
|
if result.Age != 30 {
|
||||||
|
t.Errorf("expected age 30, got %d", result.Age)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the tool was sent in the request
|
||||||
|
req := mp.lastRequest()
|
||||||
|
if len(req.Tools) != 1 {
|
||||||
|
t.Fatalf("expected 1 tool, got %d", len(req.Tools))
|
||||||
|
}
|
||||||
|
if req.Tools[0].Name != "structured_output" {
|
||||||
|
t.Errorf("expected tool name 'structured_output', got %q", req.Tools[0].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateWith(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Name: "structured_output",
|
||||||
|
Arguments: `{"name":"Bob","age":25}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
messages := []Message{
|
||||||
|
SystemMessage("You are helpful."),
|
||||||
|
UserMessage("Tell me about Bob"),
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := GenerateWith[testPerson](context.Background(), model, messages)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if result.Name != "Bob" {
|
||||||
|
t.Errorf("expected name 'Bob', got %q", result.Name)
|
||||||
|
}
|
||||||
|
if result.Age != 25 {
|
||||||
|
t.Errorf("expected age 25, got %d", result.Age)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify messages were passed through
|
||||||
|
req := mp.lastRequest()
|
||||||
|
if len(req.Messages) != 2 {
|
||||||
|
t.Fatalf("expected 2 messages, got %d", len(req.Messages))
|
||||||
|
}
|
||||||
|
if req.Messages[0].Role != "system" {
|
||||||
|
t.Errorf("expected first message role 'system', got %q", req.Messages[0].Role)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerate_NoToolCall(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{
|
||||||
|
Text: "I can't use tools right now.",
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
_, err := Generate[testPerson](context.Background(), model, "Tell me about someone")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrNoStructuredOutput) {
|
||||||
|
t.Errorf("expected ErrNoStructuredOutput, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerate_InvalidJSON(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Name: "structured_output",
|
||||||
|
Arguments: `{not valid json}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
_, err := Generate[testPerson](context.Background(), model, "Tell me about someone")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if errors.Is(err, ErrNoStructuredOutput) {
|
||||||
|
t.Error("expected parse error, not ErrNoStructuredOutput")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type testAddress struct {
|
||||||
|
Street string `json:"street" description:"Street address"`
|
||||||
|
City string `json:"city" description:"City name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type testPersonWithAddress struct {
|
||||||
|
Name string `json:"name" description:"The person's name"`
|
||||||
|
Age int `json:"age" description:"The person's age"`
|
||||||
|
Address testAddress `json:"address" description:"The person's address"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerate_NestedStruct(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Name: "structured_output",
|
||||||
|
Arguments: `{"name":"Carol","age":40,"address":{"street":"123 Main St","city":"Springfield"}}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
result, err := Generate[testPersonWithAddress](context.Background(), model, "Tell me about Carol")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if result.Name != "Carol" {
|
||||||
|
t.Errorf("expected name 'Carol', got %q", result.Name)
|
||||||
|
}
|
||||||
|
if result.Address.Street != "123 Main St" {
|
||||||
|
t.Errorf("expected street '123 Main St', got %q", result.Address.Street)
|
||||||
|
}
|
||||||
|
if result.Address.City != "Springfield" {
|
||||||
|
t.Errorf("expected city 'Springfield', got %q", result.Address.City)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerate_WithOptions(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Name: "structured_output",
|
||||||
|
Arguments: `{"name":"Dave","age":35}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
_, err := Generate[testPerson](context.Background(), model, "Tell me about Dave",
|
||||||
|
WithTemperature(0.5),
|
||||||
|
WithMaxTokens(200),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := mp.lastRequest()
|
||||||
|
if req.Temperature == nil || *req.Temperature != 0.5 {
|
||||||
|
t.Errorf("expected temperature 0.5, got %v", req.Temperature)
|
||||||
|
}
|
||||||
|
if req.MaxTokens == nil || *req.MaxTokens != 200 {
|
||||||
|
t.Errorf("expected maxTokens 200, got %v", req.MaxTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerate_WithMiddleware(t *testing.T) {
|
||||||
|
var middlewareCalled bool
|
||||||
|
mw := func(next CompletionFunc) CompletionFunc {
|
||||||
|
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||||
|
middlewareCalled = true
|
||||||
|
return next(ctx, model, messages, cfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mp := newMockProvider(provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Name: "structured_output",
|
||||||
|
Arguments: `{"name":"Eve","age":28}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
model := newMockModel(mp).WithMiddleware(mw)
|
||||||
|
|
||||||
|
result, err := Generate[testPerson](context.Background(), model, "Tell me about Eve")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !middlewareCalled {
|
||||||
|
t.Error("middleware was not called")
|
||||||
|
}
|
||||||
|
if result.Name != "Eve" {
|
||||||
|
t.Errorf("expected name 'Eve', got %q", result.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerate_WrongToolName(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Name: "some_other_tool",
|
||||||
|
Arguments: `{"name":"Frank","age":50}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
_, err := Generate[testPerson](context.Background(), model, "Tell me about Frank")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrNoStructuredOutput) {
|
||||||
|
t.Errorf("expected ErrNoStructuredOutput, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
41
v2/go.mod
Normal file
41
v2/go.mod
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
module gitea.stevedudenhoeffer.com/steve/go-llm/v2
|
||||||
|
|
||||||
|
go 1.24.0
|
||||||
|
|
||||||
|
toolchain go1.24.2
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/liushuangls/go-anthropic/v2 v2.17.0
|
||||||
|
github.com/modelcontextprotocol/go-sdk v1.2.0
|
||||||
|
github.com/openai/openai-go v1.12.0
|
||||||
|
github.com/pkg/sftp v1.13.10
|
||||||
|
golang.org/x/crypto v0.41.0
|
||||||
|
golang.org/x/image v0.35.0
|
||||||
|
google.golang.org/genai v1.45.0
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
cloud.google.com/go v0.116.0 // indirect
|
||||||
|
cloud.google.com/go/auth v0.9.3 // indirect
|
||||||
|
cloud.google.com/go/compute/metadata v0.5.0 // indirect
|
||||||
|
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||||
|
github.com/google/go-cmp v0.7.0 // indirect
|
||||||
|
github.com/google/jsonschema-go v0.3.0 // indirect
|
||||||
|
github.com/google/s2a-go v0.1.8 // indirect
|
||||||
|
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
|
||||||
|
github.com/gorilla/websocket v1.5.3 // indirect
|
||||||
|
github.com/kr/fs v0.1.0 // indirect
|
||||||
|
github.com/tidwall/gjson v1.14.4 // indirect
|
||||||
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
|
github.com/tidwall/pretty v1.2.1 // indirect
|
||||||
|
github.com/tidwall/sjson v1.2.5 // indirect
|
||||||
|
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||||
|
go.opencensus.io v0.24.0 // indirect
|
||||||
|
golang.org/x/net v0.42.0 // indirect
|
||||||
|
golang.org/x/oauth2 v0.30.0 // indirect
|
||||||
|
golang.org/x/sys v0.35.0 // indirect
|
||||||
|
golang.org/x/text v0.33.0 // indirect
|
||||||
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
|
||||||
|
google.golang.org/grpc v1.66.2 // indirect
|
||||||
|
google.golang.org/protobuf v1.34.2 // indirect
|
||||||
|
)
|
||||||
165
v2/go.sum
Normal file
165
v2/go.sum
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||||
|
cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE=
|
||||||
|
cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U=
|
||||||
|
cloud.google.com/go/auth v0.9.3 h1:VOEUIAADkkLtyfr3BLa3R8Ed/j6w1jTBmARx+wb5w5U=
|
||||||
|
cloud.google.com/go/auth v0.9.3/go.mod h1:7z6VY+7h3KUdRov5F1i8NDP5ZzWKYmEPO842BgCsmTk=
|
||||||
|
cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY=
|
||||||
|
cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY=
|
||||||
|
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||||
|
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
||||||
|
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||||
|
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||||
|
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||||
|
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
||||||
|
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
|
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||||
|
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||||
|
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
|
||||||
|
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||||
|
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||||
|
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||||
|
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||||
|
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
|
||||||
|
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
|
||||||
|
github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
|
||||||
|
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
|
||||||
|
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
|
||||||
|
github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
|
||||||
|
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
|
||||||
|
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||||
|
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||||
|
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||||
|
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
|
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
|
github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
|
github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q=
|
||||||
|
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||||
|
github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM=
|
||||||
|
github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA=
|
||||||
|
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw=
|
||||||
|
github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA=
|
||||||
|
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||||
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
|
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
|
||||||
|
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
|
||||||
|
github.com/liushuangls/go-anthropic/v2 v2.17.0 h1:iBA6h7aghi1q86owEQ95XE2R2MF/0dQ7bCxtwTxOg4c=
|
||||||
|
github.com/liushuangls/go-anthropic/v2 v2.17.0/go.mod h1:a550cJXPoTG2FL3DvfKG2zzD5O2vjgvo4tHtoGPzFLU=
|
||||||
|
github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s=
|
||||||
|
github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10=
|
||||||
|
github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0=
|
||||||
|
github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
|
||||||
|
github.com/pkg/sftp v1.13.10 h1:+5FbKNTe5Z9aspU88DPIKJ9z2KZoaGCu6Sr6kKR/5mU=
|
||||||
|
github.com/pkg/sftp v1.13.10/go.mod h1:bJ1a7uDhrX/4OII+agvy28lzRvQrmIQuaHrcI1HbeGA=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||||
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
|
||||||
|
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||||
|
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
|
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||||
|
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||||
|
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||||
|
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||||
|
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||||
|
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
|
||||||
|
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
|
||||||
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
|
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
||||||
|
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
||||||
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
|
golang.org/x/image v0.35.0 h1:LKjiHdgMtO8z7Fh18nGY6KDcoEtVfsgLDPeLyguqb7I=
|
||||||
|
golang.org/x/image v0.35.0/go.mod h1:MwPLTVgvxSASsxdLzKrl8BRFuyqMyGhLwmC+TO1Sybk=
|
||||||
|
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||||
|
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
|
||||||
|
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||||
|
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
|
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
|
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
|
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
|
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
|
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||||
|
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
|
||||||
|
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
|
||||||
|
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||||
|
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||||
|
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||||
|
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||||
|
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
|
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||||
|
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
|
golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4=
|
||||||
|
golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw=
|
||||||
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
|
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
|
||||||
|
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
|
||||||
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
|
||||||
|
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||||
|
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||||
|
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||||
|
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||||
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||||
|
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||||
|
google.golang.org/genai v1.45.0 h1:s80ZpS42XW0zu/ogiOtenCio17nJ7reEFJjoCftukpA=
|
||||||
|
google.golang.org/genai v1.45.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||||
|
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
|
||||||
|
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
|
||||||
|
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
|
||||||
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ=
|
||||||
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU=
|
||||||
|
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||||
|
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
|
||||||
|
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
|
||||||
|
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
|
||||||
|
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
|
||||||
|
google.golang.org/grpc v1.66.2 h1:3QdXkuq3Bkh7w+ywLdLvM56cmGvQHUMZpiCzt6Rqaoo=
|
||||||
|
google.golang.org/grpc v1.66.2/go.mod h1:s3/l6xSSCURdVfAnL+TqCNMyTDAGN6+lZeVxnZR128Y=
|
||||||
|
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
||||||
|
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
||||||
|
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
|
||||||
|
google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
|
||||||
|
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
||||||
|
google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||||
|
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||||
|
google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||||
|
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
|
||||||
|
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
|
||||||
|
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||||
|
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||||
355
v2/google/google.go
Normal file
355
v2/google/google.go
Normal file
@@ -0,0 +1,355 @@
|
|||||||
|
// Package google implements the go-llm v2 provider interface for Google (Gemini).
|
||||||
|
package google
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
|
||||||
|
"google.golang.org/genai"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Provider implements the provider.Provider interface for Google Gemini.
|
||||||
|
type Provider struct {
|
||||||
|
apiKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new Google provider.
|
||||||
|
func New(apiKey string) *Provider {
|
||||||
|
return &Provider{apiKey: apiKey}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Complete performs a non-streaming completion.
|
||||||
|
func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
cl, err := genai.NewClient(ctx, &genai.ClientConfig{
|
||||||
|
APIKey: p.apiKey,
|
||||||
|
Backend: genai.BackendGeminiAPI,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return provider.Response{}, fmt.Errorf("google client error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contents, cfg := p.buildRequest(req)
|
||||||
|
|
||||||
|
resp, err := cl.Models.GenerateContent(ctx, req.Model, contents, cfg)
|
||||||
|
if err != nil {
|
||||||
|
return provider.Response{}, fmt.Errorf("google completion error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.convertResponse(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream performs a streaming completion.
|
||||||
|
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
|
||||||
|
cl, err := genai.NewClient(ctx, &genai.ClientConfig{
|
||||||
|
APIKey: p.apiKey,
|
||||||
|
Backend: genai.BackendGeminiAPI,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("google client error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contents, cfg := p.buildRequest(req)
|
||||||
|
|
||||||
|
var fullText strings.Builder
|
||||||
|
var toolCalls []provider.ToolCall
|
||||||
|
|
||||||
|
for resp, err := range cl.Models.GenerateContentStream(ctx, req.Model, contents, cfg) {
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("google stream error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range resp.Candidates {
|
||||||
|
if c.Content == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, part := range c.Content.Parts {
|
||||||
|
if part.Text != "" {
|
||||||
|
fullText.WriteString(part.Text)
|
||||||
|
events <- provider.StreamEvent{
|
||||||
|
Type: provider.StreamEventText,
|
||||||
|
Text: part.Text,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if part.FunctionCall != nil {
|
||||||
|
args, _ := json.Marshal(part.FunctionCall.Args)
|
||||||
|
tc := provider.ToolCall{
|
||||||
|
ID: part.FunctionCall.Name,
|
||||||
|
Name: part.FunctionCall.Name,
|
||||||
|
Arguments: string(args),
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, tc)
|
||||||
|
events <- provider.StreamEvent{
|
||||||
|
Type: provider.StreamEventToolStart,
|
||||||
|
ToolCall: &tc,
|
||||||
|
ToolIndex: len(toolCalls) - 1,
|
||||||
|
}
|
||||||
|
events <- provider.StreamEvent{
|
||||||
|
Type: provider.StreamEventToolEnd,
|
||||||
|
ToolCall: &tc,
|
||||||
|
ToolIndex: len(toolCalls) - 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
events <- provider.StreamEvent{
|
||||||
|
Type: provider.StreamEventDone,
|
||||||
|
Response: &provider.Response{
|
||||||
|
Text: fullText.String(),
|
||||||
|
ToolCalls: toolCalls,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) buildRequest(req provider.Request) ([]*genai.Content, *genai.GenerateContentConfig) {
|
||||||
|
var contents []*genai.Content
|
||||||
|
cfg := &genai.GenerateContentConfig{}
|
||||||
|
|
||||||
|
for _, tool := range req.Tools {
|
||||||
|
cfg.Tools = append(cfg.Tools, &genai.Tool{
|
||||||
|
FunctionDeclarations: []*genai.FunctionDeclaration{
|
||||||
|
{
|
||||||
|
Name: tool.Name,
|
||||||
|
Description: tool.Description,
|
||||||
|
Parameters: schemaToGenai(tool.Schema),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Temperature != nil {
|
||||||
|
f := float32(*req.Temperature)
|
||||||
|
cfg.Temperature = &f
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.MaxTokens != nil {
|
||||||
|
cfg.MaxOutputTokens = int32(*req.MaxTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.TopP != nil {
|
||||||
|
f := float32(*req.TopP)
|
||||||
|
cfg.TopP = &f
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(req.Stop) > 0 {
|
||||||
|
cfg.StopSequences = req.Stop
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, msg := range req.Messages {
|
||||||
|
var role genai.Role
|
||||||
|
switch msg.Role {
|
||||||
|
case "system":
|
||||||
|
cfg.SystemInstruction = genai.NewContentFromText(msg.Content, genai.RoleUser)
|
||||||
|
continue
|
||||||
|
case "assistant":
|
||||||
|
role = genai.RoleModel
|
||||||
|
case "tool":
|
||||||
|
// Tool results go as function responses (Genai uses RoleUser for function responses)
|
||||||
|
contents = append(contents, &genai.Content{
|
||||||
|
Role: genai.RoleUser,
|
||||||
|
Parts: []*genai.Part{
|
||||||
|
{
|
||||||
|
FunctionResponse: &genai.FunctionResponse{
|
||||||
|
Name: msg.ToolCallID,
|
||||||
|
Response: map[string]any{
|
||||||
|
"result": msg.Content,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
default:
|
||||||
|
role = genai.RoleUser
|
||||||
|
}
|
||||||
|
|
||||||
|
var parts []*genai.Part
|
||||||
|
|
||||||
|
if msg.Content != "" {
|
||||||
|
parts = append(parts, genai.NewPartFromText(msg.Content))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle tool calls in assistant messages
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
var args map[string]any
|
||||||
|
if tc.Arguments != "" {
|
||||||
|
_ = json.Unmarshal([]byte(tc.Arguments), &args)
|
||||||
|
}
|
||||||
|
parts = append(parts, &genai.Part{
|
||||||
|
FunctionCall: &genai.FunctionCall{
|
||||||
|
Name: tc.Name,
|
||||||
|
Args: args,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, img := range msg.Images {
|
||||||
|
if img.URL != "" {
|
||||||
|
// Gemini doesn't support URLs directly; download
|
||||||
|
resp, err := http.Get(img.URL)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data, err := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
mimeType := http.DetectContentType(data)
|
||||||
|
parts = append(parts, genai.NewPartFromBytes(data, mimeType))
|
||||||
|
} else if img.Base64 != "" {
|
||||||
|
data, err := base64.StdEncoding.DecodeString(img.Base64)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parts = append(parts, genai.NewPartFromBytes(data, img.ContentType))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, aud := range msg.Audio {
|
||||||
|
if aud.URL != "" {
|
||||||
|
resp, err := http.Get(aud.URL)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data, err := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
mimeType := resp.Header.Get("Content-Type")
|
||||||
|
if mimeType == "" {
|
||||||
|
mimeType = aud.ContentType
|
||||||
|
}
|
||||||
|
if mimeType == "" {
|
||||||
|
mimeType = "audio/wav"
|
||||||
|
}
|
||||||
|
parts = append(parts, genai.NewPartFromBytes(data, mimeType))
|
||||||
|
} else if aud.Base64 != "" {
|
||||||
|
data, err := base64.StdEncoding.DecodeString(aud.Base64)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ct := aud.ContentType
|
||||||
|
if ct == "" {
|
||||||
|
ct = "audio/wav"
|
||||||
|
}
|
||||||
|
parts = append(parts, genai.NewPartFromBytes(data, ct))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
contents = append(contents, genai.NewContentFromParts(parts, role))
|
||||||
|
}
|
||||||
|
|
||||||
|
return contents, cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) convertResponse(resp *genai.GenerateContentResponse) (provider.Response, error) {
|
||||||
|
var res provider.Response
|
||||||
|
|
||||||
|
for _, c := range resp.Candidates {
|
||||||
|
if c.Content == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, part := range c.Content.Parts {
|
||||||
|
if part.Text != "" {
|
||||||
|
res.Text += part.Text
|
||||||
|
}
|
||||||
|
if part.FunctionCall != nil {
|
||||||
|
args, _ := json.Marshal(part.FunctionCall.Args)
|
||||||
|
res.ToolCalls = append(res.ToolCalls, provider.ToolCall{
|
||||||
|
ID: part.FunctionCall.Name,
|
||||||
|
Name: part.FunctionCall.Name,
|
||||||
|
Arguments: string(args),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.UsageMetadata != nil {
|
||||||
|
res.Usage = &provider.Usage{
|
||||||
|
InputTokens: int(resp.UsageMetadata.PromptTokenCount),
|
||||||
|
OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount),
|
||||||
|
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// schemaToGenai converts a JSON Schema map to a genai.Schema.
|
||||||
|
func schemaToGenai(s map[string]any) *genai.Schema {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
schema := &genai.Schema{}
|
||||||
|
|
||||||
|
if t, ok := s["type"].(string); ok {
|
||||||
|
switch t {
|
||||||
|
case "object":
|
||||||
|
schema.Type = genai.TypeObject
|
||||||
|
case "array":
|
||||||
|
schema.Type = genai.TypeArray
|
||||||
|
case "string":
|
||||||
|
schema.Type = genai.TypeString
|
||||||
|
case "integer":
|
||||||
|
schema.Type = genai.TypeInteger
|
||||||
|
case "number":
|
||||||
|
schema.Type = genai.TypeNumber
|
||||||
|
case "boolean":
|
||||||
|
schema.Type = genai.TypeBoolean
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if desc, ok := s["description"].(string); ok {
|
||||||
|
schema.Description = desc
|
||||||
|
}
|
||||||
|
|
||||||
|
if props, ok := s["properties"].(map[string]any); ok {
|
||||||
|
schema.Properties = make(map[string]*genai.Schema)
|
||||||
|
for k, v := range props {
|
||||||
|
if vm, ok := v.(map[string]any); ok {
|
||||||
|
schema.Properties[k] = schemaToGenai(vm)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req, ok := s["required"].([]string); ok {
|
||||||
|
schema.Required = req
|
||||||
|
} else if req, ok := s["required"].([]any); ok {
|
||||||
|
for _, r := range req {
|
||||||
|
if rs, ok := r.(string); ok {
|
||||||
|
schema.Required = append(schema.Required, rs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if items, ok := s["items"].(map[string]any); ok {
|
||||||
|
schema.Items = schemaToGenai(items)
|
||||||
|
}
|
||||||
|
|
||||||
|
if enums, ok := s["enum"].([]string); ok {
|
||||||
|
schema.Enum = enums
|
||||||
|
} else if enums, ok := s["enum"].([]any); ok {
|
||||||
|
for _, e := range enums {
|
||||||
|
if es, ok := e.(string); ok {
|
||||||
|
schema.Enum = append(schema.Enum, es)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return schema
|
||||||
|
}
|
||||||
105
v2/internal/imageutil/compress.go
Normal file
105
v2/internal/imageutil/compress.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
// Package imageutil provides image compression utilities.
|
||||||
|
package imageutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"image/gif"
|
||||||
|
"image/jpeg"
|
||||||
|
_ "image/png" // register PNG decoder
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"golang.org/x/image/draw"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CompressImage takes a base-64-encoded image (JPEG, PNG or GIF) and returns
|
||||||
|
// a base-64-encoded version that is at most maxLength bytes, along with the MIME type.
|
||||||
|
func CompressImage(b64 string, maxLength int) (string, string, error) {
|
||||||
|
raw, err := base64.StdEncoding.DecodeString(b64)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("base64 decode: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mime := http.DetectContentType(raw)
|
||||||
|
if len(raw) <= maxLength {
|
||||||
|
return b64, mime, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch mime {
|
||||||
|
case "image/gif":
|
||||||
|
return compressGIF(raw, maxLength)
|
||||||
|
default:
|
||||||
|
return compressRaster(raw, maxLength)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func compressRaster(src []byte, maxLength int) (string, string, error) {
|
||||||
|
img, _, err := image.Decode(bytes.NewReader(src))
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("decode raster: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
quality := 95
|
||||||
|
for {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: quality}); err != nil {
|
||||||
|
return "", "", fmt.Errorf("jpeg encode: %w", err)
|
||||||
|
}
|
||||||
|
if buf.Len() <= maxLength {
|
||||||
|
return base64.StdEncoding.EncodeToString(buf.Bytes()), "image/jpeg", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if quality > 20 {
|
||||||
|
quality -= 5
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
b := img.Bounds()
|
||||||
|
if b.Dx() < 100 || b.Dy() < 100 {
|
||||||
|
return "", "", fmt.Errorf("cannot compress below %.02fMiB without destroying image", float64(maxLength)/1048576.0)
|
||||||
|
}
|
||||||
|
dst := image.NewRGBA(image.Rect(0, 0, int(float64(b.Dx())*0.8), int(float64(b.Dy())*0.8)))
|
||||||
|
draw.ApproxBiLinear.Scale(dst, dst.Bounds(), img, b, draw.Over, nil)
|
||||||
|
img = dst
|
||||||
|
quality = 95
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func compressGIF(src []byte, maxLength int) (string, string, error) {
|
||||||
|
g, err := gif.DecodeAll(bytes.NewReader(src))
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("gif decode: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := gif.EncodeAll(&buf, g); err != nil {
|
||||||
|
return "", "", fmt.Errorf("gif encode: %w", err)
|
||||||
|
}
|
||||||
|
if buf.Len() <= maxLength {
|
||||||
|
return base64.StdEncoding.EncodeToString(buf.Bytes()), "image/gif", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
w, h := g.Config.Width, g.Config.Height
|
||||||
|
if w < 100 || h < 100 {
|
||||||
|
return "", "", fmt.Errorf("cannot compress animated GIF below %.02fMiB", float64(maxLength)/1048576.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
nw, nh := int(float64(w)*0.8), int(float64(h)*0.8)
|
||||||
|
for i, frm := range g.Image {
|
||||||
|
rgba := image.NewRGBA(frm.Bounds())
|
||||||
|
draw.Draw(rgba, rgba.Bounds(), frm, frm.Bounds().Min, draw.Src)
|
||||||
|
|
||||||
|
dst := image.NewRGBA(image.Rect(0, 0, nw, nh))
|
||||||
|
draw.ApproxBiLinear.Scale(dst, dst.Bounds(), rgba, rgba.Bounds(), draw.Over, nil)
|
||||||
|
|
||||||
|
paletted := image.NewPaletted(dst.Bounds(), nil)
|
||||||
|
draw.FloydSteinberg.Draw(paletted, paletted.Bounds(), dst, dst.Bounds().Min)
|
||||||
|
|
||||||
|
g.Image[i] = paletted
|
||||||
|
}
|
||||||
|
g.Config.Width, g.Config.Height = nw, nh
|
||||||
|
}
|
||||||
|
}
|
||||||
188
v2/internal/schema/schema.go
Normal file
188
v2/internal/schema/schema.go
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
// Package schema provides JSON Schema generation from Go structs.
|
||||||
|
// It produces standard JSON Schema as map[string]any, with no provider-specific types.
|
||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FromStruct generates a JSON Schema object from a Go struct.
|
||||||
|
// Struct tags used:
|
||||||
|
// - `json:"name"` — sets the property name (standard Go JSON convention)
|
||||||
|
// - `description:"..."` — sets the property description
|
||||||
|
// - `enum:"a,b,c"` — restricts string values to the given set
|
||||||
|
//
|
||||||
|
// Pointer fields are treated as optional; non-pointer fields are required.
|
||||||
|
// Anonymous (embedded) struct fields are flattened into the parent.
|
||||||
|
func FromStruct(v any) map[string]any {
|
||||||
|
t := reflect.TypeOf(v)
|
||||||
|
if t.Kind() == reflect.Ptr {
|
||||||
|
t = t.Elem()
|
||||||
|
}
|
||||||
|
if t.Kind() != reflect.Struct {
|
||||||
|
panic("schema.FromStruct expects a struct or pointer to struct")
|
||||||
|
}
|
||||||
|
return objectSchema(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func objectSchema(t reflect.Type) map[string]any {
|
||||||
|
properties := map[string]any{}
|
||||||
|
var required []string
|
||||||
|
|
||||||
|
for i := 0; i < t.NumField(); i++ {
|
||||||
|
field := t.Field(i)
|
||||||
|
|
||||||
|
// Skip unexported fields
|
||||||
|
if !field.IsExported() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flatten anonymous (embedded) structs
|
||||||
|
if field.Anonymous {
|
||||||
|
ft := field.Type
|
||||||
|
if ft.Kind() == reflect.Ptr {
|
||||||
|
ft = ft.Elem()
|
||||||
|
}
|
||||||
|
if ft.Kind() == reflect.Struct {
|
||||||
|
embedded := objectSchema(ft)
|
||||||
|
if props, ok := embedded["properties"].(map[string]any); ok {
|
||||||
|
for k, v := range props {
|
||||||
|
properties[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if req, ok := embedded["required"].([]string); ok {
|
||||||
|
required = append(required, req...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
name := fieldName(field)
|
||||||
|
isRequired := true
|
||||||
|
ft := field.Type
|
||||||
|
|
||||||
|
if ft.Kind() == reflect.Ptr {
|
||||||
|
ft = ft.Elem()
|
||||||
|
isRequired = false
|
||||||
|
}
|
||||||
|
|
||||||
|
prop := fieldSchema(field, ft)
|
||||||
|
properties[name] = prop
|
||||||
|
|
||||||
|
if isRequired {
|
||||||
|
required = append(required, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": properties,
|
||||||
|
}
|
||||||
|
if len(required) > 0 {
|
||||||
|
result["required"] = required
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func fieldSchema(field reflect.StructField, ft reflect.Type) map[string]any {
|
||||||
|
prop := map[string]any{}
|
||||||
|
|
||||||
|
// Check for enum tag first
|
||||||
|
if enumTag, ok := field.Tag.Lookup("enum"); ok {
|
||||||
|
vals := parseEnum(enumTag)
|
||||||
|
prop["type"] = "string"
|
||||||
|
prop["enum"] = vals
|
||||||
|
if desc, ok := field.Tag.Lookup("description"); ok {
|
||||||
|
prop["description"] = desc
|
||||||
|
}
|
||||||
|
return prop
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ft.Kind() {
|
||||||
|
case reflect.String:
|
||||||
|
prop["type"] = "string"
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||||
|
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
prop["type"] = "integer"
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
prop["type"] = "number"
|
||||||
|
case reflect.Bool:
|
||||||
|
prop["type"] = "boolean"
|
||||||
|
case reflect.Struct:
|
||||||
|
return objectSchema(ft)
|
||||||
|
case reflect.Slice:
|
||||||
|
prop["type"] = "array"
|
||||||
|
elemType := ft.Elem()
|
||||||
|
if elemType.Kind() == reflect.Ptr {
|
||||||
|
elemType = elemType.Elem()
|
||||||
|
}
|
||||||
|
prop["items"] = typeSchema(elemType)
|
||||||
|
case reflect.Map:
|
||||||
|
prop["type"] = "object"
|
||||||
|
if ft.Key().Kind() == reflect.String {
|
||||||
|
valType := ft.Elem()
|
||||||
|
if valType.Kind() == reflect.Ptr {
|
||||||
|
valType = valType.Elem()
|
||||||
|
}
|
||||||
|
prop["additionalProperties"] = typeSchema(valType)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
prop["type"] = "string" // fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
if desc, ok := field.Tag.Lookup("description"); ok {
|
||||||
|
prop["description"] = desc
|
||||||
|
}
|
||||||
|
|
||||||
|
return prop
|
||||||
|
}
|
||||||
|
|
||||||
|
func typeSchema(t reflect.Type) map[string]any {
|
||||||
|
switch t.Kind() {
|
||||||
|
case reflect.String:
|
||||||
|
return map[string]any{"type": "string"}
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||||
|
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
return map[string]any{"type": "integer"}
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
return map[string]any{"type": "number"}
|
||||||
|
case reflect.Bool:
|
||||||
|
return map[string]any{"type": "boolean"}
|
||||||
|
case reflect.Struct:
|
||||||
|
return objectSchema(t)
|
||||||
|
case reflect.Slice:
|
||||||
|
elemType := t.Elem()
|
||||||
|
if elemType.Kind() == reflect.Ptr {
|
||||||
|
elemType = elemType.Elem()
|
||||||
|
}
|
||||||
|
return map[string]any{
|
||||||
|
"type": "array",
|
||||||
|
"items": typeSchema(elemType),
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return map[string]any{"type": "string"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fieldName(f reflect.StructField) string {
|
||||||
|
if tag, ok := f.Tag.Lookup("json"); ok {
|
||||||
|
parts := strings.SplitN(tag, ",", 2)
|
||||||
|
if parts[0] != "" && parts[0] != "-" {
|
||||||
|
return parts[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return f.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseEnum(tag string) []string {
|
||||||
|
parts := strings.Split(tag, ",")
|
||||||
|
var vals []string
|
||||||
|
for _, p := range parts {
|
||||||
|
p = strings.TrimSpace(p)
|
||||||
|
if p != "" {
|
||||||
|
vals = append(vals, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return vals
|
||||||
|
}
|
||||||
181
v2/internal/schema/schema_test.go
Normal file
181
v2/internal/schema/schema_test.go
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SimpleParams struct {
|
||||||
|
Name string `json:"name" description:"The name"`
|
||||||
|
Age int `json:"age" description:"The age"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OptionalParams struct {
|
||||||
|
Required string `json:"required" description:"A required field"`
|
||||||
|
Optional *string `json:"optional,omitempty" description:"An optional field"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EnumParams struct {
|
||||||
|
Color string `json:"color" description:"The color" enum:"red,green,blue"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type NestedParams struct {
|
||||||
|
Inner SimpleParams `json:"inner" description:"Nested object"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ArrayParams struct {
|
||||||
|
Items []string `json:"items" description:"A list of items"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddedBase struct {
|
||||||
|
ID string `json:"id" description:"The ID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddedParams struct {
|
||||||
|
EmbeddedBase
|
||||||
|
Name string `json:"name" description:"The name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromStruct_Simple(t *testing.T) {
|
||||||
|
s := FromStruct(SimpleParams{})
|
||||||
|
|
||||||
|
if s["type"] != "object" {
|
||||||
|
t.Errorf("expected type=object, got %v", s["type"])
|
||||||
|
}
|
||||||
|
|
||||||
|
props, ok := s["properties"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected properties to be map[string]any")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(props) != 2 {
|
||||||
|
t.Errorf("expected 2 properties, got %d", len(props))
|
||||||
|
}
|
||||||
|
|
||||||
|
nameSchema, ok := props["name"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected name property to be map[string]any")
|
||||||
|
}
|
||||||
|
if nameSchema["type"] != "string" {
|
||||||
|
t.Errorf("expected name type=string, got %v", nameSchema["type"])
|
||||||
|
}
|
||||||
|
if nameSchema["description"] != "The name" {
|
||||||
|
t.Errorf("expected name description='The name', got %v", nameSchema["description"])
|
||||||
|
}
|
||||||
|
|
||||||
|
ageSchema, ok := props["age"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected age property to be map[string]any")
|
||||||
|
}
|
||||||
|
if ageSchema["type"] != "integer" {
|
||||||
|
t.Errorf("expected age type=integer, got %v", ageSchema["type"])
|
||||||
|
}
|
||||||
|
|
||||||
|
required, ok := s["required"].([]string)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected required to be []string")
|
||||||
|
}
|
||||||
|
if len(required) != 2 {
|
||||||
|
t.Errorf("expected 2 required fields, got %d", len(required))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromStruct_Optional(t *testing.T) {
|
||||||
|
s := FromStruct(OptionalParams{})
|
||||||
|
|
||||||
|
required, ok := s["required"].([]string)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected required to be []string")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only "required" field should be required, not "optional"
|
||||||
|
if len(required) != 1 {
|
||||||
|
t.Errorf("expected 1 required field, got %d: %v", len(required), required)
|
||||||
|
}
|
||||||
|
if required[0] != "required" {
|
||||||
|
t.Errorf("expected required field 'required', got %v", required[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromStruct_Enum(t *testing.T) {
|
||||||
|
s := FromStruct(EnumParams{})
|
||||||
|
|
||||||
|
props := s["properties"].(map[string]any)
|
||||||
|
colorSchema := props["color"].(map[string]any)
|
||||||
|
|
||||||
|
if colorSchema["type"] != "string" {
|
||||||
|
t.Errorf("expected enum type=string, got %v", colorSchema["type"])
|
||||||
|
}
|
||||||
|
|
||||||
|
enums, ok := colorSchema["enum"].([]string)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected enum to be []string")
|
||||||
|
}
|
||||||
|
if len(enums) != 3 {
|
||||||
|
t.Errorf("expected 3 enum values, got %d", len(enums))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromStruct_Nested(t *testing.T) {
|
||||||
|
s := FromStruct(NestedParams{})
|
||||||
|
|
||||||
|
props := s["properties"].(map[string]any)
|
||||||
|
innerSchema := props["inner"].(map[string]any)
|
||||||
|
|
||||||
|
if innerSchema["type"] != "object" {
|
||||||
|
t.Errorf("expected nested type=object, got %v", innerSchema["type"])
|
||||||
|
}
|
||||||
|
|
||||||
|
innerProps := innerSchema["properties"].(map[string]any)
|
||||||
|
if len(innerProps) != 2 {
|
||||||
|
t.Errorf("expected 2 inner properties, got %d", len(innerProps))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromStruct_Array(t *testing.T) {
|
||||||
|
s := FromStruct(ArrayParams{})
|
||||||
|
|
||||||
|
props := s["properties"].(map[string]any)
|
||||||
|
itemsSchema := props["items"].(map[string]any)
|
||||||
|
|
||||||
|
if itemsSchema["type"] != "array" {
|
||||||
|
t.Errorf("expected array type=array, got %v", itemsSchema["type"])
|
||||||
|
}
|
||||||
|
|
||||||
|
items := itemsSchema["items"].(map[string]any)
|
||||||
|
if items["type"] != "string" {
|
||||||
|
t.Errorf("expected items type=string, got %v", items["type"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromStruct_Embedded(t *testing.T) {
|
||||||
|
s := FromStruct(EmbeddedParams{})
|
||||||
|
|
||||||
|
props := s["properties"].(map[string]any)
|
||||||
|
|
||||||
|
// Should have both ID from embedded and Name
|
||||||
|
if len(props) != 2 {
|
||||||
|
t.Errorf("expected 2 properties (flattened), got %d", len(props))
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := props["id"]; !ok {
|
||||||
|
t.Error("expected 'id' property from embedded struct")
|
||||||
|
}
|
||||||
|
if _, ok := props["name"]; !ok {
|
||||||
|
t.Error("expected 'name' property")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromStruct_ValidJSON(t *testing.T) {
|
||||||
|
s := FromStruct(SimpleParams{})
|
||||||
|
|
||||||
|
data, err := json.Marshal(s)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("schema should be valid JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var parsed map[string]any
|
||||||
|
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||||
|
t.Fatalf("schema should round-trip through JSON: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
207
v2/llm.go
Normal file
207
v2/llm.go
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client represents an LLM provider. Create with OpenAI(), Anthropic(), Google().
|
||||||
|
type Client struct {
|
||||||
|
p provider.Provider
|
||||||
|
middleware []Middleware
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient creates a Client backed by the given provider.
|
||||||
|
// Use this to integrate custom provider implementations or for testing.
|
||||||
|
func NewClient(p provider.Provider) *Client {
|
||||||
|
return &Client{p: p}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model returns a Model for the specified model version.
|
||||||
|
func (c *Client) Model(modelVersion string) *Model {
|
||||||
|
return &Model{
|
||||||
|
provider: c.p,
|
||||||
|
model: modelVersion,
|
||||||
|
middleware: c.middleware,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMiddleware returns a new Client with additional middleware applied to all models.
|
||||||
|
func (c *Client) WithMiddleware(mw ...Middleware) *Client {
|
||||||
|
c2 := &Client{
|
||||||
|
p: c.p,
|
||||||
|
middleware: append(append([]Middleware{}, c.middleware...), mw...),
|
||||||
|
}
|
||||||
|
return c2
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model represents a specific model from a provider, ready for completions.
|
||||||
|
type Model struct {
|
||||||
|
provider provider.Provider
|
||||||
|
model string
|
||||||
|
middleware []Middleware
|
||||||
|
}
|
||||||
|
|
||||||
|
// Complete sends a non-streaming completion request.
|
||||||
|
func (m *Model) Complete(ctx context.Context, messages []Message, opts ...RequestOption) (Response, error) {
|
||||||
|
cfg := &requestConfig{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
chain := m.buildChain()
|
||||||
|
return chain(ctx, m.model, messages, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream sends a streaming completion request, returning a StreamReader.
|
||||||
|
func (m *Model) Stream(ctx context.Context, messages []Message, opts ...RequestOption) (*StreamReader, error) {
|
||||||
|
cfg := &requestConfig{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := buildProviderRequest(m.model, messages, cfg)
|
||||||
|
return newStreamReader(ctx, m.provider, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMiddleware returns a new Model with additional middleware applied.
|
||||||
|
func (m *Model) WithMiddleware(mw ...Middleware) *Model {
|
||||||
|
return &Model{
|
||||||
|
provider: m.provider,
|
||||||
|
model: m.model,
|
||||||
|
middleware: append(append([]Middleware{}, m.middleware...), mw...),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) buildChain() CompletionFunc {
|
||||||
|
// Base handler that calls the provider
|
||||||
|
base := func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||||
|
req := buildProviderRequest(model, messages, cfg)
|
||||||
|
resp, err := m.provider.Complete(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return Response{}, err
|
||||||
|
}
|
||||||
|
return convertProviderResponse(resp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply middleware in reverse order (first middleware wraps outermost)
|
||||||
|
chain := base
|
||||||
|
for i := len(m.middleware) - 1; i >= 0; i-- {
|
||||||
|
chain = m.middleware[i](chain)
|
||||||
|
}
|
||||||
|
return chain
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildProviderRequest(model string, messages []Message, cfg *requestConfig) provider.Request {
|
||||||
|
req := provider.Request{
|
||||||
|
Model: model,
|
||||||
|
Messages: convertMessages(messages),
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.temperature != nil {
|
||||||
|
req.Temperature = cfg.temperature
|
||||||
|
}
|
||||||
|
if cfg.maxTokens != nil {
|
||||||
|
req.MaxTokens = cfg.maxTokens
|
||||||
|
}
|
||||||
|
if cfg.topP != nil {
|
||||||
|
req.TopP = cfg.topP
|
||||||
|
}
|
||||||
|
if len(cfg.stop) > 0 {
|
||||||
|
req.Stop = cfg.stop
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.tools != nil {
|
||||||
|
for _, tool := range cfg.tools.AllTools() {
|
||||||
|
req.Tools = append(req.Tools, provider.ToolDef{
|
||||||
|
Name: tool.Name,
|
||||||
|
Description: tool.Description,
|
||||||
|
Schema: tool.Schema,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertMessages(msgs []Message) []provider.Message {
|
||||||
|
out := make([]provider.Message, len(msgs))
|
||||||
|
for i, m := range msgs {
|
||||||
|
pm := provider.Message{
|
||||||
|
Role: string(m.Role),
|
||||||
|
Content: m.Content.Text,
|
||||||
|
ToolCallID: m.ToolCallID,
|
||||||
|
}
|
||||||
|
for _, img := range m.Content.Images {
|
||||||
|
pm.Images = append(pm.Images, provider.Image{
|
||||||
|
URL: img.URL,
|
||||||
|
Base64: img.Base64,
|
||||||
|
ContentType: img.ContentType,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for _, aud := range m.Content.Audio {
|
||||||
|
pm.Audio = append(pm.Audio, provider.Audio{
|
||||||
|
URL: aud.URL,
|
||||||
|
Base64: aud.Base64,
|
||||||
|
ContentType: aud.ContentType,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for _, tc := range m.ToolCalls {
|
||||||
|
pm.ToolCalls = append(pm.ToolCalls, provider.ToolCall{
|
||||||
|
ID: tc.ID,
|
||||||
|
Name: tc.Name,
|
||||||
|
Arguments: tc.Arguments,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
out[i] = pm
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertProviderResponse(resp provider.Response) Response {
|
||||||
|
r := Response{
|
||||||
|
Text: resp.Text,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range resp.ToolCalls {
|
||||||
|
r.ToolCalls = append(r.ToolCalls, ToolCall{
|
||||||
|
ID: tc.ID,
|
||||||
|
Name: tc.Name,
|
||||||
|
Arguments: tc.Arguments,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Usage != nil {
|
||||||
|
r.Usage = &Usage{
|
||||||
|
InputTokens: resp.Usage.InputTokens,
|
||||||
|
OutputTokens: resp.Usage.OutputTokens,
|
||||||
|
TotalTokens: resp.Usage.TotalTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the assistant message for conversation history
|
||||||
|
r.message = Message{
|
||||||
|
Role: RoleAssistant,
|
||||||
|
Content: Content{Text: resp.Text},
|
||||||
|
ToolCalls: r.ToolCalls,
|
||||||
|
}
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Provider constructors ---
|
||||||
|
// These are defined here and delegate to provider-specific packages.
|
||||||
|
// They are set up via init() in the provider packages, or defined directly.
|
||||||
|
|
||||||
|
// ClientOption configures a client.
|
||||||
|
type ClientOption func(*clientConfig)
|
||||||
|
|
||||||
|
type clientConfig struct {
|
||||||
|
baseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBaseURL overrides the API base URL.
|
||||||
|
func WithBaseURL(url string) ClientOption {
|
||||||
|
return func(c *clientConfig) { c.baseURL = url }
|
||||||
|
}
|
||||||
264
v2/mcp.go
Normal file
264
v2/mcp.go
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MCPTransport specifies how to connect to an MCP server.
|
||||||
|
type MCPTransport string
|
||||||
|
|
||||||
|
const (
|
||||||
|
MCPStdio MCPTransport = "stdio"
|
||||||
|
MCPSSE MCPTransport = "sse"
|
||||||
|
MCPHTTP MCPTransport = "http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MCPServer represents a connection to an MCP server.
|
||||||
|
type MCPServer struct {
|
||||||
|
name string
|
||||||
|
transport MCPTransport
|
||||||
|
|
||||||
|
// stdio fields
|
||||||
|
command string
|
||||||
|
args []string
|
||||||
|
env []string
|
||||||
|
|
||||||
|
// network fields
|
||||||
|
url string
|
||||||
|
|
||||||
|
// internal
|
||||||
|
client *mcp.Client
|
||||||
|
session *mcp.ClientSession
|
||||||
|
tools map[string]*mcp.Tool
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// MCPOption configures an MCP server.
|
||||||
|
type MCPOption func(*MCPServer)
|
||||||
|
|
||||||
|
// WithMCPEnv adds environment variables for the subprocess.
|
||||||
|
func WithMCPEnv(env ...string) MCPOption {
|
||||||
|
return func(s *MCPServer) { s.env = env }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMCPName sets a friendly name for logging.
|
||||||
|
func WithMCPName(name string) MCPOption {
|
||||||
|
return func(s *MCPServer) { s.name = name }
|
||||||
|
}
|
||||||
|
|
||||||
|
// MCPStdioServer creates and connects to an MCP server via stdio transport.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// server, err := llm.MCPStdioServer(ctx, "npx", "-y", "@modelcontextprotocol/server-filesystem", "/tmp")
|
||||||
|
func MCPStdioServer(ctx context.Context, command string, args ...string) (*MCPServer, error) {
|
||||||
|
s := &MCPServer{
|
||||||
|
name: command,
|
||||||
|
transport: MCPStdio,
|
||||||
|
command: command,
|
||||||
|
args: args,
|
||||||
|
}
|
||||||
|
if err := s.connect(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MCPHTTPServer creates and connects to an MCP server via streamable HTTP transport.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// server, err := llm.MCPHTTPServer(ctx, "https://mcp.example.com")
|
||||||
|
func MCPHTTPServer(ctx context.Context, url string, opts ...MCPOption) (*MCPServer, error) {
|
||||||
|
s := &MCPServer{
|
||||||
|
name: url,
|
||||||
|
transport: MCPHTTP,
|
||||||
|
url: url,
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(s)
|
||||||
|
}
|
||||||
|
if err := s.connect(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MCPSSEServer creates and connects to an MCP server via SSE transport.
|
||||||
|
func MCPSSEServer(ctx context.Context, url string, opts ...MCPOption) (*MCPServer, error) {
|
||||||
|
s := &MCPServer{
|
||||||
|
name: url,
|
||||||
|
transport: MCPSSE,
|
||||||
|
url: url,
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(s)
|
||||||
|
}
|
||||||
|
if err := s.connect(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MCPServer) connect(ctx context.Context) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.session != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.client = mcp.NewClient(&mcp.Implementation{
|
||||||
|
Name: "go-llm-v2",
|
||||||
|
Version: "2.0.0",
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
var transport mcp.Transport
|
||||||
|
|
||||||
|
switch s.transport {
|
||||||
|
case MCPSSE:
|
||||||
|
transport = &mcp.SSEClientTransport{
|
||||||
|
Endpoint: s.url,
|
||||||
|
}
|
||||||
|
case MCPHTTP:
|
||||||
|
transport = &mcp.StreamableClientTransport{
|
||||||
|
Endpoint: s.url,
|
||||||
|
}
|
||||||
|
default: // stdio
|
||||||
|
cmd := exec.Command(s.command, s.args...)
|
||||||
|
cmd.Env = append(os.Environ(), s.env...)
|
||||||
|
transport = &mcp.CommandTransport{
|
||||||
|
Command: cmd,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := s.client.Connect(ctx, transport, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to MCP server %s: %w", s.name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.session = session
|
||||||
|
|
||||||
|
// Load tools
|
||||||
|
s.tools = make(map[string]*mcp.Tool)
|
||||||
|
for tool, err := range session.Tools(ctx, nil) {
|
||||||
|
if err != nil {
|
||||||
|
s.session.Close()
|
||||||
|
s.session = nil
|
||||||
|
return fmt.Errorf("failed to list tools from %s: %w", s.name, err)
|
||||||
|
}
|
||||||
|
s.tools[tool.Name] = tool
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the connection to the MCP server.
|
||||||
|
func (s *MCPServer) Close() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.session == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.session.Close()
|
||||||
|
s.session = nil
|
||||||
|
s.tools = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsConnected returns true if the server is connected.
|
||||||
|
func (s *MCPServer) IsConnected() bool {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.session != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListTools returns Tool definitions for all tools this server provides.
|
||||||
|
func (s *MCPServer) ListTools() []Tool {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
var tools []Tool
|
||||||
|
for _, t := range s.tools {
|
||||||
|
tools = append(tools, s.toTool(t))
|
||||||
|
}
|
||||||
|
return tools
|
||||||
|
}
|
||||||
|
|
||||||
|
// CallTool invokes a tool on the server.
|
||||||
|
func (s *MCPServer) CallTool(ctx context.Context, name string, arguments map[string]any) (string, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
session := s.session
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
|
if session == nil {
|
||||||
|
return "", fmt.Errorf("%w: %s", ErrNotConnected, s.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := session.CallTool(ctx, &mcp.CallToolParams{
|
||||||
|
Name: name,
|
||||||
|
Arguments: arguments,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Content) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return contentToString(result.Content), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MCPServer) toTool(t *mcp.Tool) Tool {
|
||||||
|
var inputSchema map[string]any
|
||||||
|
if t.InputSchema != nil {
|
||||||
|
data, err := json.Marshal(t.InputSchema)
|
||||||
|
if err == nil {
|
||||||
|
_ = json.Unmarshal(data, &inputSchema)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if inputSchema == nil {
|
||||||
|
inputSchema = map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Tool{
|
||||||
|
Name: t.Name,
|
||||||
|
Description: t.Description,
|
||||||
|
Schema: inputSchema,
|
||||||
|
isMCP: true,
|
||||||
|
mcpServer: s,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func contentToString(content []mcp.Content) string {
|
||||||
|
var parts []string
|
||||||
|
for _, c := range content {
|
||||||
|
switch tc := c.(type) {
|
||||||
|
case *mcp.TextContent:
|
||||||
|
parts = append(parts, tc.Text)
|
||||||
|
default:
|
||||||
|
if data, err := json.Marshal(c); err == nil {
|
||||||
|
parts = append(parts, string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(parts) == 1 {
|
||||||
|
return parts[0]
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(parts)
|
||||||
|
return string(data)
|
||||||
|
}
|
||||||
87
v2/message.go
Normal file
87
v2/message.go
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
// Role represents who authored a message.
|
||||||
|
type Role string
|
||||||
|
|
||||||
|
const (
|
||||||
|
RoleSystem Role = "system"
|
||||||
|
RoleUser Role = "user"
|
||||||
|
RoleAssistant Role = "assistant"
|
||||||
|
RoleTool Role = "tool"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Image represents an image attachment.
|
||||||
|
type Image struct {
|
||||||
|
// Provide exactly one of URL or Base64.
|
||||||
|
URL string // HTTP(S) URL
|
||||||
|
Base64 string // Raw base64-encoded data
|
||||||
|
ContentType string // MIME type (e.g., "image/png"), required for Base64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Audio represents an audio attachment.
|
||||||
|
type Audio struct {
|
||||||
|
// Provide exactly one of URL or Base64.
|
||||||
|
URL string // HTTP(S) URL to audio file
|
||||||
|
Base64 string // Raw base64-encoded audio data
|
||||||
|
ContentType string // MIME type (e.g., "audio/wav", "audio/mp3")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Content represents message content with optional text, images, and audio.
|
||||||
|
type Content struct {
|
||||||
|
Text string
|
||||||
|
Images []Image
|
||||||
|
Audio []Audio
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolCall represents a tool invocation requested by the assistant.
|
||||||
|
type ToolCall struct {
|
||||||
|
ID string
|
||||||
|
Name string
|
||||||
|
Arguments string // raw JSON
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message represents a single message in a conversation.
|
||||||
|
type Message struct {
|
||||||
|
Role Role
|
||||||
|
Content Content
|
||||||
|
|
||||||
|
// ToolCallID is set when Role == RoleTool, identifying which tool call this responds to.
|
||||||
|
ToolCallID string
|
||||||
|
|
||||||
|
// ToolCalls is set when the assistant requests tool invocations.
|
||||||
|
ToolCalls []ToolCall
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserMessage creates a user message with text content.
|
||||||
|
func UserMessage(text string) Message {
|
||||||
|
return Message{Role: RoleUser, Content: Content{Text: text}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserMessageWithImages creates a user message with text and images.
|
||||||
|
func UserMessageWithImages(text string, images ...Image) Message {
|
||||||
|
return Message{Role: RoleUser, Content: Content{Text: text, Images: images}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserMessageWithAudio creates a user message with text and audio attachments.
|
||||||
|
func UserMessageWithAudio(text string, audio ...Audio) Message {
|
||||||
|
return Message{Role: RoleUser, Content: Content{Text: text, Audio: audio}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SystemMessage creates a system prompt message.
|
||||||
|
func SystemMessage(text string) Message {
|
||||||
|
return Message{Role: RoleSystem, Content: Content{Text: text}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssistantMessage creates an assistant message with text content.
|
||||||
|
func AssistantMessage(text string) Message {
|
||||||
|
return Message{Role: RoleAssistant, Content: Content{Text: text}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolResultMessage creates a tool result message.
|
||||||
|
func ToolResultMessage(toolCallID string, result string) Message {
|
||||||
|
return Message{
|
||||||
|
Role: RoleTool,
|
||||||
|
Content: Content{Text: result},
|
||||||
|
ToolCallID: toolCallID,
|
||||||
|
}
|
||||||
|
}
|
||||||
212
v2/message_test.go
Normal file
212
v2/message_test.go
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUserMessage(t *testing.T) {
|
||||||
|
msg := UserMessage("hello")
|
||||||
|
if msg.Role != RoleUser {
|
||||||
|
t.Errorf("expected role=user, got %v", msg.Role)
|
||||||
|
}
|
||||||
|
if msg.Content.Text != "hello" {
|
||||||
|
t.Errorf("expected text='hello', got %q", msg.Content.Text)
|
||||||
|
}
|
||||||
|
if len(msg.Content.Images) != 0 {
|
||||||
|
t.Errorf("expected no images, got %d", len(msg.Content.Images))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserMessageWithImages(t *testing.T) {
|
||||||
|
img1 := Image{URL: "https://example.com/1.png"}
|
||||||
|
img2 := Image{Base64: "abc123", ContentType: "image/png"}
|
||||||
|
|
||||||
|
msg := UserMessageWithImages("describe", img1, img2)
|
||||||
|
if msg.Role != RoleUser {
|
||||||
|
t.Errorf("expected role=user, got %v", msg.Role)
|
||||||
|
}
|
||||||
|
if msg.Content.Text != "describe" {
|
||||||
|
t.Errorf("expected text='describe', got %q", msg.Content.Text)
|
||||||
|
}
|
||||||
|
if len(msg.Content.Images) != 2 {
|
||||||
|
t.Fatalf("expected 2 images, got %d", len(msg.Content.Images))
|
||||||
|
}
|
||||||
|
if msg.Content.Images[0].URL != "https://example.com/1.png" {
|
||||||
|
t.Errorf("expected image[0] URL, got %q", msg.Content.Images[0].URL)
|
||||||
|
}
|
||||||
|
if msg.Content.Images[1].Base64 != "abc123" {
|
||||||
|
t.Errorf("expected image[1] base64='abc123', got %q", msg.Content.Images[1].Base64)
|
||||||
|
}
|
||||||
|
if msg.Content.Images[1].ContentType != "image/png" {
|
||||||
|
t.Errorf("expected image[1] contentType='image/png', got %q", msg.Content.Images[1].ContentType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSystemMessage(t *testing.T) {
|
||||||
|
msg := SystemMessage("Be helpful")
|
||||||
|
if msg.Role != RoleSystem {
|
||||||
|
t.Errorf("expected role=system, got %v", msg.Role)
|
||||||
|
}
|
||||||
|
if msg.Content.Text != "Be helpful" {
|
||||||
|
t.Errorf("expected text='Be helpful', got %q", msg.Content.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAssistantMessage(t *testing.T) {
|
||||||
|
msg := AssistantMessage("Sure thing")
|
||||||
|
if msg.Role != RoleAssistant {
|
||||||
|
t.Errorf("expected role=assistant, got %v", msg.Role)
|
||||||
|
}
|
||||||
|
if msg.Content.Text != "Sure thing" {
|
||||||
|
t.Errorf("expected text='Sure thing', got %q", msg.Content.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToolResultMessage(t *testing.T) {
|
||||||
|
msg := ToolResultMessage("tc-123", "result data")
|
||||||
|
if msg.Role != RoleTool {
|
||||||
|
t.Errorf("expected role=tool, got %v", msg.Role)
|
||||||
|
}
|
||||||
|
if msg.ToolCallID != "tc-123" {
|
||||||
|
t.Errorf("expected toolCallID='tc-123', got %q", msg.ToolCallID)
|
||||||
|
}
|
||||||
|
if msg.Content.Text != "result data" {
|
||||||
|
t.Errorf("expected text='result data', got %q", msg.Content.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertMessages(t *testing.T) {
|
||||||
|
msgs := []Message{
|
||||||
|
SystemMessage("system prompt"),
|
||||||
|
UserMessageWithImages("look at this", Image{URL: "https://example.com/img.png"}),
|
||||||
|
{
|
||||||
|
Role: RoleAssistant,
|
||||||
|
Content: Content{Text: "I'll use a tool"},
|
||||||
|
ToolCalls: []ToolCall{
|
||||||
|
{ID: "tc1", Name: "search", Arguments: `{"q":"test"}`},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ToolResultMessage("tc1", "found it"),
|
||||||
|
}
|
||||||
|
|
||||||
|
converted := convertMessages(msgs)
|
||||||
|
|
||||||
|
if len(converted) != 4 {
|
||||||
|
t.Fatalf("expected 4 converted messages, got %d", len(converted))
|
||||||
|
}
|
||||||
|
|
||||||
|
// System message
|
||||||
|
if converted[0].Role != "system" {
|
||||||
|
t.Errorf("msg[0]: expected role='system', got %q", converted[0].Role)
|
||||||
|
}
|
||||||
|
if converted[0].Content != "system prompt" {
|
||||||
|
t.Errorf("msg[0]: expected content='system prompt', got %q", converted[0].Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// User message with images
|
||||||
|
if converted[1].Role != "user" {
|
||||||
|
t.Errorf("msg[1]: expected role='user', got %q", converted[1].Role)
|
||||||
|
}
|
||||||
|
if len(converted[1].Images) != 1 {
|
||||||
|
t.Fatalf("msg[1]: expected 1 image, got %d", len(converted[1].Images))
|
||||||
|
}
|
||||||
|
if converted[1].Images[0].URL != "https://example.com/img.png" {
|
||||||
|
t.Errorf("msg[1]: expected image URL, got %q", converted[1].Images[0].URL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assistant message with tool calls
|
||||||
|
if converted[2].Role != "assistant" {
|
||||||
|
t.Errorf("msg[2]: expected role='assistant', got %q", converted[2].Role)
|
||||||
|
}
|
||||||
|
if len(converted[2].ToolCalls) != 1 {
|
||||||
|
t.Fatalf("msg[2]: expected 1 tool call, got %d", len(converted[2].ToolCalls))
|
||||||
|
}
|
||||||
|
if converted[2].ToolCalls[0].ID != "tc1" {
|
||||||
|
t.Errorf("msg[2]: expected tool call ID='tc1', got %q", converted[2].ToolCalls[0].ID)
|
||||||
|
}
|
||||||
|
if converted[2].ToolCalls[0].Name != "search" {
|
||||||
|
t.Errorf("msg[2]: expected tool call name='search', got %q", converted[2].ToolCalls[0].Name)
|
||||||
|
}
|
||||||
|
if converted[2].ToolCalls[0].Arguments != `{"q":"test"}` {
|
||||||
|
t.Errorf("msg[2]: expected tool call arguments, got %q", converted[2].ToolCalls[0].Arguments)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool result message
|
||||||
|
if converted[3].Role != "tool" {
|
||||||
|
t.Errorf("msg[3]: expected role='tool', got %q", converted[3].Role)
|
||||||
|
}
|
||||||
|
if converted[3].ToolCallID != "tc1" {
|
||||||
|
t.Errorf("msg[3]: expected toolCallID='tc1', got %q", converted[3].ToolCallID)
|
||||||
|
}
|
||||||
|
if converted[3].Content != "found it" {
|
||||||
|
t.Errorf("msg[3]: expected content='found it', got %q", converted[3].Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertProviderResponse(t *testing.T) {
|
||||||
|
t.Run("text only", func(t *testing.T) {
|
||||||
|
resp := convertProviderResponse(provider.Response{
|
||||||
|
Text: "hello",
|
||||||
|
Usage: &provider.Usage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 5,
|
||||||
|
TotalTokens: 15,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if resp.Text != "hello" {
|
||||||
|
t.Errorf("expected text='hello', got %q", resp.Text)
|
||||||
|
}
|
||||||
|
if resp.HasToolCalls() {
|
||||||
|
t.Error("expected no tool calls")
|
||||||
|
}
|
||||||
|
if resp.Usage == nil {
|
||||||
|
t.Fatal("expected usage")
|
||||||
|
}
|
||||||
|
if resp.Usage.InputTokens != 10 {
|
||||||
|
t.Errorf("expected 10 input tokens, got %d", resp.Usage.InputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := resp.Message()
|
||||||
|
if msg.Role != RoleAssistant {
|
||||||
|
t.Errorf("expected role=assistant, got %v", msg.Role)
|
||||||
|
}
|
||||||
|
if msg.Content.Text != "hello" {
|
||||||
|
t.Errorf("expected message text='hello', got %q", msg.Content.Text)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with tool calls", func(t *testing.T) {
|
||||||
|
resp := convertProviderResponse(provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{ID: "tc1", Name: "search", Arguments: `{"q":"go"}`},
|
||||||
|
{ID: "tc2", Name: "calc", Arguments: `{"a":1}`},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if !resp.HasToolCalls() {
|
||||||
|
t.Fatal("expected tool calls")
|
||||||
|
}
|
||||||
|
if len(resp.ToolCalls) != 2 {
|
||||||
|
t.Fatalf("expected 2 tool calls, got %d", len(resp.ToolCalls))
|
||||||
|
}
|
||||||
|
if resp.ToolCalls[0].ID != "tc1" || resp.ToolCalls[0].Name != "search" {
|
||||||
|
t.Errorf("unexpected tool call[0]: %+v", resp.ToolCalls[0])
|
||||||
|
}
|
||||||
|
if resp.ToolCalls[1].ID != "tc2" || resp.ToolCalls[1].Name != "calc" {
|
||||||
|
t.Errorf("unexpected tool call[1]: %+v", resp.ToolCalls[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := resp.Message()
|
||||||
|
if len(msg.ToolCalls) != 2 {
|
||||||
|
t.Errorf("expected 2 tool calls in message, got %d", len(msg.ToolCalls))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil usage", func(t *testing.T) {
|
||||||
|
resp := convertProviderResponse(provider.Response{Text: "ok"})
|
||||||
|
if resp.Usage != nil {
|
||||||
|
t.Errorf("expected nil usage, got %+v", resp.Usage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
117
v2/middleware.go
Normal file
117
v2/middleware.go
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CompletionFunc is the signature for the completion call chain.
|
||||||
|
type CompletionFunc func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error)
|
||||||
|
|
||||||
|
// Middleware wraps a completion call. It receives the next handler in the chain
|
||||||
|
// and returns a new handler that can inspect/modify the request and response.
|
||||||
|
type Middleware func(next CompletionFunc) CompletionFunc
|
||||||
|
|
||||||
|
// WithLogging returns middleware that logs requests and responses via slog.
|
||||||
|
func WithLogging(logger *slog.Logger) Middleware {
|
||||||
|
return func(next CompletionFunc) CompletionFunc {
|
||||||
|
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||||
|
logger.Info("llm request",
|
||||||
|
"model", model,
|
||||||
|
"message_count", len(messages),
|
||||||
|
)
|
||||||
|
start := time.Now()
|
||||||
|
resp, err := next(ctx, model, messages, cfg)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("llm error", "model", model, "elapsed", elapsed, "error", err)
|
||||||
|
} else {
|
||||||
|
logger.Info("llm response",
|
||||||
|
"model", model,
|
||||||
|
"elapsed", elapsed,
|
||||||
|
"text_len", len(resp.Text),
|
||||||
|
"tool_calls", len(resp.ToolCalls),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRetry returns middleware that retries failed requests with configurable backoff.
|
||||||
|
func WithRetry(maxRetries int, backoff func(attempt int) time.Duration) Middleware {
|
||||||
|
return func(next CompletionFunc) CompletionFunc {
|
||||||
|
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||||
|
var lastErr error
|
||||||
|
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||||
|
if attempt > 0 {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return Response{}, ctx.Err()
|
||||||
|
case <-time.After(backoff(attempt)):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resp, err := next(ctx, model, messages, cfg)
|
||||||
|
if err == nil {
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
lastErr = err
|
||||||
|
}
|
||||||
|
return Response{}, fmt.Errorf("after %d retries: %w", maxRetries, lastErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTimeout returns middleware that enforces a per-request timeout.
|
||||||
|
func WithTimeout(d time.Duration) Middleware {
|
||||||
|
return func(next CompletionFunc) CompletionFunc {
|
||||||
|
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, d)
|
||||||
|
defer cancel()
|
||||||
|
return next(ctx, model, messages, cfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UsageTracker accumulates token usage statistics across calls.
|
||||||
|
type UsageTracker struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
TotalInput int64
|
||||||
|
TotalOutput int64
|
||||||
|
TotalRequests int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add records usage from a single request.
|
||||||
|
func (ut *UsageTracker) Add(u *Usage) {
|
||||||
|
if u == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ut.mu.Lock()
|
||||||
|
defer ut.mu.Unlock()
|
||||||
|
ut.TotalInput += int64(u.InputTokens)
|
||||||
|
ut.TotalOutput += int64(u.OutputTokens)
|
||||||
|
ut.TotalRequests++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Summary returns the accumulated totals.
|
||||||
|
func (ut *UsageTracker) Summary() (input, output, requests int64) {
|
||||||
|
ut.mu.Lock()
|
||||||
|
defer ut.mu.Unlock()
|
||||||
|
return ut.TotalInput, ut.TotalOutput, ut.TotalRequests
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithUsageTracking returns middleware that accumulates token usage across calls.
|
||||||
|
func WithUsageTracking(tracker *UsageTracker) Middleware {
|
||||||
|
return func(next CompletionFunc) CompletionFunc {
|
||||||
|
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||||
|
resp, err := next(ctx, model, messages, cfg)
|
||||||
|
if err == nil {
|
||||||
|
tracker.Add(resp.Usage)
|
||||||
|
}
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
282
v2/middleware_test.go
Normal file
282
v2/middleware_test.go
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"log/slog"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWithRetry_Success(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||||
|
model := newMockModel(mp).WithMiddleware(
|
||||||
|
WithRetry(3, func(attempt int) time.Duration { return time.Millisecond }),
|
||||||
|
)
|
||||||
|
|
||||||
|
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Text != "ok" {
|
||||||
|
t.Errorf("expected 'ok', got %q", resp.Text)
|
||||||
|
}
|
||||||
|
if len(mp.Requests) != 1 {
|
||||||
|
t.Errorf("expected 1 request (no retries needed), got %d", len(mp.Requests))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithRetry_EventualSuccess(t *testing.T) {
|
||||||
|
var callCount int32
|
||||||
|
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
n := atomic.AddInt32(&callCount, 1)
|
||||||
|
if n <= 2 {
|
||||||
|
return provider.Response{}, errors.New("transient error")
|
||||||
|
}
|
||||||
|
return provider.Response{Text: "success"}, nil
|
||||||
|
})
|
||||||
|
model := newMockModel(mp).WithMiddleware(
|
||||||
|
WithRetry(3, func(attempt int) time.Duration { return time.Millisecond }),
|
||||||
|
)
|
||||||
|
|
||||||
|
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Text != "success" {
|
||||||
|
t.Errorf("expected 'success', got %q", resp.Text)
|
||||||
|
}
|
||||||
|
if atomic.LoadInt32(&callCount) != 3 {
|
||||||
|
t.Errorf("expected 3 calls, got %d", callCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithRetry_AllFail(t *testing.T) {
|
||||||
|
providerErr := errors.New("persistent error")
|
||||||
|
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
return provider.Response{}, providerErr
|
||||||
|
})
|
||||||
|
model := newMockModel(mp).WithMiddleware(
|
||||||
|
WithRetry(2, func(attempt int) time.Duration { return time.Millisecond }),
|
||||||
|
)
|
||||||
|
|
||||||
|
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, providerErr) {
|
||||||
|
t.Errorf("expected wrapped persistent error, got %v", err)
|
||||||
|
}
|
||||||
|
if len(mp.Requests) != 3 {
|
||||||
|
t.Errorf("expected 3 requests (1 initial + 2 retries), got %d", len(mp.Requests))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithRetry_ContextCancelled(t *testing.T) {
|
||||||
|
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
return provider.Response{}, errors.New("fail")
|
||||||
|
})
|
||||||
|
model := newMockModel(mp).WithMiddleware(
|
||||||
|
WithRetry(10, func(attempt int) time.Duration { return 5 * time.Second }),
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
// Cancel after a short delay
|
||||||
|
go func() {
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := model.Complete(ctx, []Message{UserMessage("test")})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Errorf("expected context.Canceled, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithTimeout(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{Text: "fast"})
|
||||||
|
model := newMockModel(mp).WithMiddleware(WithTimeout(5 * time.Second))
|
||||||
|
|
||||||
|
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Text != "fast" {
|
||||||
|
t.Errorf("expected 'fast', got %q", resp.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithTimeout_Exceeded(t *testing.T) {
|
||||||
|
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return provider.Response{}, ctx.Err()
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
return provider.Response{Text: "slow"}, nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
model := newMockModel(mp).WithMiddleware(WithTimeout(50 * time.Millisecond))
|
||||||
|
|
||||||
|
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
t.Errorf("expected DeadlineExceeded, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithUsageTracking(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{
|
||||||
|
Text: "ok",
|
||||||
|
Usage: &provider.Usage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 5,
|
||||||
|
TotalTokens: 15,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
tracker := &UsageTracker{}
|
||||||
|
model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker))
|
||||||
|
|
||||||
|
// Make two requests
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error on call %d: %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
input, output, requests := tracker.Summary()
|
||||||
|
if input != 20 {
|
||||||
|
t.Errorf("expected total input 20, got %d", input)
|
||||||
|
}
|
||||||
|
if output != 10 {
|
||||||
|
t.Errorf("expected total output 10, got %d", output)
|
||||||
|
}
|
||||||
|
if requests != 2 {
|
||||||
|
t.Errorf("expected 2 requests, got %d", requests)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithUsageTracking_NilUsage(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{Text: "no usage"})
|
||||||
|
tracker := &UsageTracker{}
|
||||||
|
model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker))
|
||||||
|
|
||||||
|
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
input, output, requests := tracker.Summary()
|
||||||
|
if input != 0 || output != 0 {
|
||||||
|
t.Errorf("expected 0 tokens with nil usage, got input=%d output=%d", input, output)
|
||||||
|
}
|
||||||
|
// Add(nil) returns early without incrementing TotalRequests
|
||||||
|
if requests != 0 {
|
||||||
|
t.Errorf("expected 0 requests (nil usage skips Add), got %d", requests)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageTracker_Concurrent(t *testing.T) {
|
||||||
|
tracker := &UsageTracker{}
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
tracker.Add(&Usage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 5,
|
||||||
|
TotalTokens: 15,
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
input, output, requests := tracker.Summary()
|
||||||
|
if input != 1000 {
|
||||||
|
t.Errorf("expected total input 1000, got %d", input)
|
||||||
|
}
|
||||||
|
if output != 500 {
|
||||||
|
t.Errorf("expected total output 500, got %d", output)
|
||||||
|
}
|
||||||
|
if requests != 100 {
|
||||||
|
t.Errorf("expected 100 requests, got %d", requests)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMiddleware_Chaining(t *testing.T) {
|
||||||
|
var order []string
|
||||||
|
mw1 := func(next CompletionFunc) CompletionFunc {
|
||||||
|
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||||
|
order = append(order, "mw1-before")
|
||||||
|
resp, err := next(ctx, model, messages, cfg)
|
||||||
|
order = append(order, "mw1-after")
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mw2 := func(next CompletionFunc) CompletionFunc {
|
||||||
|
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||||
|
order = append(order, "mw2-before")
|
||||||
|
resp, err := next(ctx, model, messages, cfg)
|
||||||
|
order = append(order, "mw2-after")
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||||
|
model := newMockModel(mp).WithMiddleware(mw1, mw2)
|
||||||
|
|
||||||
|
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []string{"mw1-before", "mw2-before", "mw2-after", "mw1-after"}
|
||||||
|
if len(order) != len(expected) {
|
||||||
|
t.Fatalf("expected %d middleware calls, got %d: %v", len(expected), len(order), order)
|
||||||
|
}
|
||||||
|
for i, v := range expected {
|
||||||
|
if order[i] != v {
|
||||||
|
t.Errorf("order[%d]: expected %q, got %q", i, v, order[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithLogging(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{Text: "logged"})
|
||||||
|
logger := slog.Default()
|
||||||
|
model := newMockModel(mp).WithMiddleware(WithLogging(logger))
|
||||||
|
|
||||||
|
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Text != "logged" {
|
||||||
|
t.Errorf("expected 'logged', got %q", resp.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithLogging_Error(t *testing.T) {
|
||||||
|
providerErr := errors.New("log this error")
|
||||||
|
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
return provider.Response{}, providerErr
|
||||||
|
})
|
||||||
|
logger := slog.Default()
|
||||||
|
model := newMockModel(mp).WithMiddleware(WithLogging(logger))
|
||||||
|
|
||||||
|
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if !errors.Is(err, providerErr) {
|
||||||
|
t.Errorf("expected provider error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
87
v2/mock_provider_test.go
Normal file
87
v2/mock_provider_test.go
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockProvider is a configurable mock implementation of provider.Provider for testing.
|
||||||
|
type mockProvider struct {
|
||||||
|
CompleteFunc func(ctx context.Context, req provider.Request) (provider.Response, error)
|
||||||
|
StreamFunc func(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error
|
||||||
|
|
||||||
|
// mu guards Requests
|
||||||
|
mu sync.Mutex
|
||||||
|
Requests []provider.Request
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProvider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
m.Requests = append(m.Requests, req)
|
||||||
|
m.mu.Unlock()
|
||||||
|
return m.CompleteFunc(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProvider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
m.Requests = append(m.Requests, req)
|
||||||
|
m.mu.Unlock()
|
||||||
|
if m.StreamFunc != nil {
|
||||||
|
return m.StreamFunc(ctx, req, events)
|
||||||
|
}
|
||||||
|
close(events)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// lastRequest returns the most recent request recorded by the mock.
|
||||||
|
func (m *mockProvider) lastRequest() provider.Request {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if len(m.Requests) == 0 {
|
||||||
|
return provider.Request{}
|
||||||
|
}
|
||||||
|
return m.Requests[len(m.Requests)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// newMockProvider creates a mock that always returns the given response.
|
||||||
|
func newMockProvider(resp provider.Response) *mockProvider {
|
||||||
|
return &mockProvider{
|
||||||
|
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
return resp, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newMockProviderFunc creates a mock with a custom Complete function.
|
||||||
|
func newMockProviderFunc(fn func(ctx context.Context, req provider.Request) (provider.Response, error)) *mockProvider {
|
||||||
|
return &mockProvider{CompleteFunc: fn}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newMockStreamProvider creates a mock that streams the given events.
|
||||||
|
func newMockStreamProvider(events []provider.StreamEvent) *mockProvider {
|
||||||
|
return &mockProvider{
|
||||||
|
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
return provider.Response{}, nil
|
||||||
|
},
|
||||||
|
StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error {
|
||||||
|
for _, ev := range events {
|
||||||
|
select {
|
||||||
|
case ch <- ev:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newMockModel creates a *Model backed by the given mock provider.
|
||||||
|
func newMockModel(p *mockProvider) *Model {
|
||||||
|
return &Model{
|
||||||
|
provider: p,
|
||||||
|
model: "mock-model",
|
||||||
|
}
|
||||||
|
}
|
||||||
215
v2/model_test.go
Normal file
215
v2/model_test.go
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestModel_Complete(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{
|
||||||
|
Text: "Hello!",
|
||||||
|
Usage: &provider.Usage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 5,
|
||||||
|
TotalTokens: 15,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
resp, err := model.Complete(context.Background(), []Message{UserMessage("Hi")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Text != "Hello!" {
|
||||||
|
t.Errorf("expected text 'Hello!', got %q", resp.Text)
|
||||||
|
}
|
||||||
|
if resp.Usage == nil {
|
||||||
|
t.Fatal("expected usage, got nil")
|
||||||
|
}
|
||||||
|
if resp.Usage.InputTokens != 10 {
|
||||||
|
t.Errorf("expected input tokens 10, got %d", resp.Usage.InputTokens)
|
||||||
|
}
|
||||||
|
if resp.Usage.OutputTokens != 5 {
|
||||||
|
t.Errorf("expected output tokens 5, got %d", resp.Usage.OutputTokens)
|
||||||
|
}
|
||||||
|
if resp.Usage.TotalTokens != 15 {
|
||||||
|
t.Errorf("expected total tokens 15, got %d", resp.Usage.TotalTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModel_Complete_WithOptions(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
temp := 0.7
|
||||||
|
maxTok := 100
|
||||||
|
topP := 0.9
|
||||||
|
|
||||||
|
_, err := model.Complete(context.Background(), []Message{UserMessage("test")},
|
||||||
|
WithTemperature(temp),
|
||||||
|
WithMaxTokens(maxTok),
|
||||||
|
WithTopP(topP),
|
||||||
|
WithStop("STOP", "END"),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := mp.lastRequest()
|
||||||
|
if req.Temperature == nil || *req.Temperature != temp {
|
||||||
|
t.Errorf("expected temperature %v, got %v", temp, req.Temperature)
|
||||||
|
}
|
||||||
|
if req.MaxTokens == nil || *req.MaxTokens != maxTok {
|
||||||
|
t.Errorf("expected maxTokens %v, got %v", maxTok, req.MaxTokens)
|
||||||
|
}
|
||||||
|
if req.TopP == nil || *req.TopP != topP {
|
||||||
|
t.Errorf("expected topP %v, got %v", topP, req.TopP)
|
||||||
|
}
|
||||||
|
if len(req.Stop) != 2 || req.Stop[0] != "STOP" || req.Stop[1] != "END" {
|
||||||
|
t.Errorf("expected stop [STOP END], got %v", req.Stop)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModel_Complete_Error(t *testing.T) {
|
||||||
|
wantErr := errors.New("provider error")
|
||||||
|
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
return provider.Response{}, wantErr
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
_, err := model.Complete(context.Background(), []Message{UserMessage("Hi")})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, wantErr) {
|
||||||
|
t.Errorf("expected error %v, got %v", wantErr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModel_Complete_WithTools(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{Text: "done"})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) {
|
||||||
|
return "hello", nil
|
||||||
|
})
|
||||||
|
tb := NewToolBox(tool)
|
||||||
|
|
||||||
|
_, err := model.Complete(context.Background(), []Message{UserMessage("test")}, WithTools(tb))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := mp.lastRequest()
|
||||||
|
if len(req.Tools) != 1 {
|
||||||
|
t.Fatalf("expected 1 tool, got %d", len(req.Tools))
|
||||||
|
}
|
||||||
|
if req.Tools[0].Name != "greet" {
|
||||||
|
t.Errorf("expected tool name 'greet', got %q", req.Tools[0].Name)
|
||||||
|
}
|
||||||
|
if req.Tools[0].Description != "Says hello" {
|
||||||
|
t.Errorf("expected tool description 'Says hello', got %q", req.Tools[0].Description)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_Model(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{Text: "hi"})
|
||||||
|
client := NewClient(mp)
|
||||||
|
model := client.Model("test-model")
|
||||||
|
|
||||||
|
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Text != "hi" {
|
||||||
|
t.Errorf("expected 'hi', got %q", resp.Text)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := mp.lastRequest()
|
||||||
|
if req.Model != "test-model" {
|
||||||
|
t.Errorf("expected model 'test-model', got %q", req.Model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_WithMiddleware(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
mw := func(next CompletionFunc) CompletionFunc {
|
||||||
|
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||||
|
called = true
|
||||||
|
return next(ctx, model, messages, cfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||||
|
client := NewClient(mp).WithMiddleware(mw)
|
||||||
|
model := client.Model("test-model")
|
||||||
|
|
||||||
|
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !called {
|
||||||
|
t.Error("middleware was not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModel_WithMiddleware(t *testing.T) {
|
||||||
|
var order []string
|
||||||
|
mw1 := func(next CompletionFunc) CompletionFunc {
|
||||||
|
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||||
|
order = append(order, "mw1")
|
||||||
|
return next(ctx, model, messages, cfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mw2 := func(next CompletionFunc) CompletionFunc {
|
||||||
|
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||||
|
order = append(order, "mw2")
|
||||||
|
return next(ctx, model, messages, cfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||||
|
model := newMockModel(mp).WithMiddleware(mw1).WithMiddleware(mw2)
|
||||||
|
|
||||||
|
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(order) != 2 || order[0] != "mw1" || order[1] != "mw2" {
|
||||||
|
t.Errorf("expected middleware order [mw1 mw2], got %v", order)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModel_Complete_NoUsage(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{Text: "no usage"})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Usage != nil {
|
||||||
|
t.Errorf("expected nil usage, got %+v", resp.Usage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModel_Complete_ResponseMessage(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{Text: "response text"})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := resp.Message()
|
||||||
|
if msg.Role != RoleAssistant {
|
||||||
|
t.Errorf("expected role assistant, got %v", msg.Role)
|
||||||
|
}
|
||||||
|
if msg.Content.Text != "response text" {
|
||||||
|
t.Errorf("expected text 'response text', got %q", msg.Content.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
395
v2/openai/openai.go
Normal file
395
v2/openai/openai.go
Normal file
@@ -0,0 +1,395 @@
|
|||||||
|
// Package openai implements the go-llm v2 provider interface for OpenAI.
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"path"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/openai/openai-go"
|
||||||
|
"github.com/openai/openai-go/option"
|
||||||
|
"github.com/openai/openai-go/packages/param"
|
||||||
|
"github.com/openai/openai-go/shared"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Provider implements the provider.Provider interface for OpenAI.
|
||||||
|
type Provider struct {
|
||||||
|
apiKey string
|
||||||
|
baseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new OpenAI provider.
|
||||||
|
func New(apiKey string, baseURL string) *Provider {
|
||||||
|
return &Provider{apiKey: apiKey, baseURL: baseURL}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Complete performs a non-streaming completion.
|
||||||
|
func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
var opts []option.RequestOption
|
||||||
|
opts = append(opts, option.WithAPIKey(p.apiKey))
|
||||||
|
if p.baseURL != "" {
|
||||||
|
opts = append(opts, option.WithBaseURL(p.baseURL))
|
||||||
|
}
|
||||||
|
|
||||||
|
cl := openai.NewClient(opts...)
|
||||||
|
oaiReq := p.buildRequest(req)
|
||||||
|
|
||||||
|
resp, err := cl.Chat.Completions.New(ctx, oaiReq)
|
||||||
|
if err != nil {
|
||||||
|
return provider.Response{}, fmt.Errorf("openai completion error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.convertResponse(resp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream performs a streaming completion.
|
||||||
|
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
|
||||||
|
var opts []option.RequestOption
|
||||||
|
opts = append(opts, option.WithAPIKey(p.apiKey))
|
||||||
|
if p.baseURL != "" {
|
||||||
|
opts = append(opts, option.WithBaseURL(p.baseURL))
|
||||||
|
}
|
||||||
|
|
||||||
|
cl := openai.NewClient(opts...)
|
||||||
|
oaiReq := p.buildRequest(req)
|
||||||
|
|
||||||
|
stream := cl.Chat.Completions.NewStreaming(ctx, oaiReq)
|
||||||
|
|
||||||
|
var fullText strings.Builder
|
||||||
|
var toolCalls []provider.ToolCall
|
||||||
|
toolCallArgs := map[int]*strings.Builder{}
|
||||||
|
|
||||||
|
for stream.Next() {
|
||||||
|
chunk := stream.Current()
|
||||||
|
for _, choice := range chunk.Choices {
|
||||||
|
// Text delta
|
||||||
|
if choice.Delta.Content != "" {
|
||||||
|
fullText.WriteString(choice.Delta.Content)
|
||||||
|
events <- provider.StreamEvent{
|
||||||
|
Type: provider.StreamEventText,
|
||||||
|
Text: choice.Delta.Content,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool call deltas
|
||||||
|
for _, tc := range choice.Delta.ToolCalls {
|
||||||
|
idx := int(tc.Index)
|
||||||
|
|
||||||
|
if tc.ID != "" {
|
||||||
|
// New tool call starting
|
||||||
|
for len(toolCalls) <= idx {
|
||||||
|
toolCalls = append(toolCalls, provider.ToolCall{})
|
||||||
|
}
|
||||||
|
toolCalls[idx].ID = tc.ID
|
||||||
|
toolCalls[idx].Name = tc.Function.Name
|
||||||
|
toolCallArgs[idx] = &strings.Builder{}
|
||||||
|
|
||||||
|
events <- provider.StreamEvent{
|
||||||
|
Type: provider.StreamEventToolStart,
|
||||||
|
ToolCall: &provider.ToolCall{
|
||||||
|
ID: tc.ID,
|
||||||
|
Name: tc.Function.Name,
|
||||||
|
},
|
||||||
|
ToolIndex: idx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.Function.Arguments != "" {
|
||||||
|
if b, ok := toolCallArgs[idx]; ok {
|
||||||
|
b.WriteString(tc.Function.Arguments)
|
||||||
|
}
|
||||||
|
events <- provider.StreamEvent{
|
||||||
|
Type: provider.StreamEventToolDelta,
|
||||||
|
ToolIndex: idx,
|
||||||
|
ToolCall: &provider.ToolCall{
|
||||||
|
Arguments: tc.Function.Arguments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := stream.Err(); err != nil {
|
||||||
|
return fmt.Errorf("openai stream error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finalize tool calls
|
||||||
|
for idx := range toolCalls {
|
||||||
|
if b, ok := toolCallArgs[idx]; ok {
|
||||||
|
toolCalls[idx].Arguments = b.String()
|
||||||
|
}
|
||||||
|
events <- provider.StreamEvent{
|
||||||
|
Type: provider.StreamEventToolEnd,
|
||||||
|
ToolIndex: idx,
|
||||||
|
ToolCall: &toolCalls[idx],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send done event
|
||||||
|
events <- provider.StreamEvent{
|
||||||
|
Type: provider.StreamEventDone,
|
||||||
|
Response: &provider.Response{
|
||||||
|
Text: fullText.String(),
|
||||||
|
ToolCalls: toolCalls,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) buildRequest(req provider.Request) openai.ChatCompletionNewParams {
|
||||||
|
oaiReq := openai.ChatCompletionNewParams{
|
||||||
|
Model: req.Model,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, msg := range req.Messages {
|
||||||
|
oaiReq.Messages = append(oaiReq.Messages, convertMessage(msg, req.Model))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tool := range req.Tools {
|
||||||
|
oaiReq.Tools = append(oaiReq.Tools, openai.ChatCompletionToolParam{
|
||||||
|
Type: "function",
|
||||||
|
Function: shared.FunctionDefinitionParam{
|
||||||
|
Name: tool.Name,
|
||||||
|
Description: openai.String(tool.Description),
|
||||||
|
Parameters: openai.FunctionParameters(tool.Schema),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Temperature != nil {
|
||||||
|
// o* and gpt-5* models don't support custom temperatures
|
||||||
|
if !strings.HasPrefix(req.Model, "o") && !strings.HasPrefix(req.Model, "gpt-5") {
|
||||||
|
oaiReq.Temperature = openai.Float(*req.Temperature)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.MaxTokens != nil {
|
||||||
|
oaiReq.MaxCompletionTokens = openai.Int(int64(*req.MaxTokens))
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.TopP != nil {
|
||||||
|
oaiReq.TopP = openai.Float(*req.TopP)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(req.Stop) > 0 {
|
||||||
|
oaiReq.Stop = openai.ChatCompletionNewParamsStopUnion{OfString: openai.String(req.Stop[0])}
|
||||||
|
}
|
||||||
|
|
||||||
|
return oaiReq
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertMessage(msg provider.Message, model string) openai.ChatCompletionMessageParamUnion {
|
||||||
|
var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam
|
||||||
|
var textContent param.Opt[string]
|
||||||
|
|
||||||
|
for _, img := range msg.Images {
|
||||||
|
var url string
|
||||||
|
if img.Base64 != "" {
|
||||||
|
url = "data:" + img.ContentType + ";base64," + img.Base64
|
||||||
|
} else if img.URL != "" {
|
||||||
|
url = img.URL
|
||||||
|
}
|
||||||
|
if url != "" {
|
||||||
|
arrayOfContentParts = append(arrayOfContentParts,
|
||||||
|
openai.ChatCompletionContentPartUnionParam{
|
||||||
|
OfImageURL: &openai.ChatCompletionContentPartImageParam{
|
||||||
|
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
|
||||||
|
URL: url,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, aud := range msg.Audio {
|
||||||
|
var b64Data string
|
||||||
|
var format string
|
||||||
|
|
||||||
|
if aud.Base64 != "" {
|
||||||
|
b64Data = aud.Base64
|
||||||
|
format = audioFormat(aud.ContentType)
|
||||||
|
} else if aud.URL != "" {
|
||||||
|
resp, err := http.Get(aud.URL)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data, err := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b64Data = base64.StdEncoding.EncodeToString(data)
|
||||||
|
ct := resp.Header.Get("Content-Type")
|
||||||
|
if ct == "" {
|
||||||
|
ct = aud.ContentType
|
||||||
|
}
|
||||||
|
if ct == "" {
|
||||||
|
ct = audioFormatFromURL(aud.URL)
|
||||||
|
}
|
||||||
|
format = audioFormat(ct)
|
||||||
|
}
|
||||||
|
|
||||||
|
if b64Data != "" && format != "" {
|
||||||
|
arrayOfContentParts = append(arrayOfContentParts,
|
||||||
|
openai.ChatCompletionContentPartUnionParam{
|
||||||
|
OfInputAudio: &openai.ChatCompletionContentPartInputAudioParam{
|
||||||
|
InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
|
||||||
|
Data: b64Data,
|
||||||
|
Format: format,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg.Content != "" {
|
||||||
|
if len(arrayOfContentParts) > 0 {
|
||||||
|
arrayOfContentParts = append(arrayOfContentParts,
|
||||||
|
openai.ChatCompletionContentPartUnionParam{
|
||||||
|
OfText: &openai.ChatCompletionContentPartTextParam{
|
||||||
|
Text: msg.Content,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
textContent = openai.String(msg.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine if this model uses developer messages instead of system
|
||||||
|
useDeveloper := false
|
||||||
|
parts := strings.Split(model, "-")
|
||||||
|
if len(parts) > 1 && len(parts[0]) > 0 && parts[0][0] == 'o' {
|
||||||
|
useDeveloper = true
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg.Role {
|
||||||
|
case "system":
|
||||||
|
if useDeveloper {
|
||||||
|
return openai.ChatCompletionMessageParamUnion{
|
||||||
|
OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{
|
||||||
|
Content: openai.ChatCompletionDeveloperMessageParamContentUnion{
|
||||||
|
OfString: textContent,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return openai.ChatCompletionMessageParamUnion{
|
||||||
|
OfSystem: &openai.ChatCompletionSystemMessageParam{
|
||||||
|
Content: openai.ChatCompletionSystemMessageParamContentUnion{
|
||||||
|
OfString: textContent,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
case "user":
|
||||||
|
return openai.ChatCompletionMessageParamUnion{
|
||||||
|
OfUser: &openai.ChatCompletionUserMessageParam{
|
||||||
|
Content: openai.ChatCompletionUserMessageParamContentUnion{
|
||||||
|
OfString: textContent,
|
||||||
|
OfArrayOfContentParts: arrayOfContentParts,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
case "assistant":
|
||||||
|
as := &openai.ChatCompletionAssistantMessageParam{}
|
||||||
|
if msg.Content != "" {
|
||||||
|
as.Content.OfString = openai.String(msg.Content)
|
||||||
|
}
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{
|
||||||
|
ID: tc.ID,
|
||||||
|
Function: openai.ChatCompletionMessageToolCallFunctionParam{
|
||||||
|
Name: tc.Name,
|
||||||
|
Arguments: tc.Arguments,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return openai.ChatCompletionMessageParamUnion{OfAssistant: as}
|
||||||
|
|
||||||
|
case "tool":
|
||||||
|
return openai.ChatCompletionMessageParamUnion{
|
||||||
|
OfTool: &openai.ChatCompletionToolMessageParam{
|
||||||
|
ToolCallID: msg.ToolCallID,
|
||||||
|
Content: openai.ChatCompletionToolMessageParamContentUnion{
|
||||||
|
OfString: openai.String(msg.Content),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to user message
|
||||||
|
return openai.ChatCompletionMessageParamUnion{
|
||||||
|
OfUser: &openai.ChatCompletionUserMessageParam{
|
||||||
|
Content: openai.ChatCompletionUserMessageParamContentUnion{
|
||||||
|
OfString: textContent,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) convertResponse(resp *openai.ChatCompletion) provider.Response {
|
||||||
|
var res provider.Response
|
||||||
|
|
||||||
|
if resp == nil || len(resp.Choices) == 0 {
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
choice := resp.Choices[0]
|
||||||
|
res.Text = choice.Message.Content
|
||||||
|
|
||||||
|
for _, tc := range choice.Message.ToolCalls {
|
||||||
|
res.ToolCalls = append(res.ToolCalls, provider.ToolCall{
|
||||||
|
ID: tc.ID,
|
||||||
|
Name: tc.Function.Name,
|
||||||
|
Arguments: strings.TrimSpace(tc.Function.Arguments),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Usage.TotalTokens > 0 {
|
||||||
|
res.Usage = &provider.Usage{
|
||||||
|
InputTokens: int(resp.Usage.PromptTokens),
|
||||||
|
OutputTokens: int(resp.Usage.CompletionTokens),
|
||||||
|
TotalTokens: int(resp.Usage.TotalTokens),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// audioFormat converts a MIME type to an OpenAI audio format string ("wav" or "mp3").
|
||||||
|
func audioFormat(contentType string) string {
|
||||||
|
ct := strings.ToLower(contentType)
|
||||||
|
switch {
|
||||||
|
case strings.Contains(ct, "wav"):
|
||||||
|
return "wav"
|
||||||
|
case strings.Contains(ct, "mp3"), strings.Contains(ct, "mpeg"):
|
||||||
|
return "mp3"
|
||||||
|
default:
|
||||||
|
return "wav"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// audioFormatFromURL guesses the audio format from a URL's file extension.
|
||||||
|
func audioFormatFromURL(u string) string {
|
||||||
|
ext := strings.ToLower(path.Ext(u))
|
||||||
|
switch ext {
|
||||||
|
case ".mp3":
|
||||||
|
return "audio/mp3"
|
||||||
|
case ".wav":
|
||||||
|
return "audio/wav"
|
||||||
|
default:
|
||||||
|
return "audio/wav"
|
||||||
|
}
|
||||||
|
}
|
||||||
230
v2/openai/transcriber.go
Normal file
230
v2/openai/transcriber.go
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/openai/openai-go"
|
||||||
|
"github.com/openai/openai-go/option"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Transcriber implements the provider.Transcriber interface using OpenAI's audio models.
|
||||||
|
type Transcriber struct {
|
||||||
|
key string
|
||||||
|
model string
|
||||||
|
baseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ provider.Transcriber = (*Transcriber)(nil)
|
||||||
|
|
||||||
|
// NewTranscriber creates a transcriber backed by OpenAI's audio models.
|
||||||
|
// If model is empty, "whisper-1" is used by default.
|
||||||
|
func NewTranscriber(key string, model string) *Transcriber {
|
||||||
|
if strings.TrimSpace(model) == "" {
|
||||||
|
model = "whisper-1"
|
||||||
|
}
|
||||||
|
return &Transcriber{
|
||||||
|
key: key,
|
||||||
|
model: model,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTranscriberWithBaseURL creates a transcriber with a custom API base URL.
|
||||||
|
func NewTranscriberWithBaseURL(key, model, baseURL string) *Transcriber {
|
||||||
|
t := NewTranscriber(key, model)
|
||||||
|
t.baseURL = baseURL
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transcribe performs speech-to-text transcription of WAV audio data.
|
||||||
|
func (t *Transcriber) Transcribe(ctx context.Context, wav []byte, opts provider.TranscriptionOptions) (provider.Transcription, error) {
|
||||||
|
if len(wav) == 0 {
|
||||||
|
return provider.Transcription{}, fmt.Errorf("wav data is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
format := opts.ResponseFormat
|
||||||
|
if format == "" {
|
||||||
|
if strings.HasPrefix(t.model, "gpt-4o") {
|
||||||
|
format = provider.TranscriptionResponseFormatJSON
|
||||||
|
} else {
|
||||||
|
format = provider.TranscriptionResponseFormatVerboseJSON
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if format != provider.TranscriptionResponseFormatJSON && format != provider.TranscriptionResponseFormatVerboseJSON {
|
||||||
|
return provider.Transcription{}, fmt.Errorf("openai transcriber requires response_format json or verbose_json for structured output")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(opts.TimestampGranularities) > 0 && format != provider.TranscriptionResponseFormatVerboseJSON {
|
||||||
|
return provider.Transcription{}, fmt.Errorf("timestamp granularities require response_format=verbose_json")
|
||||||
|
}
|
||||||
|
|
||||||
|
params := openai.AudioTranscriptionNewParams{
|
||||||
|
File: openai.File(bytes.NewReader(wav), "audio.wav", "audio/wav"),
|
||||||
|
Model: openai.AudioModel(t.model),
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.Language != "" {
|
||||||
|
params.Language = openai.String(opts.Language)
|
||||||
|
}
|
||||||
|
if opts.Prompt != "" {
|
||||||
|
params.Prompt = openai.String(opts.Prompt)
|
||||||
|
}
|
||||||
|
if opts.Temperature != nil {
|
||||||
|
params.Temperature = openai.Float(*opts.Temperature)
|
||||||
|
}
|
||||||
|
|
||||||
|
params.ResponseFormat = openai.AudioResponseFormat(format)
|
||||||
|
|
||||||
|
if opts.IncludeLogprobs {
|
||||||
|
params.Include = []openai.TranscriptionInclude{openai.TranscriptionIncludeLogprobs}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(opts.TimestampGranularities) > 0 {
|
||||||
|
for _, granularity := range opts.TimestampGranularities {
|
||||||
|
params.TimestampGranularities = append(params.TimestampGranularities, string(granularity))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
clientOptions := []option.RequestOption{
|
||||||
|
option.WithAPIKey(t.key),
|
||||||
|
}
|
||||||
|
if t.baseURL != "" {
|
||||||
|
clientOptions = append(clientOptions, option.WithBaseURL(t.baseURL))
|
||||||
|
}
|
||||||
|
|
||||||
|
client := openai.NewClient(clientOptions...)
|
||||||
|
resp, err := client.Audio.Transcriptions.New(ctx, params)
|
||||||
|
if err != nil {
|
||||||
|
return provider.Transcription{}, fmt.Errorf("openai transcription failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return transcriptionToResult(t.model, resp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type verboseTranscription struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
Language string `json:"language"`
|
||||||
|
Duration float64 `json:"duration"`
|
||||||
|
Segments []verboseSegment `json:"segments"`
|
||||||
|
Words []verboseWord `json:"words"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type verboseSegment struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Start float64 `json:"start"`
|
||||||
|
End float64 `json:"end"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
Tokens []int `json:"tokens"`
|
||||||
|
AvgLogprob *float64 `json:"avg_logprob"`
|
||||||
|
CompressionRatio *float64 `json:"compression_ratio"`
|
||||||
|
NoSpeechProb *float64 `json:"no_speech_prob"`
|
||||||
|
Words []verboseWord `json:"words"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type verboseWord struct {
|
||||||
|
Word string `json:"word"`
|
||||||
|
Start float64 `json:"start"`
|
||||||
|
End float64 `json:"end"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func transcriptionToResult(model string, resp *openai.Transcription) provider.Transcription {
|
||||||
|
result := provider.Transcription{
|
||||||
|
Provider: "openai",
|
||||||
|
Model: model,
|
||||||
|
}
|
||||||
|
if resp == nil {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
result.Text = resp.Text
|
||||||
|
result.RawJSON = resp.RawJSON()
|
||||||
|
|
||||||
|
for _, logprob := range resp.Logprobs {
|
||||||
|
result.Logprobs = append(result.Logprobs, provider.TranscriptionTokenLogprob{
|
||||||
|
Token: logprob.Token,
|
||||||
|
Bytes: logprob.Bytes,
|
||||||
|
Logprob: logprob.Logprob,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if usage := usageToTranscriptionUsage(resp.Usage); usage.Type != "" {
|
||||||
|
result.Usage = usage
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RawJSON == "" {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
var verbose verboseTranscription
|
||||||
|
if err := json.Unmarshal([]byte(result.RawJSON), &verbose); err != nil {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
if verbose.Text != "" {
|
||||||
|
result.Text = verbose.Text
|
||||||
|
}
|
||||||
|
result.Language = verbose.Language
|
||||||
|
result.DurationSeconds = verbose.Duration
|
||||||
|
|
||||||
|
for _, seg := range verbose.Segments {
|
||||||
|
segment := provider.TranscriptionSegment{
|
||||||
|
ID: seg.ID,
|
||||||
|
Start: seg.Start,
|
||||||
|
End: seg.End,
|
||||||
|
Text: seg.Text,
|
||||||
|
Tokens: append([]int(nil), seg.Tokens...),
|
||||||
|
AvgLogprob: seg.AvgLogprob,
|
||||||
|
CompressionRatio: seg.CompressionRatio,
|
||||||
|
NoSpeechProb: seg.NoSpeechProb,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, word := range seg.Words {
|
||||||
|
segment.Words = append(segment.Words, provider.TranscriptionWord{
|
||||||
|
Word: word.Word,
|
||||||
|
Start: word.Start,
|
||||||
|
End: word.End,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
result.Segments = append(result.Segments, segment)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, word := range verbose.Words {
|
||||||
|
result.Words = append(result.Words, provider.TranscriptionWord{
|
||||||
|
Word: word.Word,
|
||||||
|
Start: word.Start,
|
||||||
|
End: word.End,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func usageToTranscriptionUsage(usage openai.TranscriptionUsageUnion) provider.TranscriptionUsage {
|
||||||
|
switch usage.Type {
|
||||||
|
case "tokens":
|
||||||
|
tokens := usage.AsTokens()
|
||||||
|
return provider.TranscriptionUsage{
|
||||||
|
Type: usage.Type,
|
||||||
|
InputTokens: tokens.InputTokens,
|
||||||
|
OutputTokens: tokens.OutputTokens,
|
||||||
|
TotalTokens: tokens.TotalTokens,
|
||||||
|
AudioTokens: tokens.InputTokenDetails.AudioTokens,
|
||||||
|
TextTokens: tokens.InputTokenDetails.TextTokens,
|
||||||
|
}
|
||||||
|
case "duration":
|
||||||
|
duration := usage.AsDuration()
|
||||||
|
return provider.TranscriptionUsage{
|
||||||
|
Type: usage.Type,
|
||||||
|
Seconds: duration.Seconds,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return provider.TranscriptionUsage{}
|
||||||
|
}
|
||||||
|
}
|
||||||
100
v2/provider/provider.go
Normal file
100
v2/provider/provider.go
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
// Package provider defines the interface that LLM backend implementations must satisfy.
|
||||||
|
package provider
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// Message is the provider-level message representation.
|
||||||
|
type Message struct {
|
||||||
|
Role string
|
||||||
|
Content string
|
||||||
|
Images []Image
|
||||||
|
Audio []Audio
|
||||||
|
ToolCalls []ToolCall
|
||||||
|
ToolCallID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image represents an image attachment at the provider level.
|
||||||
|
type Image struct {
|
||||||
|
URL string
|
||||||
|
Base64 string
|
||||||
|
ContentType string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Audio represents an audio attachment at the provider level.
|
||||||
|
type Audio struct {
|
||||||
|
URL string
|
||||||
|
Base64 string
|
||||||
|
ContentType string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolCall represents a tool invocation requested by the model.
|
||||||
|
type ToolCall struct {
|
||||||
|
ID string
|
||||||
|
Name string
|
||||||
|
Arguments string // raw JSON
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolDef defines a tool available to the model.
|
||||||
|
type ToolDef struct {
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
Schema map[string]any // JSON Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request is a completion request at the provider level.
|
||||||
|
type Request struct {
|
||||||
|
Model string
|
||||||
|
Messages []Message
|
||||||
|
Tools []ToolDef
|
||||||
|
Temperature *float64
|
||||||
|
MaxTokens *int
|
||||||
|
TopP *float64
|
||||||
|
Stop []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Response is a completion response at the provider level.
|
||||||
|
type Response struct {
|
||||||
|
Text string
|
||||||
|
ToolCalls []ToolCall
|
||||||
|
Usage *Usage
|
||||||
|
}
|
||||||
|
|
||||||
|
// Usage captures token consumption.
|
||||||
|
type Usage struct {
|
||||||
|
InputTokens int
|
||||||
|
OutputTokens int
|
||||||
|
TotalTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamEventType identifies the kind of stream event.
|
||||||
|
type StreamEventType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StreamEventText StreamEventType = iota // Text content delta
|
||||||
|
StreamEventToolStart // Tool call begins
|
||||||
|
StreamEventToolDelta // Tool call argument delta
|
||||||
|
StreamEventToolEnd // Tool call complete
|
||||||
|
StreamEventDone // Stream complete
|
||||||
|
StreamEventError // Error occurred
|
||||||
|
)
|
||||||
|
|
||||||
|
// StreamEvent represents a single event in a streaming response.
|
||||||
|
type StreamEvent struct {
|
||||||
|
Type StreamEventType
|
||||||
|
Text string
|
||||||
|
ToolCall *ToolCall
|
||||||
|
ToolIndex int
|
||||||
|
Error error
|
||||||
|
Response *Response
|
||||||
|
}
|
||||||
|
|
||||||
|
// Provider is the interface that LLM backends implement.
|
||||||
|
type Provider interface {
|
||||||
|
// Complete performs a non-streaming completion.
|
||||||
|
Complete(ctx context.Context, req Request) (Response, error)
|
||||||
|
|
||||||
|
// Stream performs a streaming completion, sending events to the channel.
|
||||||
|
// The provider MUST close the channel when done.
|
||||||
|
// The provider MUST send exactly one StreamEventDone as the last non-error event.
|
||||||
|
Stream(ctx context.Context, req Request, events chan<- StreamEvent) error
|
||||||
|
}
|
||||||
90
v2/provider/transcription.go
Normal file
90
v2/provider/transcription.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package provider
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// Transcriber abstracts a speech-to-text model implementation.
|
||||||
|
type Transcriber interface {
|
||||||
|
Transcribe(ctx context.Context, wav []byte, opts TranscriptionOptions) (Transcription, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TranscriptionResponseFormat controls the output format requested from a transcriber.
|
||||||
|
type TranscriptionResponseFormat string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TranscriptionResponseFormatJSON TranscriptionResponseFormat = "json"
|
||||||
|
TranscriptionResponseFormatVerboseJSON TranscriptionResponseFormat = "verbose_json"
|
||||||
|
TranscriptionResponseFormatText TranscriptionResponseFormat = "text"
|
||||||
|
TranscriptionResponseFormatSRT TranscriptionResponseFormat = "srt"
|
||||||
|
TranscriptionResponseFormatVTT TranscriptionResponseFormat = "vtt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TranscriptionTimestampGranularity defines the requested timestamp detail.
|
||||||
|
type TranscriptionTimestampGranularity string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TranscriptionTimestampGranularityWord TranscriptionTimestampGranularity = "word"
|
||||||
|
TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TranscriptionOptions configures transcription behavior.
|
||||||
|
type TranscriptionOptions struct {
|
||||||
|
Language string
|
||||||
|
Prompt string
|
||||||
|
Temperature *float64
|
||||||
|
ResponseFormat TranscriptionResponseFormat
|
||||||
|
TimestampGranularities []TranscriptionTimestampGranularity
|
||||||
|
IncludeLogprobs bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transcription captures a normalized transcription result.
|
||||||
|
type Transcription struct {
|
||||||
|
Provider string
|
||||||
|
Model string
|
||||||
|
Text string
|
||||||
|
Language string
|
||||||
|
DurationSeconds float64
|
||||||
|
Segments []TranscriptionSegment
|
||||||
|
Words []TranscriptionWord
|
||||||
|
Logprobs []TranscriptionTokenLogprob
|
||||||
|
Usage TranscriptionUsage
|
||||||
|
RawJSON string
|
||||||
|
}
|
||||||
|
|
||||||
|
// TranscriptionSegment provides a coarse time-sliced transcription segment.
|
||||||
|
type TranscriptionSegment struct {
|
||||||
|
ID int
|
||||||
|
Start float64
|
||||||
|
End float64
|
||||||
|
Text string
|
||||||
|
Tokens []int
|
||||||
|
AvgLogprob *float64
|
||||||
|
CompressionRatio *float64
|
||||||
|
NoSpeechProb *float64
|
||||||
|
Words []TranscriptionWord
|
||||||
|
}
|
||||||
|
|
||||||
|
// TranscriptionWord provides a word-level timestamp.
|
||||||
|
type TranscriptionWord struct {
|
||||||
|
Word string
|
||||||
|
Start float64
|
||||||
|
End float64
|
||||||
|
Confidence *float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// TranscriptionTokenLogprob captures token-level log probability details.
|
||||||
|
type TranscriptionTokenLogprob struct {
|
||||||
|
Token string
|
||||||
|
Bytes []float64
|
||||||
|
Logprob float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// TranscriptionUsage captures token or duration usage details.
|
||||||
|
type TranscriptionUsage struct {
|
||||||
|
Type string
|
||||||
|
InputTokens int64
|
||||||
|
OutputTokens int64
|
||||||
|
TotalTokens int64
|
||||||
|
AudioTokens int64
|
||||||
|
TextTokens int64
|
||||||
|
Seconds float64
|
||||||
|
}
|
||||||
37
v2/request.go
Normal file
37
v2/request.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
// RequestOption configures a single completion request.
|
||||||
|
type RequestOption func(*requestConfig)
|
||||||
|
|
||||||
|
type requestConfig struct {
|
||||||
|
tools *ToolBox
|
||||||
|
temperature *float64
|
||||||
|
maxTokens *int
|
||||||
|
topP *float64
|
||||||
|
stop []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTools attaches a toolbox to the request.
|
||||||
|
func WithTools(tb *ToolBox) RequestOption {
|
||||||
|
return func(c *requestConfig) { c.tools = tb }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTemperature sets the sampling temperature.
|
||||||
|
func WithTemperature(t float64) RequestOption {
|
||||||
|
return func(c *requestConfig) { c.temperature = &t }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMaxTokens sets the maximum number of tokens to generate.
|
||||||
|
func WithMaxTokens(n int) RequestOption {
|
||||||
|
return func(c *requestConfig) { c.maxTokens = &n }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTopP sets the nucleus sampling parameter.
|
||||||
|
func WithTopP(p float64) RequestOption {
|
||||||
|
return func(c *requestConfig) { c.topP = &p }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithStop sets stop sequences.
|
||||||
|
func WithStop(sequences ...string) RequestOption {
|
||||||
|
return func(c *requestConfig) { c.stop = sequences }
|
||||||
|
}
|
||||||
137
v2/request_test.go
Normal file
137
v2/request_test.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWithTemperature(t *testing.T) {
|
||||||
|
cfg := &requestConfig{}
|
||||||
|
WithTemperature(0.7)(cfg)
|
||||||
|
if cfg.temperature == nil || *cfg.temperature != 0.7 {
|
||||||
|
t.Errorf("expected temperature 0.7, got %v", cfg.temperature)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithMaxTokens(t *testing.T) {
|
||||||
|
cfg := &requestConfig{}
|
||||||
|
WithMaxTokens(256)(cfg)
|
||||||
|
if cfg.maxTokens == nil || *cfg.maxTokens != 256 {
|
||||||
|
t.Errorf("expected maxTokens 256, got %v", cfg.maxTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithTopP(t *testing.T) {
|
||||||
|
cfg := &requestConfig{}
|
||||||
|
WithTopP(0.95)(cfg)
|
||||||
|
if cfg.topP == nil || *cfg.topP != 0.95 {
|
||||||
|
t.Errorf("expected topP 0.95, got %v", cfg.topP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithStop(t *testing.T) {
|
||||||
|
cfg := &requestConfig{}
|
||||||
|
WithStop("END", "STOP", "###")(cfg)
|
||||||
|
if len(cfg.stop) != 3 {
|
||||||
|
t.Fatalf("expected 3 stop sequences, got %d", len(cfg.stop))
|
||||||
|
}
|
||||||
|
if cfg.stop[0] != "END" || cfg.stop[1] != "STOP" || cfg.stop[2] != "###" {
|
||||||
|
t.Errorf("unexpected stop sequences: %v", cfg.stop)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithTools(t *testing.T) {
|
||||||
|
tool := DefineSimple("test", "A test tool", func(ctx context.Context) (string, error) {
|
||||||
|
return "ok", nil
|
||||||
|
})
|
||||||
|
tb := NewToolBox(tool)
|
||||||
|
|
||||||
|
cfg := &requestConfig{}
|
||||||
|
WithTools(tb)(cfg)
|
||||||
|
if cfg.tools == nil {
|
||||||
|
t.Fatal("expected tools to be set")
|
||||||
|
}
|
||||||
|
if len(cfg.tools.AllTools()) != 1 {
|
||||||
|
t.Errorf("expected 1 tool, got %d", len(cfg.tools.AllTools()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildProviderRequest(t *testing.T) {
|
||||||
|
tool := DefineSimple("greet", "Greets", func(ctx context.Context) (string, error) {
|
||||||
|
return "hi", nil
|
||||||
|
})
|
||||||
|
tb := NewToolBox(tool)
|
||||||
|
|
||||||
|
temp := 0.8
|
||||||
|
maxTok := 512
|
||||||
|
topP := 0.9
|
||||||
|
|
||||||
|
cfg := &requestConfig{
|
||||||
|
tools: tb,
|
||||||
|
temperature: &temp,
|
||||||
|
maxTokens: &maxTok,
|
||||||
|
topP: &topP,
|
||||||
|
stop: []string{"END"},
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs := []Message{
|
||||||
|
SystemMessage("be nice"),
|
||||||
|
UserMessage("hello"),
|
||||||
|
}
|
||||||
|
|
||||||
|
req := buildProviderRequest("test-model", msgs, cfg)
|
||||||
|
|
||||||
|
if req.Model != "test-model" {
|
||||||
|
t.Errorf("expected model 'test-model', got %q", req.Model)
|
||||||
|
}
|
||||||
|
if len(req.Messages) != 2 {
|
||||||
|
t.Fatalf("expected 2 messages, got %d", len(req.Messages))
|
||||||
|
}
|
||||||
|
if req.Messages[0].Role != "system" {
|
||||||
|
t.Errorf("expected first message role='system', got %q", req.Messages[0].Role)
|
||||||
|
}
|
||||||
|
if req.Messages[1].Role != "user" {
|
||||||
|
t.Errorf("expected second message role='user', got %q", req.Messages[1].Role)
|
||||||
|
}
|
||||||
|
if req.Temperature == nil || *req.Temperature != 0.8 {
|
||||||
|
t.Errorf("expected temperature 0.8, got %v", req.Temperature)
|
||||||
|
}
|
||||||
|
if req.MaxTokens == nil || *req.MaxTokens != 512 {
|
||||||
|
t.Errorf("expected maxTokens 512, got %v", req.MaxTokens)
|
||||||
|
}
|
||||||
|
if req.TopP == nil || *req.TopP != 0.9 {
|
||||||
|
t.Errorf("expected topP 0.9, got %v", req.TopP)
|
||||||
|
}
|
||||||
|
if len(req.Stop) != 1 || req.Stop[0] != "END" {
|
||||||
|
t.Errorf("expected stop=[END], got %v", req.Stop)
|
||||||
|
}
|
||||||
|
if len(req.Tools) != 1 {
|
||||||
|
t.Fatalf("expected 1 tool, got %d", len(req.Tools))
|
||||||
|
}
|
||||||
|
if req.Tools[0].Name != "greet" {
|
||||||
|
t.Errorf("expected tool name 'greet', got %q", req.Tools[0].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildProviderRequest_EmptyConfig(t *testing.T) {
|
||||||
|
cfg := &requestConfig{}
|
||||||
|
msgs := []Message{UserMessage("hi")}
|
||||||
|
|
||||||
|
req := buildProviderRequest("model", msgs, cfg)
|
||||||
|
|
||||||
|
if req.Temperature != nil {
|
||||||
|
t.Errorf("expected nil temperature, got %v", req.Temperature)
|
||||||
|
}
|
||||||
|
if req.MaxTokens != nil {
|
||||||
|
t.Errorf("expected nil maxTokens, got %v", req.MaxTokens)
|
||||||
|
}
|
||||||
|
if req.TopP != nil {
|
||||||
|
t.Errorf("expected nil topP, got %v", req.TopP)
|
||||||
|
}
|
||||||
|
if len(req.Stop) != 0 {
|
||||||
|
t.Errorf("expected no stop sequences, got %v", req.Stop)
|
||||||
|
}
|
||||||
|
if len(req.Tools) != 0 {
|
||||||
|
t.Errorf("expected no tools, got %d", len(req.Tools))
|
||||||
|
}
|
||||||
|
}
|
||||||
34
v2/response.go
Normal file
34
v2/response.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
// Response represents the result of a completion request.
|
||||||
|
type Response struct {
|
||||||
|
// Text is the assistant's text content. Empty if only tool calls.
|
||||||
|
Text string
|
||||||
|
|
||||||
|
// ToolCalls contains any tool invocations the assistant requested.
|
||||||
|
ToolCalls []ToolCall
|
||||||
|
|
||||||
|
// Usage contains token usage information (if available from provider).
|
||||||
|
Usage *Usage
|
||||||
|
|
||||||
|
// message is the full assistant message for this response.
|
||||||
|
message Message
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message returns the full assistant Message for this response,
|
||||||
|
// suitable for appending to the conversation history.
|
||||||
|
func (r Response) Message() Message {
|
||||||
|
return r.message
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasToolCalls returns true if the response contains tool call requests.
|
||||||
|
func (r Response) HasToolCalls() bool {
|
||||||
|
return len(r.ToolCalls) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Usage captures token consumption.
|
||||||
|
type Usage struct {
|
||||||
|
InputTokens int
|
||||||
|
OutputTokens int
|
||||||
|
TotalTokens int
|
||||||
|
}
|
||||||
78
v2/sandbox/doc.go
Normal file
78
v2/sandbox/doc.go
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
// Package sandbox provides isolated Linux container environments for LLM agents.
|
||||||
|
//
|
||||||
|
// It manages the full lifecycle of Proxmox LXC containers — cloning from a template,
|
||||||
|
// starting, connecting via SSH, executing commands, transferring files, and destroying
|
||||||
|
// the container when done. Each sandbox is an ephemeral, unprivileged container on an
|
||||||
|
// isolated network bridge with no LAN access.
|
||||||
|
//
|
||||||
|
// # Architecture
|
||||||
|
//
|
||||||
|
// The package has three layers:
|
||||||
|
//
|
||||||
|
// - ProxmoxClient: thin REST client for the Proxmox VE API (container CRUD, IP discovery)
|
||||||
|
// - SSHExecutor: persistent SSH/SFTP connection for command execution and file transfer
|
||||||
|
// - Manager/Sandbox: high-level orchestrator that ties Proxmox + SSH together
|
||||||
|
//
|
||||||
|
// # Usage
|
||||||
|
//
|
||||||
|
// // Load SSH key for container access.
|
||||||
|
// signer, err := sandbox.LoadSSHKey("/etc/mort/sandbox_key")
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // Create a manager.
|
||||||
|
// mgr, err := sandbox.NewManager(sandbox.Config{
|
||||||
|
// Proxmox: sandbox.ProxmoxConfig{
|
||||||
|
// BaseURL: "https://proxmox.local:8006",
|
||||||
|
// TokenID: "mort-sandbox@pve!sandbox-token",
|
||||||
|
// Secret: os.Getenv("SANDBOX_PROXMOX_SECRET"),
|
||||||
|
// Node: "pve",
|
||||||
|
// TemplateID: 9000,
|
||||||
|
// Pool: "sandbox-pool",
|
||||||
|
// Bridge: "vmbr1",
|
||||||
|
// },
|
||||||
|
// SSH: sandbox.SSHConfig{
|
||||||
|
// Signer: signer,
|
||||||
|
// },
|
||||||
|
// })
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // Create a sandbox.
|
||||||
|
// ctx := context.Background()
|
||||||
|
// sb, err := mgr.Create(ctx,
|
||||||
|
// sandbox.WithHostname("user-abc"),
|
||||||
|
// sandbox.WithInternet(true),
|
||||||
|
// )
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
// defer sb.Destroy(ctx)
|
||||||
|
//
|
||||||
|
// // Execute commands.
|
||||||
|
// result, err := sb.Exec(ctx, "apt-get update && apt-get install -y nginx")
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
// fmt.Printf("exit %d: %s\n", result.ExitCode, result.Output)
|
||||||
|
//
|
||||||
|
// // Write files.
|
||||||
|
// err = sb.WriteFile(ctx, "/var/www/html/index.html", "<h1>Hello</h1>")
|
||||||
|
//
|
||||||
|
// // Read files.
|
||||||
|
// content, err := sb.ReadFile(ctx, "/etc/nginx/nginx.conf")
|
||||||
|
//
|
||||||
|
// # Security
|
||||||
|
//
|
||||||
|
// Sandboxes are secured through defense in depth:
|
||||||
|
// - Unprivileged LXC containers (UID mapping to high host UIDs)
|
||||||
|
// - Isolated network bridge with nftables default-deny outbound
|
||||||
|
// - Per-container opt-in internet access (HTTP/HTTPS only)
|
||||||
|
// - Resource limits: CPU, memory, disk, PID count
|
||||||
|
// - AppArmor confinement (lxc-container-default-cgns)
|
||||||
|
// - Capability dropping (sys_admin, sys_rawio, sys_ptrace, etc.)
|
||||||
|
//
|
||||||
|
// See docs/sandbox-setup.md for the complete Proxmox setup and hardening guide.
|
||||||
|
package sandbox
|
||||||
410
v2/sandbox/proxmox.go
Normal file
410
v2/sandbox/proxmox.go
Normal file
@@ -0,0 +1,410 @@
|
|||||||
|
package sandbox
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProxmoxConfig holds configuration for connecting to a Proxmox VE host.
|
||||||
|
type ProxmoxConfig struct {
|
||||||
|
// BaseURL is the Proxmox API base URL (e.g., "https://proxmox.local:8006").
|
||||||
|
BaseURL string
|
||||||
|
|
||||||
|
// TokenID is the API token identifier (e.g., "mort-sandbox@pve!sandbox-token").
|
||||||
|
TokenID string
|
||||||
|
|
||||||
|
// Secret is the API token secret.
|
||||||
|
Secret string
|
||||||
|
|
||||||
|
// Node is the Proxmox node name (e.g., "pve").
|
||||||
|
Node string
|
||||||
|
|
||||||
|
// TemplateID is the LXC template container ID to clone from (e.g., 9000).
|
||||||
|
TemplateID int
|
||||||
|
|
||||||
|
// Pool is the Proxmox resource pool for sandbox containers (e.g., "sandbox-pool").
|
||||||
|
Pool string
|
||||||
|
|
||||||
|
// Bridge is the network bridge for containers (e.g., "vmbr1").
|
||||||
|
Bridge string
|
||||||
|
|
||||||
|
// InsecureSkipVerify disables TLS certificate verification.
|
||||||
|
// Use only for self-signed Proxmox certificates.
|
||||||
|
InsecureSkipVerify bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContainerStatus represents the current state of a Proxmox LXC container.
|
||||||
|
type ContainerStatus struct {
|
||||||
|
Status string `json:"status"` // "running", "stopped", etc.
|
||||||
|
CPU float64 `json:"cpu"` // CPU usage (0.0–1.0)
|
||||||
|
Mem int64 `json:"mem"` // Current memory usage in bytes
|
||||||
|
MaxMem int64 `json:"maxmem"` // Maximum memory in bytes
|
||||||
|
Disk int64 `json:"disk"` // Current disk usage in bytes
|
||||||
|
MaxDisk int64 `json:"maxdisk"` // Maximum disk in bytes
|
||||||
|
NetIn int64 `json:"netin"` // Network bytes received
|
||||||
|
NetOut int64 `json:"netout"` // Network bytes sent
|
||||||
|
Uptime int64 `json:"uptime"` // Uptime in seconds
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContainerConfig holds settings for creating a new container.
|
||||||
|
type ContainerConfig struct {
|
||||||
|
// Hostname for the container.
|
||||||
|
Hostname string
|
||||||
|
|
||||||
|
// CPUs is the number of CPU cores (default 1).
|
||||||
|
CPUs int
|
||||||
|
|
||||||
|
// MemoryMB is the memory limit in megabytes (default 1024).
|
||||||
|
MemoryMB int
|
||||||
|
|
||||||
|
// DiskGB is the root filesystem size in gigabytes (default 8).
|
||||||
|
DiskGB int
|
||||||
|
|
||||||
|
// SSHPublicKey is an optional SSH public key to inject.
|
||||||
|
SSHPublicKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProxmoxClient is a thin REST API client for Proxmox VE container lifecycle management.
|
||||||
|
type ProxmoxClient struct {
|
||||||
|
config ProxmoxConfig
|
||||||
|
http *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProxmoxClient creates a new Proxmox API client.
|
||||||
|
func NewProxmoxClient(config ProxmoxConfig) *ProxmoxClient {
|
||||||
|
transport := &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: config.InsecureSkipVerify,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return &ProxmoxClient{
|
||||||
|
config: config,
|
||||||
|
http: &http.Client{
|
||||||
|
Transport: transport,
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NextAvailableID queries Proxmox for the next free VMID.
|
||||||
|
func (p *ProxmoxClient) NextAvailableID(ctx context.Context) (int, error) {
|
||||||
|
var result int
|
||||||
|
err := p.get(ctx, "/api2/json/cluster/nextid", &result)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("get next VMID: %w", err)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloneTemplate clones the configured template into a new container with the given VMID.
|
||||||
|
func (p *ProxmoxClient) CloneTemplate(ctx context.Context, newID int, cfg ContainerConfig) error {
|
||||||
|
path := fmt.Sprintf("/api2/json/nodes/%s/lxc/%d/clone", p.config.Node, p.config.TemplateID)
|
||||||
|
|
||||||
|
hostname := cfg.Hostname
|
||||||
|
if hostname == "" {
|
||||||
|
hostname = fmt.Sprintf("sandbox-%d", newID)
|
||||||
|
}
|
||||||
|
|
||||||
|
params := url.Values{
|
||||||
|
"newid": {fmt.Sprintf("%d", newID)},
|
||||||
|
"hostname": {hostname},
|
||||||
|
"full": {"1"},
|
||||||
|
}
|
||||||
|
if p.config.Pool != "" {
|
||||||
|
params.Set("pool", p.config.Pool)
|
||||||
|
}
|
||||||
|
|
||||||
|
taskID, err := p.post(ctx, path, params)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("clone template %d → %d: %w", p.config.TemplateID, newID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.waitForTask(ctx, taskID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigureContainer sets CPU, memory, and network on an existing container.
|
||||||
|
func (p *ProxmoxClient) ConfigureContainer(ctx context.Context, id int, cfg ContainerConfig) error {
|
||||||
|
path := fmt.Sprintf("/api2/json/nodes/%s/lxc/%d/config", p.config.Node, id)
|
||||||
|
|
||||||
|
cpus := cfg.CPUs
|
||||||
|
if cpus <= 0 {
|
||||||
|
cpus = 1
|
||||||
|
}
|
||||||
|
mem := cfg.MemoryMB
|
||||||
|
if mem <= 0 {
|
||||||
|
mem = 1024
|
||||||
|
}
|
||||||
|
|
||||||
|
params := url.Values{
|
||||||
|
"cores": {fmt.Sprintf("%d", cpus)},
|
||||||
|
"memory": {fmt.Sprintf("%d", mem)},
|
||||||
|
"swap": {"0"},
|
||||||
|
"net0": {fmt.Sprintf("name=eth0,bridge=%s,ip=dhcp", p.config.Bridge)},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := p.put(ctx, path, params)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("configure container %d: %w", id, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartContainer starts a stopped container.
|
||||||
|
func (p *ProxmoxClient) StartContainer(ctx context.Context, id int) error {
|
||||||
|
path := fmt.Sprintf("/api2/json/nodes/%s/lxc/%d/status/start", p.config.Node, id)
|
||||||
|
taskID, err := p.post(ctx, path, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("start container %d: %w", id, err)
|
||||||
|
}
|
||||||
|
return p.waitForTask(ctx, taskID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopContainer stops a running container.
|
||||||
|
func (p *ProxmoxClient) StopContainer(ctx context.Context, id int) error {
|
||||||
|
path := fmt.Sprintf("/api2/json/nodes/%s/lxc/%d/status/stop", p.config.Node, id)
|
||||||
|
taskID, err := p.post(ctx, path, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("stop container %d: %w", id, err)
|
||||||
|
}
|
||||||
|
return p.waitForTask(ctx, taskID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DestroyContainer stops (if running) and permanently deletes a container.
|
||||||
|
func (p *ProxmoxClient) DestroyContainer(ctx context.Context, id int) error {
|
||||||
|
// Try to stop first; ignore errors (might already be stopped).
|
||||||
|
status, err := p.GetContainerStatus(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get status before destroy: %w", err)
|
||||||
|
}
|
||||||
|
if status.Status == "running" {
|
||||||
|
_ = p.StopContainer(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
path := fmt.Sprintf("/api2/json/nodes/%s/lxc/%d", p.config.Node, id)
|
||||||
|
params := url.Values{"force": {"1"}, "purge": {"1"}}
|
||||||
|
taskID, err := p.delete(ctx, path, params)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("destroy container %d: %w", id, err)
|
||||||
|
}
|
||||||
|
return p.waitForTask(ctx, taskID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetContainerStatus returns the current status and resource usage of a container.
|
||||||
|
func (p *ProxmoxClient) GetContainerStatus(ctx context.Context, id int) (ContainerStatus, error) {
|
||||||
|
path := fmt.Sprintf("/api2/json/nodes/%s/lxc/%d/status/current", p.config.Node, id)
|
||||||
|
var status ContainerStatus
|
||||||
|
if err := p.get(ctx, path, &status); err != nil {
|
||||||
|
return ContainerStatus{}, fmt.Errorf("get container %d status: %w", id, err)
|
||||||
|
}
|
||||||
|
return status, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetContainerIP discovers the container's IP address by querying its network interfaces.
|
||||||
|
// It polls until an IP is found or the context is cancelled.
|
||||||
|
func (p *ProxmoxClient) GetContainerIP(ctx context.Context, id int) (string, error) {
|
||||||
|
path := fmt.Sprintf("/api2/json/nodes/%s/lxc/%d/interfaces", p.config.Node, id)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(2 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
var ifaces []struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
HWAddr string `json:"hwaddr"`
|
||||||
|
Inet string `json:"inet"`
|
||||||
|
Inet6 string `json:"inet6"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.get(ctx, path, &ifaces); err == nil {
|
||||||
|
for _, iface := range ifaces {
|
||||||
|
if iface.Name == "lo" || iface.Inet == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Inet is in CIDR format (e.g., "10.99.1.5/16")
|
||||||
|
ip := iface.Inet
|
||||||
|
if idx := strings.IndexByte(ip, '/'); idx > 0 {
|
||||||
|
ip = ip[:idx]
|
||||||
|
}
|
||||||
|
return ip, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return "", fmt.Errorf("get container %d IP: %w", id, ctx.Err())
|
||||||
|
case <-ticker.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableInternet adds a container IP to the nftables internet_allowed set,
|
||||||
|
// granting outbound HTTP/HTTPS access.
|
||||||
|
func (p *ProxmoxClient) EnableInternet(ctx context.Context, containerIP string) error {
|
||||||
|
return p.execOnHost(ctx, fmt.Sprintf("nft add element inet sandbox internet_allowed { %s }", containerIP))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableInternet removes a container IP from the nftables internet_allowed set,
|
||||||
|
// revoking outbound HTTP/HTTPS access.
|
||||||
|
func (p *ProxmoxClient) DisableInternet(ctx context.Context, containerIP string) error {
|
||||||
|
return p.execOnHost(ctx, fmt.Sprintf("nft delete element inet sandbox internet_allowed { %s }", containerIP))
|
||||||
|
}
|
||||||
|
|
||||||
|
// execOnHost runs a command on the Proxmox host via the API's node exec endpoint.
|
||||||
|
func (p *ProxmoxClient) execOnHost(ctx context.Context, command string) error {
|
||||||
|
path := fmt.Sprintf("/api2/json/nodes/%s/execute", p.config.Node)
|
||||||
|
params := url.Values{"commands": {command}}
|
||||||
|
_, err := p.post(ctx, path, params)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("exec on host: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- HTTP helpers ---
|
||||||
|
|
||||||
|
// proxmoxResponse is the standard envelope for all Proxmox API responses.
|
||||||
|
type proxmoxResponse struct {
|
||||||
|
Data json.RawMessage `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxmoxClient) doRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
|
||||||
|
u := strings.TrimRight(p.config.BaseURL, "/") + path
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, method, u, body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("PVEAPIToken=%s=%s", p.config.TokenID, p.config.Secret))
|
||||||
|
if body != nil {
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := p.http.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxmoxClient) get(ctx context.Context, path string, result any) error {
|
||||||
|
resp, err := p.doRequest(ctx, http.MethodGet, path, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
return p.parseResponse(resp, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxmoxClient) post(ctx context.Context, path string, params url.Values) (string, error) {
|
||||||
|
var body io.Reader
|
||||||
|
if params != nil {
|
||||||
|
body = strings.NewReader(params.Encode())
|
||||||
|
}
|
||||||
|
resp, err := p.doRequest(ctx, http.MethodPost, path, body)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var taskID string
|
||||||
|
if err := p.parseResponse(resp, &taskID); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return taskID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxmoxClient) put(ctx context.Context, path string, params url.Values) (string, error) {
|
||||||
|
var body io.Reader
|
||||||
|
if params != nil {
|
||||||
|
body = strings.NewReader(params.Encode())
|
||||||
|
}
|
||||||
|
resp, err := p.doRequest(ctx, http.MethodPut, path, body)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var result string
|
||||||
|
if err := p.parseResponse(resp, &result); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxmoxClient) delete(ctx context.Context, path string, params url.Values) (string, error) {
|
||||||
|
path = path + "?" + params.Encode()
|
||||||
|
resp, err := p.doRequest(ctx, http.MethodDelete, path, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var taskID string
|
||||||
|
if err := p.parseResponse(resp, &taskID); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return taskID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxmoxClient) parseResponse(resp *http.Response, result any) error {
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
return fmt.Errorf("proxmox API error (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
var envelope proxmoxResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&envelope); err != nil {
|
||||||
|
return fmt.Errorf("decode response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(envelope.Data, result); err != nil {
|
||||||
|
return fmt.Errorf("unmarshal data: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForTask polls a Proxmox task until it completes or the context is cancelled.
|
||||||
|
func (p *ProxmoxClient) waitForTask(ctx context.Context, taskID string) error {
|
||||||
|
if taskID == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
path := fmt.Sprintf("/api2/json/nodes/%s/tasks/%s/status", p.config.Node, url.PathEscape(taskID))
|
||||||
|
ticker := time.NewTicker(1 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
var status struct {
|
||||||
|
Status string `json:"status"` // "running", "stopped", etc.
|
||||||
|
ExitCode string `json:"exitstatus"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.get(ctx, path, &status); err != nil {
|
||||||
|
return fmt.Errorf("poll task %s: %w", taskID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if status.Status != "running" {
|
||||||
|
if status.ExitCode != "OK" && status.ExitCode != "" {
|
||||||
|
return fmt.Errorf("task %s failed: %s", taskID, status.ExitCode)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return fmt.Errorf("wait for task %s: %w", taskID, ctx.Err())
|
||||||
|
case <-ticker.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
310
v2/sandbox/sandbox.go
Normal file
310
v2/sandbox/sandbox.go
Normal file
@@ -0,0 +1,310 @@
|
|||||||
|
package sandbox
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config holds all configuration for creating sandboxes.
|
||||||
|
type Config struct {
|
||||||
|
Proxmox ProxmoxConfig
|
||||||
|
SSH SSHConfig
|
||||||
|
Defaults ContainerConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option configures a Sandbox before creation.
|
||||||
|
type Option func(*createOpts)
|
||||||
|
|
||||||
|
type createOpts struct {
|
||||||
|
hostname string
|
||||||
|
cpus int
|
||||||
|
memoryMB int
|
||||||
|
diskGB int
|
||||||
|
internet bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithHostname sets the container hostname.
|
||||||
|
func WithHostname(name string) Option {
|
||||||
|
return func(o *createOpts) { o.hostname = name }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCPUs sets the number of CPU cores for the container.
|
||||||
|
func WithCPUs(n int) Option {
|
||||||
|
return func(o *createOpts) { o.cpus = n }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMemoryMB sets the memory limit in megabytes.
|
||||||
|
func WithMemoryMB(mb int) Option {
|
||||||
|
return func(o *createOpts) { o.memoryMB = mb }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDiskGB sets the root filesystem size in gigabytes.
|
||||||
|
func WithDiskGB(gb int) Option {
|
||||||
|
return func(o *createOpts) { o.diskGB = gb }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithInternet enables outbound HTTP/HTTPS access on creation.
|
||||||
|
func WithInternet(enabled bool) Option {
|
||||||
|
return func(o *createOpts) { o.internet = enabled }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sandbox represents an isolated Linux container environment with SSH access.
|
||||||
|
// It wraps a Proxmox LXC container and provides command execution and file operations.
|
||||||
|
type Sandbox struct {
|
||||||
|
// ID is the Proxmox VMID of this container.
|
||||||
|
ID int
|
||||||
|
|
||||||
|
// IP is the container's IP address on the isolated bridge.
|
||||||
|
IP string
|
||||||
|
|
||||||
|
// Internet indicates whether outbound HTTP/HTTPS is enabled.
|
||||||
|
Internet bool
|
||||||
|
|
||||||
|
proxmox *ProxmoxClient
|
||||||
|
ssh *SSHExecutor
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manager creates and manages sandbox instances.
|
||||||
|
type Manager struct {
|
||||||
|
proxmox *ProxmoxClient
|
||||||
|
sshKey ssh.Signer
|
||||||
|
defaults ContainerConfig
|
||||||
|
sshCfg SSHConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new sandbox manager from the given configuration.
|
||||||
|
func NewManager(cfg Config) (*Manager, error) {
|
||||||
|
if cfg.SSH.Signer == nil {
|
||||||
|
return nil, fmt.Errorf("SSH signer is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Manager{
|
||||||
|
proxmox: NewProxmoxClient(cfg.Proxmox),
|
||||||
|
sshKey: cfg.SSH.Signer,
|
||||||
|
defaults: cfg.Defaults,
|
||||||
|
sshCfg: cfg.SSH,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create provisions a new sandbox container: clones the template, starts it,
|
||||||
|
// waits for SSH, and optionally enables internet access.
|
||||||
|
// The returned Sandbox must be destroyed with Destroy when no longer needed.
|
||||||
|
func (m *Manager) Create(ctx context.Context, opts ...Option) (*Sandbox, error) {
|
||||||
|
o := &createOpts{
|
||||||
|
hostname: m.defaults.Hostname,
|
||||||
|
cpus: m.defaults.CPUs,
|
||||||
|
memoryMB: m.defaults.MemoryMB,
|
||||||
|
diskGB: m.defaults.DiskGB,
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(o)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply defaults for zero values.
|
||||||
|
if o.cpus <= 0 {
|
||||||
|
o.cpus = 1
|
||||||
|
}
|
||||||
|
if o.memoryMB <= 0 {
|
||||||
|
o.memoryMB = 1024
|
||||||
|
}
|
||||||
|
if o.diskGB <= 0 {
|
||||||
|
o.diskGB = 8
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get next VMID.
|
||||||
|
vmid, err := m.proxmox.NextAvailableID(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get next VMID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
containerCfg := ContainerConfig{
|
||||||
|
Hostname: o.hostname,
|
||||||
|
CPUs: o.cpus,
|
||||||
|
MemoryMB: o.memoryMB,
|
||||||
|
DiskGB: o.diskGB,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone template.
|
||||||
|
if err := m.proxmox.CloneTemplate(ctx, vmid, containerCfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("clone template: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure container resources.
|
||||||
|
if err := m.proxmox.ConfigureContainer(ctx, vmid, containerCfg); err != nil {
|
||||||
|
// Clean up the cloned container on failure.
|
||||||
|
_ = m.proxmox.DestroyContainer(ctx, vmid)
|
||||||
|
return nil, fmt.Errorf("configure container: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start container.
|
||||||
|
if err := m.proxmox.StartContainer(ctx, vmid); err != nil {
|
||||||
|
_ = m.proxmox.DestroyContainer(ctx, vmid)
|
||||||
|
return nil, fmt.Errorf("start container: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Discover IP address (with timeout).
|
||||||
|
ipCtx, ipCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
defer ipCancel()
|
||||||
|
|
||||||
|
ip, err := m.proxmox.GetContainerIP(ipCtx, vmid)
|
||||||
|
if err != nil {
|
||||||
|
_ = m.proxmox.DestroyContainer(ctx, vmid)
|
||||||
|
return nil, fmt.Errorf("discover IP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect SSH (with timeout).
|
||||||
|
sshExec := NewSSHExecutor(ip, m.sshCfg)
|
||||||
|
|
||||||
|
sshCtx, sshCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
defer sshCancel()
|
||||||
|
|
||||||
|
if err := sshExec.Connect(sshCtx); err != nil {
|
||||||
|
_ = m.proxmox.DestroyContainer(ctx, vmid)
|
||||||
|
return nil, fmt.Errorf("ssh connect: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sb := &Sandbox{
|
||||||
|
ID: vmid,
|
||||||
|
IP: ip,
|
||||||
|
proxmox: m.proxmox,
|
||||||
|
ssh: sshExec,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable internet if requested.
|
||||||
|
if o.internet {
|
||||||
|
if err := sb.SetInternet(ctx, true); err != nil {
|
||||||
|
sb.Destroy(ctx)
|
||||||
|
return nil, fmt.Errorf("enable internet: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attach reconnects to an existing sandbox container by VMID.
|
||||||
|
// This is useful for recovering sessions after a restart.
|
||||||
|
func (m *Manager) Attach(ctx context.Context, vmid int) (*Sandbox, error) {
|
||||||
|
status, err := m.proxmox.GetContainerStatus(ctx, vmid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get container status: %w", err)
|
||||||
|
}
|
||||||
|
if status.Status != "running" {
|
||||||
|
return nil, fmt.Errorf("container %d is not running (status: %s)", vmid, status.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
ip, err := m.proxmox.GetContainerIP(ctx, vmid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get container IP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sshExec := NewSSHExecutor(ip, m.sshCfg)
|
||||||
|
if err := sshExec.Connect(ctx); err != nil {
|
||||||
|
return nil, fmt.Errorf("ssh connect: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Sandbox{
|
||||||
|
ID: vmid,
|
||||||
|
IP: ip,
|
||||||
|
proxmox: m.proxmox,
|
||||||
|
ssh: sshExec,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec runs a shell command in the sandbox and returns the result.
|
||||||
|
func (s *Sandbox) Exec(ctx context.Context, command string) (ExecResult, error) {
|
||||||
|
return s.ssh.Exec(ctx, command)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteFile creates or overwrites a file in the sandbox.
|
||||||
|
func (s *Sandbox) WriteFile(ctx context.Context, path, content string) error {
|
||||||
|
return s.ssh.Upload(ctx, strings.NewReader(content), path, 0644)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadFile reads a file from the sandbox and returns its contents.
|
||||||
|
func (s *Sandbox) ReadFile(ctx context.Context, path string) (string, error) {
|
||||||
|
rc, err := s.ssh.Download(ctx, path)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer rc.Close()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(rc)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("read file %s: %w", path, err)
|
||||||
|
}
|
||||||
|
return string(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upload copies data from an io.Reader to a file in the sandbox.
|
||||||
|
func (s *Sandbox) Upload(ctx context.Context, reader io.Reader, remotePath string, mode os.FileMode) error {
|
||||||
|
return s.ssh.Upload(ctx, reader, remotePath, mode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Download returns an io.ReadCloser for a file in the sandbox.
|
||||||
|
// The caller must close the returned reader.
|
||||||
|
func (s *Sandbox) Download(ctx context.Context, remotePath string) (io.ReadCloser, error) {
|
||||||
|
return s.ssh.Download(ctx, remotePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetInternet enables or disables outbound HTTP/HTTPS access for the sandbox.
|
||||||
|
func (s *Sandbox) SetInternet(ctx context.Context, enabled bool) error {
|
||||||
|
if enabled {
|
||||||
|
if err := s.proxmox.EnableInternet(ctx, s.IP); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := s.proxmox.DisableInternet(ctx, s.IP); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.Internet = enabled
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Status returns the current resource usage of the sandbox container.
|
||||||
|
func (s *Sandbox) Status(ctx context.Context) (ContainerStatus, error) {
|
||||||
|
return s.proxmox.GetContainerStatus(ctx, s.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsConnected returns true if the SSH connection to the sandbox is active.
|
||||||
|
func (s *Sandbox) IsConnected() bool {
|
||||||
|
return s.ssh.IsConnected()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Destroy stops the container, removes internet access, closes SSH connections,
|
||||||
|
// and permanently deletes the container from Proxmox.
|
||||||
|
func (s *Sandbox) Destroy(ctx context.Context) error {
|
||||||
|
var errs []error
|
||||||
|
|
||||||
|
// Remove internet access first (ignore errors — container is being destroyed).
|
||||||
|
if s.Internet {
|
||||||
|
_ = s.proxmox.DisableInternet(ctx, s.IP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close SSH connections.
|
||||||
|
if err := s.ssh.Close(); err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("close ssh: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Destroy the container.
|
||||||
|
if err := s.proxmox.DestroyContainer(ctx, s.ID); err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("destroy container: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errs) > 0 {
|
||||||
|
return fmt.Errorf("destroy sandbox %d: %v", s.ID, errs)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DestroyByID destroys a container by VMID without requiring an active SSH connection.
|
||||||
|
// This is useful for cleaning up orphaned containers after a restart.
|
||||||
|
func (m *Manager) DestroyByID(ctx context.Context, vmid int) error {
|
||||||
|
return m.proxmox.DestroyContainer(ctx, vmid)
|
||||||
|
}
|
||||||
1996
v2/sandbox/sandbox_test.go
Normal file
1996
v2/sandbox/sandbox_test.go
Normal file
File diff suppressed because it is too large
Load Diff
253
v2/sandbox/ssh.go
Normal file
253
v2/sandbox/ssh.go
Normal file
@@ -0,0 +1,253 @@
|
|||||||
|
package sandbox
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/sftp"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SSHConfig holds configuration for SSH connections to sandbox containers.
|
||||||
|
type SSHConfig struct {
|
||||||
|
// User is the SSH username (default "sandbox").
|
||||||
|
User string
|
||||||
|
|
||||||
|
// Signer is the SSH private key signer for authentication.
|
||||||
|
Signer ssh.Signer
|
||||||
|
|
||||||
|
// ConnectTimeout is the maximum time to wait for an SSH connection (default 10s).
|
||||||
|
ConnectTimeout time.Duration
|
||||||
|
|
||||||
|
// CommandTimeout is the default maximum time for a single command execution (default 60s).
|
||||||
|
CommandTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHExecutor manages SSH and SFTP connections to a sandbox container.
|
||||||
|
type SSHExecutor struct {
|
||||||
|
host string
|
||||||
|
config SSHConfig
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
sshClient *ssh.Client
|
||||||
|
sftpClient *sftp.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSSHExecutor creates a new SSH executor for the given host.
|
||||||
|
func NewSSHExecutor(host string, config SSHConfig) *SSHExecutor {
|
||||||
|
if config.User == "" {
|
||||||
|
config.User = "sandbox"
|
||||||
|
}
|
||||||
|
if config.ConnectTimeout <= 0 {
|
||||||
|
config.ConnectTimeout = 10 * time.Second
|
||||||
|
}
|
||||||
|
if config.CommandTimeout <= 0 {
|
||||||
|
config.CommandTimeout = 60 * time.Second
|
||||||
|
}
|
||||||
|
return &SSHExecutor{
|
||||||
|
host: host,
|
||||||
|
config: config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect establishes SSH and SFTP connections to the container.
|
||||||
|
// It polls until the connection succeeds or the context is cancelled,
|
||||||
|
// which is useful when waiting for a freshly started container to boot.
|
||||||
|
func (s *SSHExecutor) Connect(ctx context.Context) error {
|
||||||
|
sshConfig := &ssh.ClientConfig{
|
||||||
|
User: s.config.User,
|
||||||
|
Auth: []ssh.AuthMethod{
|
||||||
|
ssh.PublicKeys(s.config.Signer),
|
||||||
|
},
|
||||||
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||||
|
Timeout: s.config.ConnectTimeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := net.JoinHostPort(s.host, "22")
|
||||||
|
|
||||||
|
ticker := time.NewTicker(2 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
for {
|
||||||
|
client, err := ssh.Dial("tcp", addr, sshConfig)
|
||||||
|
if err == nil {
|
||||||
|
sftpClient, err := sftp.NewClient(client)
|
||||||
|
if err != nil {
|
||||||
|
client.Close()
|
||||||
|
return fmt.Errorf("create SFTP client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.sshClient = client
|
||||||
|
s.sftpClient = sftpClient
|
||||||
|
s.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
lastErr = err
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return fmt.Errorf("ssh connect to %s: %w (last error: %v)", addr, ctx.Err(), lastErr)
|
||||||
|
case <-ticker.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecResult contains the output and exit status of a command execution.
|
||||||
|
type ExecResult struct {
|
||||||
|
Output string
|
||||||
|
ExitCode int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec runs a shell command on the container and returns the combined stdout/stderr
|
||||||
|
// output and exit code.
|
||||||
|
func (s *SSHExecutor) Exec(ctx context.Context, command string) (ExecResult, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
client := s.sshClient
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if client == nil {
|
||||||
|
return ExecResult{}, fmt.Errorf("ssh not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := client.NewSession()
|
||||||
|
if err != nil {
|
||||||
|
return ExecResult{}, fmt.Errorf("create session: %w", err)
|
||||||
|
}
|
||||||
|
defer session.Close()
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
session.Stdout = &buf
|
||||||
|
session.Stderr = &buf
|
||||||
|
|
||||||
|
// Apply context timeout.
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- session.Run(command)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
_ = session.Signal(ssh.SIGKILL)
|
||||||
|
return ExecResult{}, fmt.Errorf("exec timed out: %w", ctx.Err())
|
||||||
|
case err := <-done:
|
||||||
|
output := buf.String()
|
||||||
|
if err != nil {
|
||||||
|
if exitErr, ok := err.(*ssh.ExitError); ok {
|
||||||
|
return ExecResult{
|
||||||
|
Output: output,
|
||||||
|
ExitCode: exitErr.ExitStatus(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return ExecResult{Output: output}, fmt.Errorf("exec: %w", err)
|
||||||
|
}
|
||||||
|
return ExecResult{Output: output, ExitCode: 0}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upload writes data from an io.Reader to a file on the container.
|
||||||
|
func (s *SSHExecutor) Upload(ctx context.Context, reader io.Reader, remotePath string, mode os.FileMode) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
client := s.sftpClient
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if client == nil {
|
||||||
|
return fmt.Errorf("sftp not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := client.OpenFile(remotePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("open remote file %s: %w", remotePath, err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if _, err := io.Copy(f, reader); err != nil {
|
||||||
|
return fmt.Errorf("write to %s: %w", remotePath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := client.Chmod(remotePath, mode); err != nil {
|
||||||
|
return fmt.Errorf("chmod %s: %w", remotePath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Download reads a file from the container and returns its contents as an io.ReadCloser.
|
||||||
|
// The caller must close the returned reader.
|
||||||
|
func (s *SSHExecutor) Download(ctx context.Context, remotePath string) (io.ReadCloser, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
client := s.sftpClient
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if client == nil {
|
||||||
|
return nil, fmt.Errorf("sftp not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := client.Open(remotePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("open remote file %s: %w", remotePath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close tears down both SFTP and SSH connections.
|
||||||
|
func (s *SSHExecutor) Close() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
var errs []error
|
||||||
|
if s.sftpClient != nil {
|
||||||
|
if err := s.sftpClient.Close(); err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("close SFTP: %w", err))
|
||||||
|
}
|
||||||
|
s.sftpClient = nil
|
||||||
|
}
|
||||||
|
if s.sshClient != nil {
|
||||||
|
if err := s.sshClient.Close(); err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("close SSH: %w", err))
|
||||||
|
}
|
||||||
|
s.sshClient = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errs) > 0 {
|
||||||
|
return fmt.Errorf("close ssh executor: %v", errs)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsConnected returns true if the SSH connection is established.
|
||||||
|
func (s *SSHExecutor) IsConnected() bool {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
return s.sshClient != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadSSHKey reads a PEM-encoded private key file and returns an ssh.Signer.
|
||||||
|
func LoadSSHKey(path string) (ssh.Signer, error) {
|
||||||
|
keyData, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read SSH key %s: %w", path, err)
|
||||||
|
}
|
||||||
|
signer, err := ssh.ParsePrivateKey(keyData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse SSH key: %w", err)
|
||||||
|
}
|
||||||
|
return signer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseSSHKey parses a PEM-encoded private key from bytes and returns an ssh.Signer.
|
||||||
|
func ParseSSHKey(pemBytes []byte) (ssh.Signer, error) {
|
||||||
|
signer, err := ssh.ParsePrivateKey(pemBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse SSH key: %w", err)
|
||||||
|
}
|
||||||
|
return signer, nil
|
||||||
|
}
|
||||||
163
v2/stream.go
Normal file
163
v2/stream.go
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StreamEventType identifies the kind of stream event.
|
||||||
|
type StreamEventType = provider.StreamEventType
|
||||||
|
|
||||||
|
const (
|
||||||
|
StreamEventText = provider.StreamEventText
|
||||||
|
StreamEventToolStart = provider.StreamEventToolStart
|
||||||
|
StreamEventToolDelta = provider.StreamEventToolDelta
|
||||||
|
StreamEventToolEnd = provider.StreamEventToolEnd
|
||||||
|
StreamEventDone = provider.StreamEventDone
|
||||||
|
StreamEventError = provider.StreamEventError
|
||||||
|
)
|
||||||
|
|
||||||
|
// StreamEvent represents a single event in a streaming response.
|
||||||
|
type StreamEvent struct {
|
||||||
|
Type StreamEventType
|
||||||
|
|
||||||
|
// Text is set for StreamEventText — the text delta.
|
||||||
|
Text string
|
||||||
|
|
||||||
|
// ToolCall is set for StreamEventToolStart/ToolDelta/ToolEnd.
|
||||||
|
ToolCall *ToolCall
|
||||||
|
|
||||||
|
// ToolIndex identifies which tool call is being updated.
|
||||||
|
ToolIndex int
|
||||||
|
|
||||||
|
// Error is set for StreamEventError.
|
||||||
|
Error error
|
||||||
|
|
||||||
|
// Response is set for StreamEventDone — the complete, aggregated response.
|
||||||
|
Response *Response
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamReader reads streaming events from an LLM response.
|
||||||
|
// Must be closed when done.
|
||||||
|
type StreamReader struct {
|
||||||
|
events <-chan StreamEvent
|
||||||
|
cancel context.CancelFunc
|
||||||
|
done bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStreamReader(ctx context.Context, p provider.Provider, req provider.Request) (*StreamReader, error) {
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
providerEvents := make(chan provider.StreamEvent, 32)
|
||||||
|
|
||||||
|
publicEvents := make(chan StreamEvent, 32)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(publicEvents)
|
||||||
|
for pev := range providerEvents {
|
||||||
|
ev := convertStreamEvent(pev)
|
||||||
|
select {
|
||||||
|
case publicEvents <- ev:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(providerEvents)
|
||||||
|
if err := p.Stream(ctx, req, providerEvents); err != nil {
|
||||||
|
select {
|
||||||
|
case providerEvents <- provider.StreamEvent{Type: provider.StreamEventError, Error: err}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return &StreamReader{
|
||||||
|
events: publicEvents,
|
||||||
|
cancel: cancel,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertStreamEvent(pev provider.StreamEvent) StreamEvent {
|
||||||
|
ev := StreamEvent{
|
||||||
|
Type: pev.Type,
|
||||||
|
Text: pev.Text,
|
||||||
|
ToolIndex: pev.ToolIndex,
|
||||||
|
}
|
||||||
|
if pev.Error != nil {
|
||||||
|
ev.Error = pev.Error
|
||||||
|
}
|
||||||
|
if pev.ToolCall != nil {
|
||||||
|
tc := ToolCall{
|
||||||
|
ID: pev.ToolCall.ID,
|
||||||
|
Name: pev.ToolCall.Name,
|
||||||
|
Arguments: pev.ToolCall.Arguments,
|
||||||
|
}
|
||||||
|
ev.ToolCall = &tc
|
||||||
|
}
|
||||||
|
if pev.Response != nil {
|
||||||
|
resp := convertProviderResponse(*pev.Response)
|
||||||
|
ev.Response = &resp
|
||||||
|
}
|
||||||
|
return ev
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next returns the next event from the stream.
|
||||||
|
// Returns io.EOF when the stream is complete.
|
||||||
|
func (sr *StreamReader) Next() (StreamEvent, error) {
|
||||||
|
if sr.done {
|
||||||
|
return StreamEvent{}, io.EOF
|
||||||
|
}
|
||||||
|
ev, ok := <-sr.events
|
||||||
|
if !ok {
|
||||||
|
sr.done = true
|
||||||
|
return StreamEvent{}, io.EOF
|
||||||
|
}
|
||||||
|
if ev.Type == StreamEventError {
|
||||||
|
return ev, ev.Error
|
||||||
|
}
|
||||||
|
if ev.Type == StreamEventDone {
|
||||||
|
sr.done = true
|
||||||
|
}
|
||||||
|
return ev, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the stream reader and releases resources.
|
||||||
|
func (sr *StreamReader) Close() error {
|
||||||
|
sr.cancel()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect reads all events and returns the final aggregated Response.
|
||||||
|
func (sr *StreamReader) Collect() (Response, error) {
|
||||||
|
var lastResp *Response
|
||||||
|
for {
|
||||||
|
ev, err := sr.Next()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return Response{}, err
|
||||||
|
}
|
||||||
|
if ev.Type == StreamEventDone && ev.Response != nil {
|
||||||
|
lastResp = ev.Response
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if lastResp == nil {
|
||||||
|
return Response{}, fmt.Errorf("stream completed without final response")
|
||||||
|
}
|
||||||
|
return *lastResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Text is a convenience that collects the stream and returns just the text.
|
||||||
|
func (sr *StreamReader) Text() (string, error) {
|
||||||
|
resp, err := sr.Collect()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return resp.Text, nil
|
||||||
|
}
|
||||||
338
v2/stream_test.go
Normal file
338
v2/stream_test.go
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStreamReader_TextEvents(t *testing.T) {
|
||||||
|
events := []provider.StreamEvent{
|
||||||
|
{Type: provider.StreamEventText, Text: "Hello"},
|
||||||
|
{Type: provider.StreamEventText, Text: " world"},
|
||||||
|
{Type: provider.StreamEventDone, Response: &provider.Response{
|
||||||
|
Text: "Hello world",
|
||||||
|
Usage: &provider.Usage{
|
||||||
|
InputTokens: 5,
|
||||||
|
OutputTokens: 2,
|
||||||
|
TotalTokens: 7,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
mp := newMockStreamProvider(events)
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
// Read text events
|
||||||
|
ev, err := reader.Next()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error on first event: %v", err)
|
||||||
|
}
|
||||||
|
if ev.Type != StreamEventText || ev.Text != "Hello" {
|
||||||
|
t.Errorf("expected text event 'Hello', got type=%d text=%q", ev.Type, ev.Text)
|
||||||
|
}
|
||||||
|
|
||||||
|
ev, err = reader.Next()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error on second event: %v", err)
|
||||||
|
}
|
||||||
|
if ev.Type != StreamEventText || ev.Text != " world" {
|
||||||
|
t.Errorf("expected text event ' world', got type=%d text=%q", ev.Type, ev.Text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read done event
|
||||||
|
ev, err = reader.Next()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error on done event: %v", err)
|
||||||
|
}
|
||||||
|
if ev.Type != StreamEventDone {
|
||||||
|
t.Errorf("expected done event, got type=%d", ev.Type)
|
||||||
|
}
|
||||||
|
if ev.Response == nil {
|
||||||
|
t.Fatal("expected response in done event")
|
||||||
|
}
|
||||||
|
if ev.Response.Text != "Hello world" {
|
||||||
|
t.Errorf("expected final text 'Hello world', got %q", ev.Response.Text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subsequent reads should return EOF
|
||||||
|
_, err = reader.Next()
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
|
t.Errorf("expected io.EOF after done, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamReader_ToolCallEvents(t *testing.T) {
|
||||||
|
events := []provider.StreamEvent{
|
||||||
|
{
|
||||||
|
Type: provider.StreamEventToolStart,
|
||||||
|
ToolIndex: 0,
|
||||||
|
ToolCall: &provider.ToolCall{ID: "tc1", Name: "search"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: provider.StreamEventToolDelta,
|
||||||
|
ToolIndex: 0,
|
||||||
|
ToolCall: &provider.ToolCall{Arguments: `{"query":`},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: provider.StreamEventToolDelta,
|
||||||
|
ToolIndex: 0,
|
||||||
|
ToolCall: &provider.ToolCall{Arguments: `"test"}`},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: provider.StreamEventToolEnd,
|
||||||
|
ToolIndex: 0,
|
||||||
|
ToolCall: &provider.ToolCall{ID: "tc1", Name: "search", Arguments: `{"query":"test"}`},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: provider.StreamEventDone,
|
||||||
|
Response: &provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{ID: "tc1", Name: "search", Arguments: `{"query":"test"}`},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mp := newMockStreamProvider(events)
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
// Read tool start
|
||||||
|
ev, err := reader.Next()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if ev.Type != StreamEventToolStart {
|
||||||
|
t.Errorf("expected tool start, got type=%d", ev.Type)
|
||||||
|
}
|
||||||
|
if ev.ToolCall == nil || ev.ToolCall.Name != "search" {
|
||||||
|
t.Errorf("expected tool call 'search', got %+v", ev.ToolCall)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read tool deltas
|
||||||
|
ev, _ = reader.Next()
|
||||||
|
if ev.Type != StreamEventToolDelta {
|
||||||
|
t.Errorf("expected tool delta, got type=%d", ev.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
ev, _ = reader.Next()
|
||||||
|
if ev.Type != StreamEventToolDelta {
|
||||||
|
t.Errorf("expected tool delta, got type=%d", ev.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read tool end
|
||||||
|
ev, _ = reader.Next()
|
||||||
|
if ev.Type != StreamEventToolEnd {
|
||||||
|
t.Errorf("expected tool end, got type=%d", ev.Type)
|
||||||
|
}
|
||||||
|
if ev.ToolCall == nil || ev.ToolCall.Arguments != `{"query":"test"}` {
|
||||||
|
t.Errorf("expected complete arguments, got %+v", ev.ToolCall)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read done
|
||||||
|
ev, _ = reader.Next()
|
||||||
|
if ev.Type != StreamEventDone {
|
||||||
|
t.Errorf("expected done, got type=%d", ev.Type)
|
||||||
|
}
|
||||||
|
if ev.Response == nil || len(ev.Response.ToolCalls) != 1 {
|
||||||
|
t.Error("expected response with 1 tool call")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamReader_Error(t *testing.T) {
|
||||||
|
streamErr := errors.New("stream failed")
|
||||||
|
mp := &mockProvider{
|
||||||
|
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
return provider.Response{}, nil
|
||||||
|
},
|
||||||
|
StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error {
|
||||||
|
ch <- provider.StreamEvent{Type: provider.StreamEventText, Text: "partial"}
|
||||||
|
ch <- provider.StreamEvent{Type: provider.StreamEventError, Error: streamErr}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
// Read partial text
|
||||||
|
ev, err := reader.Next()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if ev.Text != "partial" {
|
||||||
|
t.Errorf("expected 'partial', got %q", ev.Text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read error
|
||||||
|
_, err = reader.Next()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, streamErr) {
|
||||||
|
t.Errorf("expected stream error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamReader_Close(t *testing.T) {
|
||||||
|
// Create a stream that sends one event then blocks until context is cancelled
|
||||||
|
mp := &mockProvider{
|
||||||
|
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
return provider.Response{}, nil
|
||||||
|
},
|
||||||
|
StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error {
|
||||||
|
ch <- provider.StreamEvent{Type: provider.StreamEventText, Text: "start"}
|
||||||
|
<-ctx.Done()
|
||||||
|
return ctx.Err()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the first event
|
||||||
|
ev, err := reader.Next()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error on first event: %v", err)
|
||||||
|
}
|
||||||
|
if ev.Text != "start" {
|
||||||
|
t.Errorf("expected 'start', got %q", ev.Text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close should cancel context
|
||||||
|
if err := reader.Close(); err != nil {
|
||||||
|
t.Fatalf("close error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// After close, Next should eventually terminate with either EOF or context error.
|
||||||
|
// The exact behavior depends on goroutine scheduling: the channel may close (EOF)
|
||||||
|
// or the error event from the cancelled context may arrive first.
|
||||||
|
_, err = reader.Next()
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error after close, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamReader_Collect(t *testing.T) {
|
||||||
|
events := []provider.StreamEvent{
|
||||||
|
{Type: provider.StreamEventText, Text: "Hello"},
|
||||||
|
{Type: provider.StreamEventText, Text: " world"},
|
||||||
|
{Type: provider.StreamEventDone, Response: &provider.Response{
|
||||||
|
Text: "Hello world",
|
||||||
|
Usage: &provider.Usage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 2,
|
||||||
|
TotalTokens: 12,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
mp := newMockStreamProvider(events)
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
resp, err := reader.Collect()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("collect error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Text != "Hello world" {
|
||||||
|
t.Errorf("expected 'Hello world', got %q", resp.Text)
|
||||||
|
}
|
||||||
|
if resp.Usage == nil {
|
||||||
|
t.Fatal("expected usage")
|
||||||
|
}
|
||||||
|
if resp.Usage.InputTokens != 10 {
|
||||||
|
t.Errorf("expected 10 input tokens, got %d", resp.Usage.InputTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamReader_Text(t *testing.T) {
|
||||||
|
events := []provider.StreamEvent{
|
||||||
|
{Type: provider.StreamEventText, Text: "result"},
|
||||||
|
{Type: provider.StreamEventDone, Response: &provider.Response{Text: "result"}},
|
||||||
|
}
|
||||||
|
mp := newMockStreamProvider(events)
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
text, err := reader.Text()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("text error: %v", err)
|
||||||
|
}
|
||||||
|
if text != "result" {
|
||||||
|
t.Errorf("expected 'result', got %q", text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamReader_EmptyStream(t *testing.T) {
|
||||||
|
// Stream that completes without a done event (no response)
|
||||||
|
mp := newMockStreamProvider([]provider.StreamEvent{
|
||||||
|
{Type: provider.StreamEventText, Text: "hi"},
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
_, err = reader.Collect()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for stream without done event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamReader_StreamFuncError(t *testing.T) {
|
||||||
|
// Stream function returns error directly
|
||||||
|
mp := &mockProvider{
|
||||||
|
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
return provider.Response{}, nil
|
||||||
|
},
|
||||||
|
StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error {
|
||||||
|
return errors.New("stream init failed")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error creating reader: %v", err)
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
// The error should come through as an error event
|
||||||
|
_, err = reader.Collect()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error from stream function")
|
||||||
|
}
|
||||||
|
}
|
||||||
193
v2/tool.go
Normal file
193
v2/tool.go
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/internal/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Tool defines a tool that the LLM can invoke.
|
||||||
|
type Tool struct {
|
||||||
|
// Name is the tool's unique identifier.
|
||||||
|
Name string
|
||||||
|
|
||||||
|
// Description tells the LLM what this tool does.
|
||||||
|
Description string
|
||||||
|
|
||||||
|
// Schema is the JSON Schema for the tool's parameters.
|
||||||
|
Schema map[string]any
|
||||||
|
|
||||||
|
// fn holds the implementation function (set via Define or DefineSimple).
|
||||||
|
fn reflect.Value
|
||||||
|
pTyp reflect.Type // nil for parameterless tools
|
||||||
|
|
||||||
|
// isMCP indicates this tool is provided by an MCP server.
|
||||||
|
isMCP bool
|
||||||
|
mcpServer *MCPServer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define creates a tool from a typed handler function.
|
||||||
|
// T must be a struct. Struct fields become the tool's parameters.
|
||||||
|
//
|
||||||
|
// Struct tags:
|
||||||
|
// - `json:"name"` — parameter name
|
||||||
|
// - `description:"..."` — parameter description
|
||||||
|
// - `enum:"a,b,c"` — enum constraint
|
||||||
|
//
|
||||||
|
// Pointer fields are optional; non-pointer fields are required.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// type WeatherParams struct {
|
||||||
|
// City string `json:"city" description:"The city to query"`
|
||||||
|
// Unit string `json:"unit" description:"Temperature unit" enum:"celsius,fahrenheit"`
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// llm.Define[WeatherParams]("get_weather", "Get weather for a city",
|
||||||
|
// func(ctx context.Context, p WeatherParams) (string, error) {
|
||||||
|
// return fmt.Sprintf("72F in %s", p.City), nil
|
||||||
|
// },
|
||||||
|
// )
|
||||||
|
func Define[T any](name, description string, fn func(context.Context, T) (string, error)) Tool {
|
||||||
|
var zero T
|
||||||
|
return Tool{
|
||||||
|
Name: name,
|
||||||
|
Description: description,
|
||||||
|
Schema: schema.FromStruct(zero),
|
||||||
|
fn: reflect.ValueOf(fn),
|
||||||
|
pTyp: reflect.TypeOf(zero),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefineSimple creates a parameterless tool.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// llm.DefineSimple("get_time", "Get the current time",
|
||||||
|
// func(ctx context.Context) (string, error) {
|
||||||
|
// return time.Now().Format(time.RFC3339), nil
|
||||||
|
// },
|
||||||
|
// )
|
||||||
|
func DefineSimple(name, description string, fn func(context.Context) (string, error)) Tool {
|
||||||
|
return Tool{
|
||||||
|
Name: name,
|
||||||
|
Description: description,
|
||||||
|
Schema: map[string]any{"type": "object", "properties": map[string]any{}},
|
||||||
|
fn: reflect.ValueOf(fn),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute runs the tool with the given JSON arguments string.
|
||||||
|
func (t Tool) Execute(ctx context.Context, argsJSON string) (string, error) {
|
||||||
|
if t.isMCP {
|
||||||
|
var args map[string]any
|
||||||
|
if argsJSON != "" && argsJSON != "{}" {
|
||||||
|
if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
|
||||||
|
return "", fmt.Errorf("invalid MCP tool arguments: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return t.mcpServer.CallTool(ctx, t.Name, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parameterless tool
|
||||||
|
if t.pTyp == nil {
|
||||||
|
out := t.fn.Call([]reflect.Value{reflect.ValueOf(ctx)})
|
||||||
|
if !out[1].IsNil() {
|
||||||
|
return "", out[1].Interface().(error)
|
||||||
|
}
|
||||||
|
return out[0].String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Typed tool: unmarshal JSON into the struct, call the function
|
||||||
|
p := reflect.New(t.pTyp)
|
||||||
|
if argsJSON != "" && argsJSON != "{}" {
|
||||||
|
if err := json.Unmarshal([]byte(argsJSON), p.Interface()); err != nil {
|
||||||
|
return "", fmt.Errorf("invalid tool arguments: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out := t.fn.Call([]reflect.Value{reflect.ValueOf(ctx), p.Elem()})
|
||||||
|
if !out[1].IsNil() {
|
||||||
|
return "", out[1].Interface().(error)
|
||||||
|
}
|
||||||
|
return out[0].String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolBox is a collection of tools available for use by an LLM.
|
||||||
|
type ToolBox struct {
|
||||||
|
tools map[string]Tool
|
||||||
|
mcpServers []*MCPServer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewToolBox creates a new ToolBox from the given tools.
|
||||||
|
func NewToolBox(tools ...Tool) *ToolBox {
|
||||||
|
tb := &ToolBox{tools: make(map[string]Tool)}
|
||||||
|
for _, t := range tools {
|
||||||
|
tb.tools[t.Name] = t
|
||||||
|
}
|
||||||
|
return tb
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds tools to the toolbox and returns it for chaining.
|
||||||
|
func (tb *ToolBox) Add(tools ...Tool) *ToolBox {
|
||||||
|
if tb.tools == nil {
|
||||||
|
tb.tools = make(map[string]Tool)
|
||||||
|
}
|
||||||
|
for _, t := range tools {
|
||||||
|
tb.tools[t.Name] = t
|
||||||
|
}
|
||||||
|
return tb
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddMCP adds an MCP server's tools to the toolbox. The server must be connected.
|
||||||
|
func (tb *ToolBox) AddMCP(server *MCPServer) *ToolBox {
|
||||||
|
if tb.tools == nil {
|
||||||
|
tb.tools = make(map[string]Tool)
|
||||||
|
}
|
||||||
|
tb.mcpServers = append(tb.mcpServers, server)
|
||||||
|
|
||||||
|
for _, tool := range server.ListTools() {
|
||||||
|
tb.tools[tool.Name] = tool
|
||||||
|
}
|
||||||
|
return tb
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllTools returns all tools (local + MCP) as a slice.
|
||||||
|
func (tb *ToolBox) AllTools() []Tool {
|
||||||
|
if tb == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
tools := make([]Tool, 0, len(tb.tools))
|
||||||
|
for _, t := range tb.tools {
|
||||||
|
tools = append(tools, t)
|
||||||
|
}
|
||||||
|
return tools
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute executes a tool call by name.
|
||||||
|
func (tb *ToolBox) Execute(ctx context.Context, call ToolCall) (string, error) {
|
||||||
|
if tb == nil {
|
||||||
|
return "", ErrNoToolsConfigured
|
||||||
|
}
|
||||||
|
tool, ok := tb.tools[call.Name]
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("%w: %s", ErrToolNotFound, call.Name)
|
||||||
|
}
|
||||||
|
return tool.Execute(ctx, call.Arguments)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteAll executes all tool calls and returns tool result messages.
|
||||||
|
func (tb *ToolBox) ExecuteAll(ctx context.Context, calls []ToolCall) ([]Message, error) {
|
||||||
|
var results []Message
|
||||||
|
for _, call := range calls {
|
||||||
|
result, err := tb.Execute(ctx, call)
|
||||||
|
text := result
|
||||||
|
if err != nil {
|
||||||
|
text = "Error: " + err.Error()
|
||||||
|
}
|
||||||
|
results = append(results, ToolResultMessage(call.ID, text))
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
139
v2/tool_test.go
Normal file
139
v2/tool_test.go
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type calcParams struct {
|
||||||
|
A float64 `json:"a" description:"First number"`
|
||||||
|
B float64 `json:"b" description:"Second number"`
|
||||||
|
Op string `json:"op" description:"Operation" enum:"add,sub,mul,div"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefine(t *testing.T) {
|
||||||
|
tool := Define[calcParams]("calc", "Calculator",
|
||||||
|
func(ctx context.Context, p calcParams) (string, error) {
|
||||||
|
var result float64
|
||||||
|
switch p.Op {
|
||||||
|
case "add":
|
||||||
|
result = p.A + p.B
|
||||||
|
case "sub":
|
||||||
|
result = p.A - p.B
|
||||||
|
case "mul":
|
||||||
|
result = p.A * p.B
|
||||||
|
case "div":
|
||||||
|
result = p.A / p.B
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(result)
|
||||||
|
return string(b), err
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if tool.Name != "calc" {
|
||||||
|
t.Errorf("expected name 'calc', got %q", tool.Name)
|
||||||
|
}
|
||||||
|
if tool.Description != "Calculator" {
|
||||||
|
t.Errorf("expected description 'Calculator', got %q", tool.Description)
|
||||||
|
}
|
||||||
|
if tool.Schema["type"] != "object" {
|
||||||
|
t.Errorf("expected schema type=object, got %v", tool.Schema["type"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test execution
|
||||||
|
result, err := tool.Execute(context.Background(), `{"a": 10, "b": 3, "op": "add"}`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execute failed: %v", err)
|
||||||
|
}
|
||||||
|
if result != "13" {
|
||||||
|
t.Errorf("expected '13', got %q", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefineSimple(t *testing.T) {
|
||||||
|
tool := DefineSimple("hello", "Say hello",
|
||||||
|
func(ctx context.Context) (string, error) {
|
||||||
|
return "Hello, world!", nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err := tool.Execute(context.Background(), "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execute failed: %v", err)
|
||||||
|
}
|
||||||
|
if result != "Hello, world!" {
|
||||||
|
t.Errorf("expected 'Hello, world!', got %q", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToolBox(t *testing.T) {
|
||||||
|
tool1 := DefineSimple("tool1", "Tool 1", func(ctx context.Context) (string, error) {
|
||||||
|
return "result1", nil
|
||||||
|
})
|
||||||
|
tool2 := DefineSimple("tool2", "Tool 2", func(ctx context.Context) (string, error) {
|
||||||
|
return "result2", nil
|
||||||
|
})
|
||||||
|
|
||||||
|
tb := NewToolBox(tool1, tool2)
|
||||||
|
|
||||||
|
tools := tb.AllTools()
|
||||||
|
if len(tools) != 2 {
|
||||||
|
t.Errorf("expected 2 tools, got %d", len(tools))
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := tb.Execute(context.Background(), ToolCall{ID: "1", Name: "tool1"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execute failed: %v", err)
|
||||||
|
}
|
||||||
|
if result != "result1" {
|
||||||
|
t.Errorf("expected 'result1', got %q", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test not found
|
||||||
|
_, err = tb.Execute(context.Background(), ToolCall{ID: "x", Name: "nonexistent"})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for nonexistent tool")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToolBoxExecuteAll(t *testing.T) {
|
||||||
|
tb := NewToolBox(
|
||||||
|
DefineSimple("t1", "T1", func(ctx context.Context) (string, error) {
|
||||||
|
return "r1", nil
|
||||||
|
}),
|
||||||
|
DefineSimple("t2", "T2", func(ctx context.Context) (string, error) {
|
||||||
|
return "r2", nil
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
calls := []ToolCall{
|
||||||
|
{ID: "c1", Name: "t1"},
|
||||||
|
{ID: "c2", Name: "t2"},
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs, err := tb.ExecuteAll(context.Background(), calls)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execute all failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msgs) != 2 {
|
||||||
|
t.Fatalf("expected 2 messages, got %d", len(msgs))
|
||||||
|
}
|
||||||
|
|
||||||
|
if msgs[0].Role != RoleTool {
|
||||||
|
t.Errorf("expected role=tool, got %v", msgs[0].Role)
|
||||||
|
}
|
||||||
|
if msgs[0].ToolCallID != "c1" {
|
||||||
|
t.Errorf("expected toolCallID=c1, got %v", msgs[0].ToolCallID)
|
||||||
|
}
|
||||||
|
if msgs[0].Content.Text != "r1" {
|
||||||
|
t.Errorf("expected content=r1, got %v", msgs[0].Content.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// jsonMarshal helper for calcParams test
|
||||||
|
func (p calcParams) jsonMarshal(result float64) (string, error) {
|
||||||
|
b, err := json.Marshal(result)
|
||||||
|
return string(b), err
|
||||||
|
}
|
||||||
59
v2/tools/browser.go
Normal file
59
v2/tools/browser.go
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BrowserParams defines parameters for the browser tool.
|
||||||
|
type BrowserParams struct {
|
||||||
|
URL string `json:"url" description:"The URL to fetch and extract text from"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Browser creates a simple web content fetcher tool.
|
||||||
|
// It fetches a URL and returns the text content.
|
||||||
|
//
|
||||||
|
// For a full headless browser, consider using an MCP server like Playwright MCP.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// tools := llm.NewToolBox(tools.Browser())
|
||||||
|
func Browser() llm.Tool {
|
||||||
|
return llm.Define[BrowserParams]("browser", "Fetch a web page and return its text content",
|
||||||
|
func(ctx context.Context, p BrowserParams) (string, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, p.URL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("creating request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("User-Agent", "go-llm/2.0 (Web Fetcher)")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("fetching URL: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Limit to 1MB
|
||||||
|
limited := io.LimitReader(resp.Body, 1<<20)
|
||||||
|
body, err := io.ReadAll(limited)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("reading body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := map[string]any{
|
||||||
|
"url": p.URL,
|
||||||
|
"status": resp.StatusCode,
|
||||||
|
"content_type": resp.Header.Get("Content-Type"),
|
||||||
|
"body": string(body),
|
||||||
|
}
|
||||||
|
|
||||||
|
out, _ := json.MarshalIndent(result, "", " ")
|
||||||
|
return string(out), nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
101
v2/tools/exec.go
Normal file
101
v2/tools/exec.go
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExecParams defines parameters for the exec tool.
|
||||||
|
type ExecParams struct {
|
||||||
|
Command string `json:"command" description:"The shell command to execute"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecOption configures the exec tool.
|
||||||
|
type ExecOption func(*execConfig)
|
||||||
|
|
||||||
|
type execConfig struct {
|
||||||
|
allowedCommands []string
|
||||||
|
workDir string
|
||||||
|
timeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAllowedCommands restricts which commands can be executed.
|
||||||
|
// If empty, all commands are allowed.
|
||||||
|
func WithAllowedCommands(cmds []string) ExecOption {
|
||||||
|
return func(c *execConfig) { c.allowedCommands = cmds }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithWorkDir sets the working directory for command execution.
|
||||||
|
func WithWorkDir(dir string) ExecOption {
|
||||||
|
return func(c *execConfig) { c.workDir = dir }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithExecTimeout sets the maximum execution time.
|
||||||
|
func WithExecTimeout(d time.Duration) ExecOption {
|
||||||
|
return func(c *execConfig) { c.timeout = d }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec creates a shell command execution tool.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// tools := llm.NewToolBox(
|
||||||
|
// tools.Exec(tools.WithAllowedCommands([]string{"ls", "cat", "grep"})),
|
||||||
|
// )
|
||||||
|
func Exec(opts ...ExecOption) llm.Tool {
|
||||||
|
cfg := &execConfig{
|
||||||
|
timeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return llm.Define[ExecParams]("exec", "Execute a shell command and return its output",
|
||||||
|
func(ctx context.Context, p ExecParams) (string, error) {
|
||||||
|
// Check allowed commands
|
||||||
|
if len(cfg.allowedCommands) > 0 {
|
||||||
|
parts := strings.Fields(p.Command)
|
||||||
|
if len(parts) == 0 {
|
||||||
|
return "", fmt.Errorf("empty command")
|
||||||
|
}
|
||||||
|
allowed := false
|
||||||
|
for _, cmd := range cfg.allowedCommands {
|
||||||
|
if parts[0] == cmd {
|
||||||
|
allowed = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !allowed {
|
||||||
|
return "", fmt.Errorf("command %q is not in the allowed list", parts[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, cfg.timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
cmd = exec.CommandContext(ctx, "cmd", "/C", p.Command)
|
||||||
|
} else {
|
||||||
|
cmd = exec.CommandContext(ctx, "sh", "-c", p.Command)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.workDir != "" {
|
||||||
|
cmd.Dir = cfg.workDir
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Sprintf("Error: %s\nOutput: %s", err.Error(), string(output)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(output), nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
75
v2/tools/http.go
Normal file
75
v2/tools/http.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HTTPParams defines parameters for the HTTP request tool.
|
||||||
|
type HTTPParams struct {
|
||||||
|
Method string `json:"method" description:"HTTP method" enum:"GET,POST,PUT,DELETE,PATCH,HEAD"`
|
||||||
|
URL string `json:"url" description:"Request URL"`
|
||||||
|
Headers map[string]string `json:"headers,omitempty" description:"Request headers"`
|
||||||
|
Body *string `json:"body,omitempty" description:"Request body"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTP creates an HTTP request tool.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// tools := llm.NewToolBox(tools.HTTP())
|
||||||
|
func HTTP() llm.Tool {
|
||||||
|
return llm.Define[HTTPParams]("http_request", "Make an HTTP request and return the response",
|
||||||
|
func(ctx context.Context, p HTTPParams) (string, error) {
|
||||||
|
var bodyReader io.Reader
|
||||||
|
if p.Body != nil {
|
||||||
|
bodyReader = bytes.NewBufferString(*p.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, p.Method, p.URL, bodyReader)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("creating request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range p.Headers {
|
||||||
|
req.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Limit to 1MB
|
||||||
|
limited := io.LimitReader(resp.Body, 1<<20)
|
||||||
|
body, err := io.ReadAll(limited)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("reading response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := map[string]string{}
|
||||||
|
for k, v := range resp.Header {
|
||||||
|
if len(v) > 0 {
|
||||||
|
headers[k] = v[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := map[string]any{
|
||||||
|
"status": resp.StatusCode,
|
||||||
|
"status_text": resp.Status,
|
||||||
|
"headers": headers,
|
||||||
|
"body": string(body),
|
||||||
|
}
|
||||||
|
|
||||||
|
out, _ := json.MarshalIndent(result, "", " ")
|
||||||
|
return string(out), nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
81
v2/tools/readfile.go
Normal file
81
v2/tools/readfile.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReadFileParams defines parameters for the read file tool.
|
||||||
|
type ReadFileParams struct {
|
||||||
|
Path string `json:"path" description:"File path to read"`
|
||||||
|
Start *int `json:"start,omitempty" description:"Starting line number (1-based, inclusive)"`
|
||||||
|
End *int `json:"end,omitempty" description:"Ending line number (1-based, inclusive)"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadFile creates a file reading tool.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// tools := llm.NewToolBox(tools.ReadFile())
|
||||||
|
func ReadFile() llm.Tool {
|
||||||
|
return llm.Define[ReadFileParams]("read_file", "Read the contents of a file",
|
||||||
|
func(ctx context.Context, p ReadFileParams) (string, error) {
|
||||||
|
f, err := os.Open(p.Path)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("opening file: %w", err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
// If no line range specified, read the whole file (limited to 1MB)
|
||||||
|
if p.Start == nil && p.End == nil {
|
||||||
|
info, err := f.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("stat file: %w", err)
|
||||||
|
}
|
||||||
|
if info.Size() > 1<<20 {
|
||||||
|
return "", fmt.Errorf("file too large (%d bytes), use start/end to read a range", info.Size())
|
||||||
|
}
|
||||||
|
data, err := os.ReadFile(p.Path)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("reading file: %w", err)
|
||||||
|
}
|
||||||
|
return string(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read specific line range
|
||||||
|
start := 1
|
||||||
|
end := -1
|
||||||
|
if p.Start != nil {
|
||||||
|
start = *p.Start
|
||||||
|
}
|
||||||
|
if p.End != nil {
|
||||||
|
end = *p.End
|
||||||
|
}
|
||||||
|
|
||||||
|
var lines []string
|
||||||
|
scanner := bufio.NewScanner(f)
|
||||||
|
lineNum := 0
|
||||||
|
for scanner.Scan() {
|
||||||
|
lineNum++
|
||||||
|
if lineNum < start {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if end > 0 && lineNum > end {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
lines = append(lines, fmt.Sprintf("%d: %s", lineNum, scanner.Text()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return "", fmt.Errorf("scanning file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(lines, "\n"), nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
101
v2/tools/websearch.go
Normal file
101
v2/tools/websearch.go
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
// Package tools provides ready-to-use tool implementations for common agent patterns.
|
||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WebSearchParams defines parameters for the web search tool.
|
||||||
|
type WebSearchParams struct {
|
||||||
|
Query string `json:"query" description:"The search query"`
|
||||||
|
Count *int `json:"count,omitempty" description:"Number of results to return (default 5, max 20)"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSearch creates a web search tool using the Brave Search API.
|
||||||
|
//
|
||||||
|
// Get a free API key at https://brave.com/search/api/
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// tools := llm.NewToolBox(tools.WebSearch("your-brave-api-key"))
|
||||||
|
func WebSearch(apiKey string) llm.Tool {
|
||||||
|
return llm.Define[WebSearchParams]("web_search", "Search the web for information using Brave Search",
|
||||||
|
func(ctx context.Context, p WebSearchParams) (string, error) {
|
||||||
|
count := 5
|
||||||
|
if p.Count != nil && *p.Count > 0 {
|
||||||
|
count = *p.Count
|
||||||
|
if count > 20 {
|
||||||
|
count = 20
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
u := fmt.Sprintf("https://api.search.brave.com/res/v1/web/search?q=%s&count=%d",
|
||||||
|
url.QueryEscape(p.Query), count)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("creating request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
req.Header.Set("X-Subscription-Token", apiKey)
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("search request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("reading response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return "", fmt.Errorf("search API returned %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse and simplify the response
|
||||||
|
var raw map[string]any
|
||||||
|
if err := json.Unmarshal(body, &raw); err != nil {
|
||||||
|
return string(body), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type result struct {
|
||||||
|
Title string `json:"title"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var results []result
|
||||||
|
if web, ok := raw["web"].(map[string]any); ok {
|
||||||
|
if items, ok := web["results"].([]any); ok {
|
||||||
|
for _, item := range items {
|
||||||
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
r := result{}
|
||||||
|
if t, ok := m["title"].(string); ok {
|
||||||
|
r.Title = t
|
||||||
|
}
|
||||||
|
if u, ok := m["url"].(string); ok {
|
||||||
|
r.URL = u
|
||||||
|
}
|
||||||
|
if d, ok := m["description"].(string); ok {
|
||||||
|
r.Description = d
|
||||||
|
}
|
||||||
|
results = append(results, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out, _ := json.MarshalIndent(results, "", " ")
|
||||||
|
return string(out), nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
31
v2/tools/writefile.go
Normal file
31
v2/tools/writefile.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
llm "gitea.stevedudenhoeffer.com/steve/go-llm/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WriteFileParams defines parameters for the write file tool.
|
||||||
|
type WriteFileParams struct {
|
||||||
|
Path string `json:"path" description:"File path to write"`
|
||||||
|
Content string `json:"content" description:"Content to write to the file"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteFile creates a file writing tool.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// tools := llm.NewToolBox(tools.WriteFile())
|
||||||
|
func WriteFile() llm.Tool {
|
||||||
|
return llm.Define[WriteFileParams]("write_file", "Write content to a file (creates or overwrites)",
|
||||||
|
func(ctx context.Context, p WriteFileParams) (string, error) {
|
||||||
|
if err := os.WriteFile(p.Path, []byte(p.Content), 0644); err != nil {
|
||||||
|
return "", fmt.Errorf("writing file: %w", err)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("Successfully wrote %d bytes to %s", len(p.Content), p.Path), nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
100
v2/transcriber.go
Normal file
100
v2/transcriber.go
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Transcriber abstracts a speech-to-text model implementation.
|
||||||
|
type Transcriber = provider.Transcriber
|
||||||
|
|
||||||
|
// TranscriptionResponseFormat controls the output format requested from a transcriber.
|
||||||
|
type TranscriptionResponseFormat = provider.TranscriptionResponseFormat
|
||||||
|
|
||||||
|
const (
|
||||||
|
TranscriptionResponseFormatJSON = provider.TranscriptionResponseFormatJSON
|
||||||
|
TranscriptionResponseFormatVerboseJSON = provider.TranscriptionResponseFormatVerboseJSON
|
||||||
|
TranscriptionResponseFormatText = provider.TranscriptionResponseFormatText
|
||||||
|
TranscriptionResponseFormatSRT = provider.TranscriptionResponseFormatSRT
|
||||||
|
TranscriptionResponseFormatVTT = provider.TranscriptionResponseFormatVTT
|
||||||
|
)
|
||||||
|
|
||||||
|
// TranscriptionTimestampGranularity defines the requested timestamp detail.
|
||||||
|
type TranscriptionTimestampGranularity = provider.TranscriptionTimestampGranularity
|
||||||
|
|
||||||
|
const (
|
||||||
|
TranscriptionTimestampGranularityWord = provider.TranscriptionTimestampGranularityWord
|
||||||
|
TranscriptionTimestampGranularitySegment = provider.TranscriptionTimestampGranularitySegment
|
||||||
|
)
|
||||||
|
|
||||||
|
// TranscriptionOptions configures transcription behavior.
|
||||||
|
type TranscriptionOptions = provider.TranscriptionOptions
|
||||||
|
|
||||||
|
// Transcription captures a normalized transcription result.
|
||||||
|
type Transcription = provider.Transcription
|
||||||
|
|
||||||
|
// TranscriptionSegment provides a coarse time-sliced transcription segment.
|
||||||
|
type TranscriptionSegment = provider.TranscriptionSegment
|
||||||
|
|
||||||
|
// TranscriptionWord provides a word-level timestamp.
|
||||||
|
type TranscriptionWord = provider.TranscriptionWord
|
||||||
|
|
||||||
|
// TranscriptionTokenLogprob captures token-level log probability details.
|
||||||
|
type TranscriptionTokenLogprob = provider.TranscriptionTokenLogprob
|
||||||
|
|
||||||
|
// TranscriptionUsage captures token or duration usage details.
|
||||||
|
type TranscriptionUsage = provider.TranscriptionUsage
|
||||||
|
|
||||||
|
// TranscribeFile converts an audio file to WAV (via ffmpeg) and transcribes it.
|
||||||
|
func TranscribeFile(ctx context.Context, filename string, transcriber Transcriber, opts TranscriptionOptions) (Transcription, error) {
|
||||||
|
if transcriber == nil {
|
||||||
|
return Transcription{}, fmt.Errorf("transcriber is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
wav, err := audioFileToWav(ctx, filename)
|
||||||
|
if err != nil {
|
||||||
|
return Transcription{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return transcriber.Transcribe(ctx, wav, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func audioFileToWav(ctx context.Context, filename string) ([]byte, error) {
|
||||||
|
if filename == "" {
|
||||||
|
return nil, fmt.Errorf("filename is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.EqualFold(filepath.Ext(filename), ".wav") {
|
||||||
|
data, err := os.ReadFile(filename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read wav file: %w", err)
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tempFile, err := os.CreateTemp("", "go-llm-audio-*.wav")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create temp wav file: %w", err)
|
||||||
|
}
|
||||||
|
tempPath := tempFile.Name()
|
||||||
|
_ = tempFile.Close()
|
||||||
|
defer os.Remove(tempPath)
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "ffmpeg", "-hide_banner", "-loglevel", "error", "-y", "-i", filename, "-vn", "-f", "wav", tempPath)
|
||||||
|
if output, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return nil, fmt.Errorf("ffmpeg convert failed: %w (output: %s)", err, strings.TrimSpace(string(output)))
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(tempPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read converted wav file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user