feat(ollama): add automatic retry with exponential backoff for transient HTTP errors
CI / Build, Test & Lint (push) Successful in 10m50s
CI / Build, Test & Lint (push) Successful in 10m50s
Ollama Cloud returns HTTP 503 when the model is temporarily overloaded, 429 on rate limit, and 502 on upstream failures. These are transient conditions that resolve on retry. Previously they bubbled up as hard errors, forcing callers to implement their own retry logic. The retry is implemented at the HTTP transport level in doChatRequest, so both Complete and Stream benefit transparently. Strategy: up to 3 retries with exponential backoff (1s, 2s, 4s), Retry-After header respected for 429, context cancellation checked between retries. Non-transient errors (400, 401, 403, 404, 500) are never retried. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -571,3 +571,317 @@ func equalStrings(a, b []string) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// --- Retry tests ---
|
||||
|
||||
// newTestNative creates a Provider with a minimal retry delay so tests run fast.
|
||||
func newTestNative(apiKey, baseURL string) *Provider {
|
||||
p := newNative(apiKey, baseURL)
|
||||
p.retryBaseDelayOverride = 1 * time.Millisecond
|
||||
return p
|
||||
}
|
||||
|
||||
func TestRetryOnTransientHTTPError(t *testing.T) {
|
||||
t.Run("503 retries then succeeds", func(t *testing.T) {
|
||||
var attempts int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Drain the request body so the client can reuse the connection.
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
attempts++
|
||||
if attempts <= 2 {
|
||||
w.WriteHeader(503)
|
||||
_, _ = w.Write([]byte(`{"error":"model is temporarily overloaded"}`))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
_, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"ok"},"done":true,"prompt_eval_count":1,"eval_count":1}`))
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
p := newTestNative("key", srv.URL)
|
||||
resp, err := p.Complete(context.Background(), provider.Request{
|
||||
Model: "test",
|
||||
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected success after retries, got error: %v", err)
|
||||
}
|
||||
if resp.Text != "ok" {
|
||||
t.Errorf("Text: want %q, got %q", "ok", resp.Text)
|
||||
}
|
||||
if attempts != 3 {
|
||||
t.Errorf("attempts: want 3, got %d", attempts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("429 retries then succeeds", func(t *testing.T) {
|
||||
var attempts int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
attempts++
|
||||
if attempts == 1 {
|
||||
w.WriteHeader(429)
|
||||
_, _ = w.Write([]byte(`{"error":"rate limited"}`))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
_, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"done"},"done":true}`))
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
p := newTestNative("", srv.URL)
|
||||
resp, err := p.Complete(context.Background(), provider.Request{
|
||||
Model: "test",
|
||||
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected success after retry, got error: %v", err)
|
||||
}
|
||||
if resp.Text != "done" {
|
||||
t.Errorf("Text: want %q, got %q", "done", resp.Text)
|
||||
}
|
||||
if attempts != 2 {
|
||||
t.Errorf("attempts: want 2, got %d", attempts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("502 retries then succeeds", func(t *testing.T) {
|
||||
var attempts int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
attempts++
|
||||
if attempts == 1 {
|
||||
w.WriteHeader(502)
|
||||
_, _ = w.Write([]byte(`Bad Gateway`))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
_, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"yes"},"done":true}`))
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
p := newTestNative("", srv.URL)
|
||||
resp, err := p.Complete(context.Background(), provider.Request{
|
||||
Model: "test",
|
||||
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected success after retry, got error: %v", err)
|
||||
}
|
||||
if resp.Text != "yes" {
|
||||
t.Errorf("Text: want %q, got %q", "yes", resp.Text)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("exhausts retries and returns error", func(t *testing.T) {
|
||||
var attempts int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
attempts++
|
||||
w.WriteHeader(503)
|
||||
_, _ = w.Write([]byte(`{"error":"overloaded"}`))
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
p := newTestNative("", srv.URL)
|
||||
_, err := p.Complete(context.Background(), provider.Request{
|
||||
Model: "test",
|
||||
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error after exhausting retries")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "503") {
|
||||
t.Errorf("error should mention status 503, got: %v", err)
|
||||
}
|
||||
// 1 initial + 3 retries = 4 total
|
||||
if attempts != 4 {
|
||||
t.Errorf("attempts: want 4, got %d", attempts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("400 is not retried", func(t *testing.T) {
|
||||
var attempts int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
attempts++
|
||||
w.WriteHeader(400)
|
||||
_, _ = w.Write([]byte(`{"error":"bad request"}`))
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
p := newTestNative("", srv.URL)
|
||||
_, err := p.Complete(context.Background(), provider.Request{
|
||||
Model: "test",
|
||||
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 400")
|
||||
}
|
||||
if attempts != 1 {
|
||||
t.Errorf("attempts: want 1 (no retries for 400), got %d", attempts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("context cancellation during backoff aborts retry", func(t *testing.T) {
|
||||
var attempts int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
attempts++
|
||||
w.WriteHeader(503)
|
||||
_, _ = w.Write([]byte(`{"error":"overloaded"}`))
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// Cancel shortly after the first attempt so the backoff wait is interrupted.
|
||||
go func() {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
p := newTestNative("", srv.URL)
|
||||
// Use a longer backoff so the context cancel fires during the wait.
|
||||
p.retryBaseDelayOverride = 2 * time.Second
|
||||
_, err := p.Complete(ctx, provider.Request{
|
||||
Model: "test",
|
||||
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from cancelled context")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "canceled") && !strings.Contains(err.Error(), "context") {
|
||||
t.Errorf("expected context error, got: %v", err)
|
||||
}
|
||||
// Should have made only 1 attempt before the context cancelled during backoff.
|
||||
if attempts != 1 {
|
||||
t.Errorf("attempts: want 1, got %d", attempts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("stream retries on 503 then succeeds", func(t *testing.T) {
|
||||
var attempts int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.ReadAll(r.Body)
|
||||
attempts++
|
||||
if attempts == 1 {
|
||||
w.WriteHeader(503)
|
||||
_, _ = w.Write([]byte(`{"error":"overloaded"}`))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||
w.WriteHeader(200)
|
||||
_, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"streamed"},"done":true,"prompt_eval_count":1,"eval_count":1}` + "\n"))
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
p := newTestNative("", srv.URL)
|
||||
events := collectStream(t, p, provider.Request{
|
||||
Model: "test",
|
||||
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
||||
})
|
||||
|
||||
var gotDone bool
|
||||
for _, ev := range events {
|
||||
if ev.Type == provider.StreamEventDone && ev.Response != nil {
|
||||
if ev.Response.Text != "streamed" {
|
||||
t.Errorf("Response.Text: want %q, got %q", "streamed", ev.Response.Text)
|
||||
}
|
||||
gotDone = true
|
||||
}
|
||||
}
|
||||
if !gotDone {
|
||||
t.Fatal("expected StreamEventDone")
|
||||
}
|
||||
if attempts != 2 {
|
||||
t.Errorf("attempts: want 2, got %d", attempts)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRetryBackoff(t *testing.T) {
|
||||
t.Run("exponential backoff without Retry-After", func(t *testing.T) {
|
||||
h := http.Header{}
|
||||
d0 := retryBackoff(0, h, 0)
|
||||
d1 := retryBackoff(1, h, 0)
|
||||
d2 := retryBackoff(2, h, 0)
|
||||
|
||||
if d0 != 1*time.Second {
|
||||
t.Errorf("attempt 0: want 1s, got %v", d0)
|
||||
}
|
||||
if d1 != 2*time.Second {
|
||||
t.Errorf("attempt 1: want 2s, got %v", d1)
|
||||
}
|
||||
if d2 != 4*time.Second {
|
||||
t.Errorf("attempt 2: want 4s, got %v", d2)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Retry-After header overrides backoff", func(t *testing.T) {
|
||||
h := http.Header{}
|
||||
h.Set("Retry-After", "5")
|
||||
d := retryBackoff(0, h, 0)
|
||||
if d != 5*time.Second {
|
||||
t.Errorf("want 5s from Retry-After, got %v", d)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid Retry-After falls back to exponential", func(t *testing.T) {
|
||||
h := http.Header{}
|
||||
h.Set("Retry-After", "not-a-number")
|
||||
d := retryBackoff(1, h, 0)
|
||||
if d != 2*time.Second {
|
||||
t.Errorf("want 2s fallback, got %v", d)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("baseOverride replaces default base delay", func(t *testing.T) {
|
||||
h := http.Header{}
|
||||
d := retryBackoff(0, h, 500*time.Millisecond)
|
||||
if d != 500*time.Millisecond {
|
||||
t.Errorf("attempt 0 with 500ms override: want 500ms, got %v", d)
|
||||
}
|
||||
d1 := retryBackoff(2, h, 500*time.Millisecond)
|
||||
if d1 != 2*time.Second {
|
||||
t.Errorf("attempt 2 with 500ms override: want 2s, got %v", d1)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsTransientHTTPStatus(t *testing.T) {
|
||||
cases := []struct {
|
||||
code int
|
||||
want bool
|
||||
}{
|
||||
{200, false},
|
||||
{400, false},
|
||||
{401, false},
|
||||
{403, false},
|
||||
{404, false},
|
||||
{429, true},
|
||||
{500, false},
|
||||
{502, true},
|
||||
{503, true},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if got := isTransientHTTPStatus(c.code); got != c.want {
|
||||
t.Errorf("isTransientHTTPStatus(%d): want %v, got %v", c.code, c.want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateBody(t *testing.T) {
|
||||
short := "hello"
|
||||
if got := truncateBody([]byte(short), 10); got != short {
|
||||
t.Errorf("short: want %q, got %q", short, got)
|
||||
}
|
||||
|
||||
long := "hello world this is a very long string"
|
||||
got := truncateBody([]byte(long), 10)
|
||||
if got != "hello worl..." {
|
||||
t.Errorf("long: want %q, got %q", "hello worl...", got)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user