Files
steve 67c3ebe067
CI / Build, Test & Lint (push) Successful in 10m50s
feat(ollama): add automatic retry with exponential backoff for transient HTTP errors
Ollama Cloud returns HTTP 503 when the model is temporarily overloaded,
429 on rate limit, and 502 on upstream failures. These are transient
conditions that resolve on retry. Previously they bubbled up as hard
errors, forcing callers to implement their own retry logic.

The retry is implemented at the HTTP transport level in doChatRequest,
so both Complete and Stream benefit transparently. Strategy: up to 3
retries with exponential backoff (1s, 2s, 4s), Retry-After header
respected for 429, context cancellation checked between retries.
Non-transient errors (400, 401, 403, 404, 500) are never retried.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 11:58:25 -04:00

618 lines
18 KiB
Go

// Package ollama implements the go-llm v2 provider interface for Ollama,
// targeting Ollama's native /api/chat endpoint. Supports both local Ollama
// instances (no API key) and Ollama Cloud (https://ollama.com, requires an
// API key).
package ollama
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strconv"
"strings"
"time"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
// DefaultLocalBaseURL is the default base URL for a locally-running Ollama
// instance.
const DefaultLocalBaseURL = "http://localhost:11434"
// DefaultCloudBaseURL is the default base URL for Ollama Cloud.
const DefaultCloudBaseURL = "https://ollama.com"
// retryMaxAttempts is the maximum number of retry attempts for transient HTTP
// errors (503, 429, 502). Total attempts = 1 initial + retryMaxAttempts.
const retryMaxAttempts = 3
// retryBaseDelay is the base delay for exponential backoff between retries.
// Actual delays: 1s, 2s, 4s (base * 2^attempt).
const retryBaseDelay = 1 * time.Second
// isTransientHTTPStatus reports whether the HTTP status code indicates a
// transient server-side condition that may resolve on retry.
func isTransientHTTPStatus(code int) bool {
return code == http.StatusBadGateway || // 502
code == http.StatusServiceUnavailable || // 503
code == http.StatusTooManyRequests // 429
}
// Provider implements provider.Provider over Ollama's native /api/chat
// endpoint. An empty apiKey means local-mode (no Authorization header sent);
// a non-empty apiKey is sent as a Bearer token (cloud-mode).
type Provider struct {
apiKey string
baseURL string
client *http.Client
// retryBaseDelayOverride, when non-zero, replaces retryBaseDelay for
// testing. Production code leaves this at the zero value.
retryBaseDelayOverride time.Duration
}
// newNative constructs a native Ollama provider. Callers should use the
// package-level New() constructor or the v2 llm.Ollama() / llm.OllamaCloud()
// helpers.
func newNative(apiKey, baseURL string) *Provider {
return &Provider{
apiKey: apiKey,
baseURL: baseURL,
client: &http.Client{},
}
}
// nativeChatRequest is the JSON body POSTed to /api/chat.
type nativeChatRequest struct {
Model string `json:"model"`
Messages []nativeChatMessage `json:"messages"`
Tools []nativeToolDef `json:"tools,omitempty"`
Stream bool `json:"stream"`
// Think is polymorphic — Ollama accepts true/false or "low"/"medium"/"high".
Think json.RawMessage `json:"think,omitempty"`
Options map[string]any `json:"options,omitempty"`
}
// nativeChatMessage is one entry in the messages array on the wire. It also
// carries assistant tool calls and tool-role responses.
type nativeChatMessage struct {
Role string `json:"role"`
Content string `json:"content,omitempty"`
Images []string `json:"images,omitempty"`
ToolCalls []nativeToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Thinking string `json:"thinking,omitempty"`
}
// nativeToolCall mirrors Ollama's tool-call wire shape: a function with name
// and JSON-encoded arguments. Ollama's spec doesn't require an id, but some
// builds and some streaming chunks include one — we accept it on both wire and
// internal sides.
type nativeToolCall struct {
ID string `json:"id,omitempty"`
Function nativeFunctionCall `json:"function"`
}
type nativeFunctionCall struct {
Index *int `json:"index,omitempty"`
Name string `json:"name,omitempty"`
Arguments json.RawMessage `json:"arguments,omitempty"`
}
// nativeChatResponse is the JSON body returned from a non-streaming /api/chat
// call (and is also the per-line shape during streaming).
type nativeChatResponse struct {
Model string `json:"model,omitempty"`
Message nativeChatMessage `json:"message"`
Done bool `json:"done"`
DoneReason string `json:"done_reason,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
TotalDuration int64 `json:"total_duration,omitempty"`
}
// nativeToolDef is the wire shape of a tool definition sent to Ollama.
type nativeToolDef struct {
Type string `json:"type"`
Function nativeFunctionDef `json:"function"`
}
type nativeFunctionDef struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters map[string]any `json:"parameters,omitempty"`
}
// encodeThink converts a go-llm Reasoning string ("", "low", "medium",
// "high", or the literal strings "true"/"false") into Ollama's polymorphic
// `think` field. Returns nil for the empty string so the field is omitted.
func encodeThink(reasoning string) json.RawMessage {
switch reasoning {
case "":
return nil
case "true":
return json.RawMessage(`true`)
case "false":
return json.RawMessage(`false`)
default:
// "low" / "medium" / "high" — encode as a JSON string.
b, _ := json.Marshal(reasoning)
return b
}
}
// Complete performs a non-streaming chat completion via /api/chat.
func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
body, err := p.buildChatRequest(req, false)
if err != nil {
return provider.Response{}, err
}
httpResp, err := p.doChatRequest(ctx, body)
if err != nil {
return provider.Response{}, err
}
defer httpResp.Body.Close()
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
return provider.Response{}, fmt.Errorf("ollama: HTTP %d: %s", httpResp.StatusCode, string(b))
}
var chat nativeChatResponse
if err := json.NewDecoder(httpResp.Body).Decode(&chat); err != nil {
return provider.Response{}, fmt.Errorf("ollama: decode response: %w", err)
}
resp := provider.Response{
Text: chat.Message.Content,
Thinking: chat.Message.Thinking,
}
for i, tc := range chat.Message.ToolCalls {
resp.ToolCalls = append(resp.ToolCalls, provider.ToolCall{
ID: toolCallID(tc, i),
Name: tc.Function.Name,
Arguments: rawMessageToArgString(tc.Function.Arguments),
})
}
if chat.PromptEvalCount > 0 || chat.EvalCount > 0 {
resp.Usage = &provider.Usage{
InputTokens: chat.PromptEvalCount,
OutputTokens: chat.EvalCount,
TotalTokens: chat.PromptEvalCount + chat.EvalCount,
}
}
return resp, nil
}
// Stream performs a streaming chat completion via /api/chat with
// `stream: true`, parsing NDJSON line-by-line. Tool-call argument deltas are
// accumulated across chunks keyed by id (or function index) and finalized
// when the upstream Done flag arrives.
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
defer close(events)
body, err := p.buildChatRequest(req, true)
if err != nil {
return err
}
httpResp, err := p.doChatRequest(ctx, body)
if err != nil {
return err
}
defer httpResp.Body.Close()
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
return fmt.Errorf("ollama: HTTP %d: %s", httpResp.StatusCode, string(b))
}
scanner := bufio.NewScanner(httpResp.Body)
// Ollama can emit multi-KB lines on tool-call deltas. Generous buffer.
const maxLineSize = 4 * 1024 * 1024
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
type toolAcc struct {
id string
name string
args strings.Builder
index int // ToolIndex emitted on stream events
}
tools := map[string]*toolAcc{}
var toolOrder []*toolAcc
var (
fullText strings.Builder
fullThinking strings.Builder
usage *provider.Usage
streamErr error
)
for scanner.Scan() {
line := scanner.Bytes()
if len(bytes.TrimSpace(line)) == 0 {
continue
}
var chunk nativeChatResponse
if err := json.Unmarshal(line, &chunk); err != nil {
streamErr = fmt.Errorf("ollama: decode stream chunk: %w", err)
break
}
if chunk.Message.Thinking != "" {
fullThinking.WriteString(chunk.Message.Thinking)
events <- provider.StreamEvent{
Type: provider.StreamEventThinking,
Text: chunk.Message.Thinking,
}
}
if chunk.Message.Content != "" {
fullText.WriteString(chunk.Message.Content)
events <- provider.StreamEvent{
Type: provider.StreamEventText,
Text: chunk.Message.Content,
}
}
for pos, tc := range chunk.Message.ToolCalls {
key := streamToolKey(tc, pos)
acc, exists := tools[key]
if !exists {
acc = &toolAcc{
id: tc.ID,
name: tc.Function.Name,
index: len(toolOrder),
}
if acc.id == "" {
acc.id = fmt.Sprintf("tc_%d", acc.index)
}
tools[key] = acc
toolOrder = append(toolOrder, acc)
events <- provider.StreamEvent{
Type: provider.StreamEventToolStart,
ToolIndex: acc.index,
ToolCall: &provider.ToolCall{
ID: acc.id,
Name: acc.name,
},
}
} else {
// Continuation chunk may carry the tool's name late; capture it.
if tc.Function.Name != "" && acc.name == "" {
acc.name = tc.Function.Name
}
}
delta := decodeArgumentDelta(tc.Function.Arguments)
if delta != "" {
acc.args.WriteString(delta)
events <- provider.StreamEvent{
Type: provider.StreamEventToolDelta,
ToolIndex: acc.index,
ToolCall: &provider.ToolCall{
Arguments: delta,
},
}
}
}
if chunk.Done {
if chunk.PromptEvalCount > 0 || chunk.EvalCount > 0 {
usage = &provider.Usage{
InputTokens: chunk.PromptEvalCount,
OutputTokens: chunk.EvalCount,
TotalTokens: chunk.PromptEvalCount + chunk.EvalCount,
}
}
break
}
}
if err := scanner.Err(); err != nil && streamErr == nil {
streamErr = fmt.Errorf("ollama: stream read: %w", err)
}
if streamErr != nil {
events <- provider.StreamEvent{
Type: provider.StreamEventError,
Error: streamErr,
}
return streamErr
}
// Finalize accumulated tool calls.
finalCalls := make([]provider.ToolCall, 0, len(toolOrder))
for _, acc := range toolOrder {
args := acc.args.String()
if args == "" {
args = "{}"
}
final := provider.ToolCall{
ID: acc.id,
Name: acc.name,
Arguments: args,
}
finalCalls = append(finalCalls, final)
events <- provider.StreamEvent{
Type: provider.StreamEventToolEnd,
ToolIndex: acc.index,
ToolCall: &final,
}
}
events <- provider.StreamEvent{
Type: provider.StreamEventDone,
Response: &provider.Response{
Text: fullText.String(),
Thinking: fullThinking.String(),
ToolCalls: finalCalls,
Usage: usage,
},
}
return nil
}
// streamToolKey computes a stable map key correlating tool-call deltas
// across stream chunks. Prefer the wire id, fall back to function index,
// finally fall back to the tool's position in the chunk's tool_calls array
// (a single-tool stream collapses cleanly under any strategy).
func streamToolKey(tc nativeToolCall, position int) string {
if tc.ID != "" {
return "id:" + tc.ID
}
if tc.Function.Index != nil {
return fmt.Sprintf("idx:%d", *tc.Function.Index)
}
return fmt.Sprintf("pos:%d", position)
}
// decodeArgumentDelta returns the string fragment to append when a streamed
// tool-call chunk includes arguments. Ollama may emit arguments either as a
// JSON-encoded string fragment (chunk-by-chunk concatenation, openaicompat
// style) or as a complete object value (one-shot delivery). We accept both:
// strings are unwrapped, objects/arrays pass through verbatim.
func decodeArgumentDelta(raw json.RawMessage) string {
if len(raw) == 0 {
return ""
}
trimmed := bytes.TrimSpace(raw)
if len(trimmed) == 0 || string(trimmed) == "null" {
return ""
}
if trimmed[0] == '"' {
var s string
if err := json.Unmarshal(trimmed, &s); err == nil {
return s
}
}
return string(trimmed)
}
// buildChatRequest converts a provider.Request into the native wire body
// JSON. stream toggles the stream flag (true for /api/chat streaming).
func (p *Provider) buildChatRequest(req provider.Request, stream bool) ([]byte, error) {
wire := nativeChatRequest{
Model: req.Model,
Stream: stream,
Think: encodeThink(req.Reasoning),
}
for _, msg := range req.Messages {
m, err := convertMessage(msg)
if err != nil {
return nil, err
}
wire.Messages = append(wire.Messages, m)
}
for _, t := range req.Tools {
wire.Tools = append(wire.Tools, nativeToolDef{
Type: "function",
Function: nativeFunctionDef{
Name: t.Name,
Description: t.Description,
Parameters: t.Schema,
},
})
}
if req.Temperature != nil || req.MaxTokens != nil || req.TopP != nil || len(req.Stop) > 0 {
wire.Options = map[string]any{}
if req.Temperature != nil {
wire.Options["temperature"] = *req.Temperature
}
if req.TopP != nil {
wire.Options["top_p"] = *req.TopP
}
if req.MaxTokens != nil {
wire.Options["num_predict"] = *req.MaxTokens
}
if len(req.Stop) > 0 {
wire.Options["stop"] = req.Stop
}
}
return json.Marshal(wire)
}
// doChatRequest POSTs the wire body to /api/chat and returns the raw HTTP
// response. Transient HTTP errors (502, 503, 429) are retried with exponential
// backoff up to retryMaxAttempts times. The caller is responsible for closing
// the response body.
func (p *Provider) doChatRequest(ctx context.Context, body []byte) (*http.Response, error) {
url := strings.TrimRight(p.baseURL, "/") + "/api/chat"
for attempt := 0; ; attempt++ {
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("ollama: build request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
if p.apiKey != "" {
httpReq.Header.Set("Authorization", "Bearer "+p.apiKey)
}
resp, err := p.client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("ollama: HTTP request: %w", err)
}
// On success or non-transient error, return immediately.
if !isTransientHTTPStatus(resp.StatusCode) || attempt >= retryMaxAttempts {
return resp, nil
}
// Transient error — drain and close the body before retrying.
respBody, _ := io.ReadAll(resp.Body)
resp.Body.Close()
delay := retryBackoff(attempt, resp.Header, p.retryBaseDelayOverride)
slog.Info("ollama: retrying after transient HTTP error",
"status", resp.StatusCode,
"attempt", attempt+1,
"max_attempts", retryMaxAttempts,
"delay", delay,
"body", truncateBody(respBody, 200),
)
// Wait for backoff or context cancellation.
timer := time.NewTimer(delay)
select {
case <-ctx.Done():
timer.Stop()
return nil, ctx.Err()
case <-timer.C:
}
}
}
// retryBackoff computes the delay before the next retry attempt. It uses
// exponential backoff (base * 2^attempt), but respects the Retry-After header
// when present (for 429 responses). baseOverride, when non-zero, replaces the
// package-level retryBaseDelay constant (used by tests to avoid real waits).
func retryBackoff(attempt int, header http.Header, baseOverride time.Duration) time.Duration {
// Check Retry-After header (seconds value or HTTP-date; we only parse seconds).
if ra := header.Get("Retry-After"); ra != "" {
if secs, err := strconv.Atoi(ra); err == nil && secs > 0 {
return time.Duration(secs) * time.Second
}
}
base := retryBaseDelay
if baseOverride > 0 {
base = baseOverride
}
return base * (1 << attempt)
}
// truncateBody returns a string of at most maxLen bytes from b, appending
// "..." when truncated. Used for readable log output of error response bodies.
func truncateBody(b []byte, maxLen int) string {
if len(b) <= maxLen {
return string(b)
}
return string(b[:maxLen]) + "..."
}
// convertMessage maps a provider.Message into a native wire message.
func convertMessage(msg provider.Message) (nativeChatMessage, error) {
out := nativeChatMessage{
Role: msg.Role,
Content: msg.Content,
ToolCallID: msg.ToolCallID,
}
for _, img := range msg.Images {
b64, err := imageToBase64(img)
if err != nil {
return nativeChatMessage{}, err
}
if b64 != "" {
out.Images = append(out.Images, b64)
}
}
for i, tc := range msg.ToolCalls {
raw := json.RawMessage(strings.TrimSpace(tc.Arguments))
if len(raw) == 0 {
raw = json.RawMessage(`{}`)
}
// Preserve a stable index so streaming peers can correlate deltas.
idx := i
out.ToolCalls = append(out.ToolCalls, nativeToolCall{
ID: tc.ID,
Function: nativeFunctionCall{
Index: &idx,
Name: tc.Name,
Arguments: raw,
},
})
}
return out, nil
}
// imageToBase64 returns the base64-encoded payload of an image, fetching
// URL-only images over HTTP if no inline base64 is supplied.
func imageToBase64(img provider.Image) (string, error) {
if img.Base64 != "" {
return img.Base64, nil
}
if img.URL == "" {
return "", nil
}
resp, err := http.Get(img.URL)
if err != nil {
return "", fmt.Errorf("ollama: fetch image %q: %w", img.URL, err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", fmt.Errorf("ollama: fetch image %q: HTTP %d", img.URL, resp.StatusCode)
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("ollama: read image %q: %w", img.URL, err)
}
return base64.StdEncoding.EncodeToString(data), nil
}
// rawMessageToArgString converts a JSON-encoded arguments value into the
// string form the provider package uses for ToolCall.Arguments. Object/array
// values pass through verbatim; bare string values (some Ollama builds emit
// pre-stringified arguments) are unwrapped.
func rawMessageToArgString(raw json.RawMessage) string {
if len(raw) == 0 {
return "{}"
}
trimmed := strings.TrimSpace(string(raw))
if len(trimmed) == 0 {
return "{}"
}
if trimmed[0] == '"' {
var s string
if err := json.Unmarshal([]byte(trimmed), &s); err == nil {
return s
}
}
return trimmed
}
// toolCallID returns a stable identifier for a tool call. Ollama's native
// API typically does not include an id, so we synthesize one from the index
// when missing.
func toolCallID(tc nativeToolCall, index int) string {
if tc.ID != "" {
return tc.ID
}
if tc.Function.Index != nil {
return fmt.Sprintf("tc_%d", *tc.Function.Index)
}
return fmt.Sprintf("tc_%d", index)
}