diff --git a/v2/ollama/native.go b/v2/ollama/native.go index 5fd1c83..acbf7ad 100644 --- a/v2/ollama/native.go +++ b/v2/ollama/native.go @@ -12,8 +12,11 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "net/http" + "strconv" "strings" + "time" "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" ) @@ -25,6 +28,22 @@ const DefaultLocalBaseURL = "http://localhost:11434" // DefaultCloudBaseURL is the default base URL for Ollama Cloud. const DefaultCloudBaseURL = "https://ollama.com" +// retryMaxAttempts is the maximum number of retry attempts for transient HTTP +// errors (503, 429, 502). Total attempts = 1 initial + retryMaxAttempts. +const retryMaxAttempts = 3 + +// retryBaseDelay is the base delay for exponential backoff between retries. +// Actual delays: 1s, 2s, 4s (base * 2^attempt). +const retryBaseDelay = 1 * time.Second + +// isTransientHTTPStatus reports whether the HTTP status code indicates a +// transient server-side condition that may resolve on retry. +func isTransientHTTPStatus(code int) bool { + return code == http.StatusBadGateway || // 502 + code == http.StatusServiceUnavailable || // 503 + code == http.StatusTooManyRequests // 429 +} + // Provider implements provider.Provider over Ollama's native /api/chat // endpoint. An empty apiKey means local-mode (no Authorization header sent); // a non-empty apiKey is sent as a Bearer token (cloud-mode). @@ -32,6 +51,10 @@ type Provider struct { apiKey string baseURL string client *http.Client + + // retryBaseDelayOverride, when non-zero, replaces retryBaseDelay for + // testing. Production code leaves this at the zero value. + retryBaseDelayOverride time.Duration } // newNative constructs a native Ollama provider. Callers should use the @@ -420,22 +443,81 @@ func (p *Provider) buildChatRequest(req provider.Request, stream bool) ([]byte, } // doChatRequest POSTs the wire body to /api/chat and returns the raw HTTP -// response. The caller is responsible for closing the response body. +// response. Transient HTTP errors (502, 503, 429) are retried with exponential +// backoff up to retryMaxAttempts times. The caller is responsible for closing +// the response body. func (p *Provider) doChatRequest(ctx context.Context, body []byte) (*http.Response, error) { url := strings.TrimRight(p.baseURL, "/") + "/api/chat" - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) - if err != nil { - return nil, fmt.Errorf("ollama: build request: %w", err) + + for attempt := 0; ; attempt++ { + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("ollama: build request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + if p.apiKey != "" { + httpReq.Header.Set("Authorization", "Bearer "+p.apiKey) + } + + resp, err := p.client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("ollama: HTTP request: %w", err) + } + + // On success or non-transient error, return immediately. + if !isTransientHTTPStatus(resp.StatusCode) || attempt >= retryMaxAttempts { + return resp, nil + } + + // Transient error — drain and close the body before retrying. + respBody, _ := io.ReadAll(resp.Body) + resp.Body.Close() + + delay := retryBackoff(attempt, resp.Header, p.retryBaseDelayOverride) + slog.Info("ollama: retrying after transient HTTP error", + "status", resp.StatusCode, + "attempt", attempt+1, + "max_attempts", retryMaxAttempts, + "delay", delay, + "body", truncateBody(respBody, 200), + ) + + // Wait for backoff or context cancellation. + timer := time.NewTimer(delay) + select { + case <-ctx.Done(): + timer.Stop() + return nil, ctx.Err() + case <-timer.C: + } } - httpReq.Header.Set("Content-Type", "application/json") - if p.apiKey != "" { - httpReq.Header.Set("Authorization", "Bearer "+p.apiKey) +} + +// retryBackoff computes the delay before the next retry attempt. It uses +// exponential backoff (base * 2^attempt), but respects the Retry-After header +// when present (for 429 responses). baseOverride, when non-zero, replaces the +// package-level retryBaseDelay constant (used by tests to avoid real waits). +func retryBackoff(attempt int, header http.Header, baseOverride time.Duration) time.Duration { + // Check Retry-After header (seconds value or HTTP-date; we only parse seconds). + if ra := header.Get("Retry-After"); ra != "" { + if secs, err := strconv.Atoi(ra); err == nil && secs > 0 { + return time.Duration(secs) * time.Second + } } - resp, err := p.client.Do(httpReq) - if err != nil { - return nil, fmt.Errorf("ollama: HTTP request: %w", err) + base := retryBaseDelay + if baseOverride > 0 { + base = baseOverride } - return resp, nil + return base * (1 << attempt) +} + +// truncateBody returns a string of at most maxLen bytes from b, appending +// "..." when truncated. Used for readable log output of error response bodies. +func truncateBody(b []byte, maxLen int) string { + if len(b) <= maxLen { + return string(b) + } + return string(b[:maxLen]) + "..." } // convertMessage maps a provider.Message into a native wire message. diff --git a/v2/ollama/native_test.go b/v2/ollama/native_test.go index 311c45b..3df1007 100644 --- a/v2/ollama/native_test.go +++ b/v2/ollama/native_test.go @@ -571,3 +571,317 @@ func equalStrings(a, b []string) bool { } return true } + +// --- Retry tests --- + +// newTestNative creates a Provider with a minimal retry delay so tests run fast. +func newTestNative(apiKey, baseURL string) *Provider { + p := newNative(apiKey, baseURL) + p.retryBaseDelayOverride = 1 * time.Millisecond + return p +} + +func TestRetryOnTransientHTTPError(t *testing.T) { + t.Run("503 retries then succeeds", func(t *testing.T) { + var attempts int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Drain the request body so the client can reuse the connection. + _, _ = io.ReadAll(r.Body) + attempts++ + if attempts <= 2 { + w.WriteHeader(503) + _, _ = w.Write([]byte(`{"error":"model is temporarily overloaded"}`)) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"ok"},"done":true,"prompt_eval_count":1,"eval_count":1}`)) + })) + t.Cleanup(srv.Close) + + p := newTestNative("key", srv.URL) + resp, err := p.Complete(context.Background(), provider.Request{ + Model: "test", + Messages: []provider.Message{{Role: "user", Content: "hi"}}, + }) + if err != nil { + t.Fatalf("expected success after retries, got error: %v", err) + } + if resp.Text != "ok" { + t.Errorf("Text: want %q, got %q", "ok", resp.Text) + } + if attempts != 3 { + t.Errorf("attempts: want 3, got %d", attempts) + } + }) + + t.Run("429 retries then succeeds", func(t *testing.T) { + var attempts int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.ReadAll(r.Body) + attempts++ + if attempts == 1 { + w.WriteHeader(429) + _, _ = w.Write([]byte(`{"error":"rate limited"}`)) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"done"},"done":true}`)) + })) + t.Cleanup(srv.Close) + + p := newTestNative("", srv.URL) + resp, err := p.Complete(context.Background(), provider.Request{ + Model: "test", + Messages: []provider.Message{{Role: "user", Content: "hi"}}, + }) + if err != nil { + t.Fatalf("expected success after retry, got error: %v", err) + } + if resp.Text != "done" { + t.Errorf("Text: want %q, got %q", "done", resp.Text) + } + if attempts != 2 { + t.Errorf("attempts: want 2, got %d", attempts) + } + }) + + t.Run("502 retries then succeeds", func(t *testing.T) { + var attempts int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.ReadAll(r.Body) + attempts++ + if attempts == 1 { + w.WriteHeader(502) + _, _ = w.Write([]byte(`Bad Gateway`)) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"yes"},"done":true}`)) + })) + t.Cleanup(srv.Close) + + p := newTestNative("", srv.URL) + resp, err := p.Complete(context.Background(), provider.Request{ + Model: "test", + Messages: []provider.Message{{Role: "user", Content: "hi"}}, + }) + if err != nil { + t.Fatalf("expected success after retry, got error: %v", err) + } + if resp.Text != "yes" { + t.Errorf("Text: want %q, got %q", "yes", resp.Text) + } + }) + + t.Run("exhausts retries and returns error", func(t *testing.T) { + var attempts int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.ReadAll(r.Body) + attempts++ + w.WriteHeader(503) + _, _ = w.Write([]byte(`{"error":"overloaded"}`)) + })) + t.Cleanup(srv.Close) + + p := newTestNative("", srv.URL) + _, err := p.Complete(context.Background(), provider.Request{ + Model: "test", + Messages: []provider.Message{{Role: "user", Content: "hi"}}, + }) + if err == nil { + t.Fatal("expected error after exhausting retries") + } + if !strings.Contains(err.Error(), "503") { + t.Errorf("error should mention status 503, got: %v", err) + } + // 1 initial + 3 retries = 4 total + if attempts != 4 { + t.Errorf("attempts: want 4, got %d", attempts) + } + }) + + t.Run("400 is not retried", func(t *testing.T) { + var attempts int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.ReadAll(r.Body) + attempts++ + w.WriteHeader(400) + _, _ = w.Write([]byte(`{"error":"bad request"}`)) + })) + t.Cleanup(srv.Close) + + p := newTestNative("", srv.URL) + _, err := p.Complete(context.Background(), provider.Request{ + Model: "test", + Messages: []provider.Message{{Role: "user", Content: "hi"}}, + }) + if err == nil { + t.Fatal("expected error for 400") + } + if attempts != 1 { + t.Errorf("attempts: want 1 (no retries for 400), got %d", attempts) + } + }) + + t.Run("context cancellation during backoff aborts retry", func(t *testing.T) { + var attempts int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.ReadAll(r.Body) + attempts++ + w.WriteHeader(503) + _, _ = w.Write([]byte(`{"error":"overloaded"}`)) + })) + t.Cleanup(srv.Close) + + ctx, cancel := context.WithCancel(context.Background()) + // Cancel shortly after the first attempt so the backoff wait is interrupted. + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + p := newTestNative("", srv.URL) + // Use a longer backoff so the context cancel fires during the wait. + p.retryBaseDelayOverride = 2 * time.Second + _, err := p.Complete(ctx, provider.Request{ + Model: "test", + Messages: []provider.Message{{Role: "user", Content: "hi"}}, + }) + if err == nil { + t.Fatal("expected error from cancelled context") + } + if !strings.Contains(err.Error(), "canceled") && !strings.Contains(err.Error(), "context") { + t.Errorf("expected context error, got: %v", err) + } + // Should have made only 1 attempt before the context cancelled during backoff. + if attempts != 1 { + t.Errorf("attempts: want 1, got %d", attempts) + } + }) + + t.Run("stream retries on 503 then succeeds", func(t *testing.T) { + var attempts int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.ReadAll(r.Body) + attempts++ + if attempts == 1 { + w.WriteHeader(503) + _, _ = w.Write([]byte(`{"error":"overloaded"}`)) + return + } + w.Header().Set("Content-Type", "application/x-ndjson") + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"streamed"},"done":true,"prompt_eval_count":1,"eval_count":1}` + "\n")) + })) + t.Cleanup(srv.Close) + + p := newTestNative("", srv.URL) + events := collectStream(t, p, provider.Request{ + Model: "test", + Messages: []provider.Message{{Role: "user", Content: "hi"}}, + }) + + var gotDone bool + for _, ev := range events { + if ev.Type == provider.StreamEventDone && ev.Response != nil { + if ev.Response.Text != "streamed" { + t.Errorf("Response.Text: want %q, got %q", "streamed", ev.Response.Text) + } + gotDone = true + } + } + if !gotDone { + t.Fatal("expected StreamEventDone") + } + if attempts != 2 { + t.Errorf("attempts: want 2, got %d", attempts) + } + }) +} + +func TestRetryBackoff(t *testing.T) { + t.Run("exponential backoff without Retry-After", func(t *testing.T) { + h := http.Header{} + d0 := retryBackoff(0, h, 0) + d1 := retryBackoff(1, h, 0) + d2 := retryBackoff(2, h, 0) + + if d0 != 1*time.Second { + t.Errorf("attempt 0: want 1s, got %v", d0) + } + if d1 != 2*time.Second { + t.Errorf("attempt 1: want 2s, got %v", d1) + } + if d2 != 4*time.Second { + t.Errorf("attempt 2: want 4s, got %v", d2) + } + }) + + t.Run("Retry-After header overrides backoff", func(t *testing.T) { + h := http.Header{} + h.Set("Retry-After", "5") + d := retryBackoff(0, h, 0) + if d != 5*time.Second { + t.Errorf("want 5s from Retry-After, got %v", d) + } + }) + + t.Run("invalid Retry-After falls back to exponential", func(t *testing.T) { + h := http.Header{} + h.Set("Retry-After", "not-a-number") + d := retryBackoff(1, h, 0) + if d != 2*time.Second { + t.Errorf("want 2s fallback, got %v", d) + } + }) + + t.Run("baseOverride replaces default base delay", func(t *testing.T) { + h := http.Header{} + d := retryBackoff(0, h, 500*time.Millisecond) + if d != 500*time.Millisecond { + t.Errorf("attempt 0 with 500ms override: want 500ms, got %v", d) + } + d1 := retryBackoff(2, h, 500*time.Millisecond) + if d1 != 2*time.Second { + t.Errorf("attempt 2 with 500ms override: want 2s, got %v", d1) + } + }) +} + +func TestIsTransientHTTPStatus(t *testing.T) { + cases := []struct { + code int + want bool + }{ + {200, false}, + {400, false}, + {401, false}, + {403, false}, + {404, false}, + {429, true}, + {500, false}, + {502, true}, + {503, true}, + } + for _, c := range cases { + if got := isTransientHTTPStatus(c.code); got != c.want { + t.Errorf("isTransientHTTPStatus(%d): want %v, got %v", c.code, c.want, got) + } + } +} + +func TestTruncateBody(t *testing.T) { + short := "hello" + if got := truncateBody([]byte(short), 10); got != short { + t.Errorf("short: want %q, got %q", short, got) + } + + long := "hello world this is a very long string" + got := truncateBody([]byte(long), 10) + if got != "hello worl..." { + t.Errorf("long: want %q, got %q", "hello worl...", got) + } +}