feat(ollama): add automatic retry with exponential backoff for transient HTTP errors
CI / Build, Test & Lint (push) Successful in 10m50s

Ollama Cloud returns HTTP 503 when the model is temporarily overloaded,
429 on rate limit, and 502 on upstream failures. These are transient
conditions that resolve on retry. Previously they bubbled up as hard
errors, forcing callers to implement their own retry logic.

The retry is implemented at the HTTP transport level in doChatRequest,
so both Complete and Stream benefit transparently. Strategy: up to 3
retries with exponential backoff (1s, 2s, 4s), Retry-After header
respected for 429, context cancellation checked between retries.
Non-transient errors (400, 401, 403, 404, 500) are never retried.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-25 11:55:43 -04:00
parent 6bac4cb3ed
commit 67c3ebe067
2 changed files with 407 additions and 11 deletions
+83 -1
View File
@@ -12,8 +12,11 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/http" "net/http"
"strconv"
"strings" "strings"
"time"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
) )
@@ -25,6 +28,22 @@ const DefaultLocalBaseURL = "http://localhost:11434"
// DefaultCloudBaseURL is the default base URL for Ollama Cloud. // DefaultCloudBaseURL is the default base URL for Ollama Cloud.
const DefaultCloudBaseURL = "https://ollama.com" const DefaultCloudBaseURL = "https://ollama.com"
// retryMaxAttempts is the maximum number of retry attempts for transient HTTP
// errors (503, 429, 502). Total attempts = 1 initial + retryMaxAttempts.
const retryMaxAttempts = 3
// retryBaseDelay is the base delay for exponential backoff between retries.
// Actual delays: 1s, 2s, 4s (base * 2^attempt).
const retryBaseDelay = 1 * time.Second
// isTransientHTTPStatus reports whether the HTTP status code indicates a
// transient server-side condition that may resolve on retry.
func isTransientHTTPStatus(code int) bool {
return code == http.StatusBadGateway || // 502
code == http.StatusServiceUnavailable || // 503
code == http.StatusTooManyRequests // 429
}
// Provider implements provider.Provider over Ollama's native /api/chat // Provider implements provider.Provider over Ollama's native /api/chat
// endpoint. An empty apiKey means local-mode (no Authorization header sent); // endpoint. An empty apiKey means local-mode (no Authorization header sent);
// a non-empty apiKey is sent as a Bearer token (cloud-mode). // a non-empty apiKey is sent as a Bearer token (cloud-mode).
@@ -32,6 +51,10 @@ type Provider struct {
apiKey string apiKey string
baseURL string baseURL string
client *http.Client client *http.Client
// retryBaseDelayOverride, when non-zero, replaces retryBaseDelay for
// testing. Production code leaves this at the zero value.
retryBaseDelayOverride time.Duration
} }
// newNative constructs a native Ollama provider. Callers should use the // newNative constructs a native Ollama provider. Callers should use the
@@ -420,9 +443,13 @@ func (p *Provider) buildChatRequest(req provider.Request, stream bool) ([]byte,
} }
// doChatRequest POSTs the wire body to /api/chat and returns the raw HTTP // doChatRequest POSTs the wire body to /api/chat and returns the raw HTTP
// response. The caller is responsible for closing the response body. // response. Transient HTTP errors (502, 503, 429) are retried with exponential
// backoff up to retryMaxAttempts times. The caller is responsible for closing
// the response body.
func (p *Provider) doChatRequest(ctx context.Context, body []byte) (*http.Response, error) { func (p *Provider) doChatRequest(ctx context.Context, body []byte) (*http.Response, error) {
url := strings.TrimRight(p.baseURL, "/") + "/api/chat" url := strings.TrimRight(p.baseURL, "/") + "/api/chat"
for attempt := 0; ; attempt++ {
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil { if err != nil {
return nil, fmt.Errorf("ollama: build request: %w", err) return nil, fmt.Errorf("ollama: build request: %w", err)
@@ -431,13 +458,68 @@ func (p *Provider) doChatRequest(ctx context.Context, body []byte) (*http.Respon
if p.apiKey != "" { if p.apiKey != "" {
httpReq.Header.Set("Authorization", "Bearer "+p.apiKey) httpReq.Header.Set("Authorization", "Bearer "+p.apiKey)
} }
resp, err := p.client.Do(httpReq) resp, err := p.client.Do(httpReq)
if err != nil { if err != nil {
return nil, fmt.Errorf("ollama: HTTP request: %w", err) return nil, fmt.Errorf("ollama: HTTP request: %w", err)
} }
// On success or non-transient error, return immediately.
if !isTransientHTTPStatus(resp.StatusCode) || attempt >= retryMaxAttempts {
return resp, nil return resp, nil
} }
// Transient error — drain and close the body before retrying.
respBody, _ := io.ReadAll(resp.Body)
resp.Body.Close()
delay := retryBackoff(attempt, resp.Header, p.retryBaseDelayOverride)
slog.Info("ollama: retrying after transient HTTP error",
"status", resp.StatusCode,
"attempt", attempt+1,
"max_attempts", retryMaxAttempts,
"delay", delay,
"body", truncateBody(respBody, 200),
)
// Wait for backoff or context cancellation.
timer := time.NewTimer(delay)
select {
case <-ctx.Done():
timer.Stop()
return nil, ctx.Err()
case <-timer.C:
}
}
}
// retryBackoff computes the delay before the next retry attempt. It uses
// exponential backoff (base * 2^attempt), but respects the Retry-After header
// when present (for 429 responses). baseOverride, when non-zero, replaces the
// package-level retryBaseDelay constant (used by tests to avoid real waits).
func retryBackoff(attempt int, header http.Header, baseOverride time.Duration) time.Duration {
// Check Retry-After header (seconds value or HTTP-date; we only parse seconds).
if ra := header.Get("Retry-After"); ra != "" {
if secs, err := strconv.Atoi(ra); err == nil && secs > 0 {
return time.Duration(secs) * time.Second
}
}
base := retryBaseDelay
if baseOverride > 0 {
base = baseOverride
}
return base * (1 << attempt)
}
// truncateBody returns a string of at most maxLen bytes from b, appending
// "..." when truncated. Used for readable log output of error response bodies.
func truncateBody(b []byte, maxLen int) string {
if len(b) <= maxLen {
return string(b)
}
return string(b[:maxLen]) + "..."
}
// convertMessage maps a provider.Message into a native wire message. // convertMessage maps a provider.Message into a native wire message.
func convertMessage(msg provider.Message) (nativeChatMessage, error) { func convertMessage(msg provider.Message) (nativeChatMessage, error) {
out := nativeChatMessage{ out := nativeChatMessage{
+314
View File
@@ -571,3 +571,317 @@ func equalStrings(a, b []string) bool {
} }
return true return true
} }
// --- Retry tests ---
// newTestNative creates a Provider with a minimal retry delay so tests run fast.
func newTestNative(apiKey, baseURL string) *Provider {
p := newNative(apiKey, baseURL)
p.retryBaseDelayOverride = 1 * time.Millisecond
return p
}
func TestRetryOnTransientHTTPError(t *testing.T) {
t.Run("503 retries then succeeds", func(t *testing.T) {
var attempts int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Drain the request body so the client can reuse the connection.
_, _ = io.ReadAll(r.Body)
attempts++
if attempts <= 2 {
w.WriteHeader(503)
_, _ = w.Write([]byte(`{"error":"model is temporarily overloaded"}`))
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
_, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"ok"},"done":true,"prompt_eval_count":1,"eval_count":1}`))
}))
t.Cleanup(srv.Close)
p := newTestNative("key", srv.URL)
resp, err := p.Complete(context.Background(), provider.Request{
Model: "test",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
if err != nil {
t.Fatalf("expected success after retries, got error: %v", err)
}
if resp.Text != "ok" {
t.Errorf("Text: want %q, got %q", "ok", resp.Text)
}
if attempts != 3 {
t.Errorf("attempts: want 3, got %d", attempts)
}
})
t.Run("429 retries then succeeds", func(t *testing.T) {
var attempts int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
attempts++
if attempts == 1 {
w.WriteHeader(429)
_, _ = w.Write([]byte(`{"error":"rate limited"}`))
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
_, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"done"},"done":true}`))
}))
t.Cleanup(srv.Close)
p := newTestNative("", srv.URL)
resp, err := p.Complete(context.Background(), provider.Request{
Model: "test",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
if err != nil {
t.Fatalf("expected success after retry, got error: %v", err)
}
if resp.Text != "done" {
t.Errorf("Text: want %q, got %q", "done", resp.Text)
}
if attempts != 2 {
t.Errorf("attempts: want 2, got %d", attempts)
}
})
t.Run("502 retries then succeeds", func(t *testing.T) {
var attempts int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
attempts++
if attempts == 1 {
w.WriteHeader(502)
_, _ = w.Write([]byte(`Bad Gateway`))
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
_, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"yes"},"done":true}`))
}))
t.Cleanup(srv.Close)
p := newTestNative("", srv.URL)
resp, err := p.Complete(context.Background(), provider.Request{
Model: "test",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
if err != nil {
t.Fatalf("expected success after retry, got error: %v", err)
}
if resp.Text != "yes" {
t.Errorf("Text: want %q, got %q", "yes", resp.Text)
}
})
t.Run("exhausts retries and returns error", func(t *testing.T) {
var attempts int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
attempts++
w.WriteHeader(503)
_, _ = w.Write([]byte(`{"error":"overloaded"}`))
}))
t.Cleanup(srv.Close)
p := newTestNative("", srv.URL)
_, err := p.Complete(context.Background(), provider.Request{
Model: "test",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
if err == nil {
t.Fatal("expected error after exhausting retries")
}
if !strings.Contains(err.Error(), "503") {
t.Errorf("error should mention status 503, got: %v", err)
}
// 1 initial + 3 retries = 4 total
if attempts != 4 {
t.Errorf("attempts: want 4, got %d", attempts)
}
})
t.Run("400 is not retried", func(t *testing.T) {
var attempts int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
attempts++
w.WriteHeader(400)
_, _ = w.Write([]byte(`{"error":"bad request"}`))
}))
t.Cleanup(srv.Close)
p := newTestNative("", srv.URL)
_, err := p.Complete(context.Background(), provider.Request{
Model: "test",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
if err == nil {
t.Fatal("expected error for 400")
}
if attempts != 1 {
t.Errorf("attempts: want 1 (no retries for 400), got %d", attempts)
}
})
t.Run("context cancellation during backoff aborts retry", func(t *testing.T) {
var attempts int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
attempts++
w.WriteHeader(503)
_, _ = w.Write([]byte(`{"error":"overloaded"}`))
}))
t.Cleanup(srv.Close)
ctx, cancel := context.WithCancel(context.Background())
// Cancel shortly after the first attempt so the backoff wait is interrupted.
go func() {
time.Sleep(50 * time.Millisecond)
cancel()
}()
p := newTestNative("", srv.URL)
// Use a longer backoff so the context cancel fires during the wait.
p.retryBaseDelayOverride = 2 * time.Second
_, err := p.Complete(ctx, provider.Request{
Model: "test",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
if err == nil {
t.Fatal("expected error from cancelled context")
}
if !strings.Contains(err.Error(), "canceled") && !strings.Contains(err.Error(), "context") {
t.Errorf("expected context error, got: %v", err)
}
// Should have made only 1 attempt before the context cancelled during backoff.
if attempts != 1 {
t.Errorf("attempts: want 1, got %d", attempts)
}
})
t.Run("stream retries on 503 then succeeds", func(t *testing.T) {
var attempts int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
attempts++
if attempts == 1 {
w.WriteHeader(503)
_, _ = w.Write([]byte(`{"error":"overloaded"}`))
return
}
w.Header().Set("Content-Type", "application/x-ndjson")
w.WriteHeader(200)
_, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"streamed"},"done":true,"prompt_eval_count":1,"eval_count":1}` + "\n"))
}))
t.Cleanup(srv.Close)
p := newTestNative("", srv.URL)
events := collectStream(t, p, provider.Request{
Model: "test",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
var gotDone bool
for _, ev := range events {
if ev.Type == provider.StreamEventDone && ev.Response != nil {
if ev.Response.Text != "streamed" {
t.Errorf("Response.Text: want %q, got %q", "streamed", ev.Response.Text)
}
gotDone = true
}
}
if !gotDone {
t.Fatal("expected StreamEventDone")
}
if attempts != 2 {
t.Errorf("attempts: want 2, got %d", attempts)
}
})
}
func TestRetryBackoff(t *testing.T) {
t.Run("exponential backoff without Retry-After", func(t *testing.T) {
h := http.Header{}
d0 := retryBackoff(0, h, 0)
d1 := retryBackoff(1, h, 0)
d2 := retryBackoff(2, h, 0)
if d0 != 1*time.Second {
t.Errorf("attempt 0: want 1s, got %v", d0)
}
if d1 != 2*time.Second {
t.Errorf("attempt 1: want 2s, got %v", d1)
}
if d2 != 4*time.Second {
t.Errorf("attempt 2: want 4s, got %v", d2)
}
})
t.Run("Retry-After header overrides backoff", func(t *testing.T) {
h := http.Header{}
h.Set("Retry-After", "5")
d := retryBackoff(0, h, 0)
if d != 5*time.Second {
t.Errorf("want 5s from Retry-After, got %v", d)
}
})
t.Run("invalid Retry-After falls back to exponential", func(t *testing.T) {
h := http.Header{}
h.Set("Retry-After", "not-a-number")
d := retryBackoff(1, h, 0)
if d != 2*time.Second {
t.Errorf("want 2s fallback, got %v", d)
}
})
t.Run("baseOverride replaces default base delay", func(t *testing.T) {
h := http.Header{}
d := retryBackoff(0, h, 500*time.Millisecond)
if d != 500*time.Millisecond {
t.Errorf("attempt 0 with 500ms override: want 500ms, got %v", d)
}
d1 := retryBackoff(2, h, 500*time.Millisecond)
if d1 != 2*time.Second {
t.Errorf("attempt 2 with 500ms override: want 2s, got %v", d1)
}
})
}
func TestIsTransientHTTPStatus(t *testing.T) {
cases := []struct {
code int
want bool
}{
{200, false},
{400, false},
{401, false},
{403, false},
{404, false},
{429, true},
{500, false},
{502, true},
{503, true},
}
for _, c := range cases {
if got := isTransientHTTPStatus(c.code); got != c.want {
t.Errorf("isTransientHTTPStatus(%d): want %v, got %v", c.code, c.want, got)
}
}
}
func TestTruncateBody(t *testing.T) {
short := "hello"
if got := truncateBody([]byte(short), 10); got != short {
t.Errorf("short: want %q, got %q", short, got)
}
long := "hello world this is a very long string"
got := truncateBody([]byte(long), 10)
if got != "hello worl..." {
t.Errorf("long: want %q, got %q", "hello worl...", got)
}
}