feat(ollama): add automatic retry with exponential backoff for transient HTTP errors
CI / Build, Test & Lint (push) Successful in 10m50s

Ollama Cloud returns HTTP 503 when the model is temporarily overloaded,
429 on rate limit, and 502 on upstream failures. These are transient
conditions that resolve on retry. Previously they bubbled up as hard
errors, forcing callers to implement their own retry logic.

The retry is implemented at the HTTP transport level in doChatRequest,
so both Complete and Stream benefit transparently. Strategy: up to 3
retries with exponential backoff (1s, 2s, 4s), Retry-After header
respected for 429, context cancellation checked between retries.
Non-transient errors (400, 401, 403, 404, 500) are never retried.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-25 11:55:43 -04:00
parent 6bac4cb3ed
commit 67c3ebe067
2 changed files with 407 additions and 11 deletions
+314
View File
@@ -571,3 +571,317 @@ func equalStrings(a, b []string) bool {
}
return true
}
// --- Retry tests ---
// newTestNative creates a Provider with a minimal retry delay so tests run fast.
func newTestNative(apiKey, baseURL string) *Provider {
p := newNative(apiKey, baseURL)
p.retryBaseDelayOverride = 1 * time.Millisecond
return p
}
func TestRetryOnTransientHTTPError(t *testing.T) {
t.Run("503 retries then succeeds", func(t *testing.T) {
var attempts int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Drain the request body so the client can reuse the connection.
_, _ = io.ReadAll(r.Body)
attempts++
if attempts <= 2 {
w.WriteHeader(503)
_, _ = w.Write([]byte(`{"error":"model is temporarily overloaded"}`))
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
_, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"ok"},"done":true,"prompt_eval_count":1,"eval_count":1}`))
}))
t.Cleanup(srv.Close)
p := newTestNative("key", srv.URL)
resp, err := p.Complete(context.Background(), provider.Request{
Model: "test",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
if err != nil {
t.Fatalf("expected success after retries, got error: %v", err)
}
if resp.Text != "ok" {
t.Errorf("Text: want %q, got %q", "ok", resp.Text)
}
if attempts != 3 {
t.Errorf("attempts: want 3, got %d", attempts)
}
})
t.Run("429 retries then succeeds", func(t *testing.T) {
var attempts int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
attempts++
if attempts == 1 {
w.WriteHeader(429)
_, _ = w.Write([]byte(`{"error":"rate limited"}`))
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
_, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"done"},"done":true}`))
}))
t.Cleanup(srv.Close)
p := newTestNative("", srv.URL)
resp, err := p.Complete(context.Background(), provider.Request{
Model: "test",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
if err != nil {
t.Fatalf("expected success after retry, got error: %v", err)
}
if resp.Text != "done" {
t.Errorf("Text: want %q, got %q", "done", resp.Text)
}
if attempts != 2 {
t.Errorf("attempts: want 2, got %d", attempts)
}
})
t.Run("502 retries then succeeds", func(t *testing.T) {
var attempts int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
attempts++
if attempts == 1 {
w.WriteHeader(502)
_, _ = w.Write([]byte(`Bad Gateway`))
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
_, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"yes"},"done":true}`))
}))
t.Cleanup(srv.Close)
p := newTestNative("", srv.URL)
resp, err := p.Complete(context.Background(), provider.Request{
Model: "test",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
if err != nil {
t.Fatalf("expected success after retry, got error: %v", err)
}
if resp.Text != "yes" {
t.Errorf("Text: want %q, got %q", "yes", resp.Text)
}
})
t.Run("exhausts retries and returns error", func(t *testing.T) {
var attempts int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
attempts++
w.WriteHeader(503)
_, _ = w.Write([]byte(`{"error":"overloaded"}`))
}))
t.Cleanup(srv.Close)
p := newTestNative("", srv.URL)
_, err := p.Complete(context.Background(), provider.Request{
Model: "test",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
if err == nil {
t.Fatal("expected error after exhausting retries")
}
if !strings.Contains(err.Error(), "503") {
t.Errorf("error should mention status 503, got: %v", err)
}
// 1 initial + 3 retries = 4 total
if attempts != 4 {
t.Errorf("attempts: want 4, got %d", attempts)
}
})
t.Run("400 is not retried", func(t *testing.T) {
var attempts int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
attempts++
w.WriteHeader(400)
_, _ = w.Write([]byte(`{"error":"bad request"}`))
}))
t.Cleanup(srv.Close)
p := newTestNative("", srv.URL)
_, err := p.Complete(context.Background(), provider.Request{
Model: "test",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
if err == nil {
t.Fatal("expected error for 400")
}
if attempts != 1 {
t.Errorf("attempts: want 1 (no retries for 400), got %d", attempts)
}
})
t.Run("context cancellation during backoff aborts retry", func(t *testing.T) {
var attempts int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
attempts++
w.WriteHeader(503)
_, _ = w.Write([]byte(`{"error":"overloaded"}`))
}))
t.Cleanup(srv.Close)
ctx, cancel := context.WithCancel(context.Background())
// Cancel shortly after the first attempt so the backoff wait is interrupted.
go func() {
time.Sleep(50 * time.Millisecond)
cancel()
}()
p := newTestNative("", srv.URL)
// Use a longer backoff so the context cancel fires during the wait.
p.retryBaseDelayOverride = 2 * time.Second
_, err := p.Complete(ctx, provider.Request{
Model: "test",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
if err == nil {
t.Fatal("expected error from cancelled context")
}
if !strings.Contains(err.Error(), "canceled") && !strings.Contains(err.Error(), "context") {
t.Errorf("expected context error, got: %v", err)
}
// Should have made only 1 attempt before the context cancelled during backoff.
if attempts != 1 {
t.Errorf("attempts: want 1, got %d", attempts)
}
})
t.Run("stream retries on 503 then succeeds", func(t *testing.T) {
var attempts int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
attempts++
if attempts == 1 {
w.WriteHeader(503)
_, _ = w.Write([]byte(`{"error":"overloaded"}`))
return
}
w.Header().Set("Content-Type", "application/x-ndjson")
w.WriteHeader(200)
_, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"streamed"},"done":true,"prompt_eval_count":1,"eval_count":1}` + "\n"))
}))
t.Cleanup(srv.Close)
p := newTestNative("", srv.URL)
events := collectStream(t, p, provider.Request{
Model: "test",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
var gotDone bool
for _, ev := range events {
if ev.Type == provider.StreamEventDone && ev.Response != nil {
if ev.Response.Text != "streamed" {
t.Errorf("Response.Text: want %q, got %q", "streamed", ev.Response.Text)
}
gotDone = true
}
}
if !gotDone {
t.Fatal("expected StreamEventDone")
}
if attempts != 2 {
t.Errorf("attempts: want 2, got %d", attempts)
}
})
}
func TestRetryBackoff(t *testing.T) {
t.Run("exponential backoff without Retry-After", func(t *testing.T) {
h := http.Header{}
d0 := retryBackoff(0, h, 0)
d1 := retryBackoff(1, h, 0)
d2 := retryBackoff(2, h, 0)
if d0 != 1*time.Second {
t.Errorf("attempt 0: want 1s, got %v", d0)
}
if d1 != 2*time.Second {
t.Errorf("attempt 1: want 2s, got %v", d1)
}
if d2 != 4*time.Second {
t.Errorf("attempt 2: want 4s, got %v", d2)
}
})
t.Run("Retry-After header overrides backoff", func(t *testing.T) {
h := http.Header{}
h.Set("Retry-After", "5")
d := retryBackoff(0, h, 0)
if d != 5*time.Second {
t.Errorf("want 5s from Retry-After, got %v", d)
}
})
t.Run("invalid Retry-After falls back to exponential", func(t *testing.T) {
h := http.Header{}
h.Set("Retry-After", "not-a-number")
d := retryBackoff(1, h, 0)
if d != 2*time.Second {
t.Errorf("want 2s fallback, got %v", d)
}
})
t.Run("baseOverride replaces default base delay", func(t *testing.T) {
h := http.Header{}
d := retryBackoff(0, h, 500*time.Millisecond)
if d != 500*time.Millisecond {
t.Errorf("attempt 0 with 500ms override: want 500ms, got %v", d)
}
d1 := retryBackoff(2, h, 500*time.Millisecond)
if d1 != 2*time.Second {
t.Errorf("attempt 2 with 500ms override: want 2s, got %v", d1)
}
})
}
func TestIsTransientHTTPStatus(t *testing.T) {
cases := []struct {
code int
want bool
}{
{200, false},
{400, false},
{401, false},
{403, false},
{404, false},
{429, true},
{500, false},
{502, true},
{503, true},
}
for _, c := range cases {
if got := isTransientHTTPStatus(c.code); got != c.want {
t.Errorf("isTransientHTTPStatus(%d): want %v, got %v", c.code, c.want, got)
}
}
}
func TestTruncateBody(t *testing.T) {
short := "hello"
if got := truncateBody([]byte(short), 10); got != short {
t.Errorf("short: want %q, got %q", short, got)
}
long := "hello world this is a very long string"
got := truncateBody([]byte(long), 10)
if got != "hello worl..." {
t.Errorf("long: want %q, got %q", "hello worl...", got)
}
}