Files
steve 67c3ebe067
CI / Build, Test & Lint (push) Successful in 10m50s
feat(ollama): add automatic retry with exponential backoff for transient HTTP errors
Ollama Cloud returns HTTP 503 when the model is temporarily overloaded,
429 on rate limit, and 502 on upstream failures. These are transient
conditions that resolve on retry. Previously they bubbled up as hard
errors, forcing callers to implement their own retry logic.

The retry is implemented at the HTTP transport level in doChatRequest,
so both Complete and Stream benefit transparently. Strategy: up to 3
retries with exponential backoff (1s, 2s, 4s), Retry-After header
respected for 429, context cancellation checked between retries.
Non-transient errors (400, 401, 403, 404, 500) are never retried.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 11:58:25 -04:00

888 lines
26 KiB
Go

package ollama
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
// captureRequest is a tiny helper that records the inbound HTTP request and
// returns a configurable response body.
type captureRequest struct {
method string
path string
authHeader string
contentType string
body []byte
parsedBody map[string]any
}
func newTestServer(t *testing.T, captured *captureRequest, status int, respBody string, respContentType string) *httptest.Server {
t.Helper()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method
captured.path = r.URL.Path
captured.authHeader = r.Header.Get("Authorization")
captured.contentType = r.Header.Get("Content-Type")
body, _ := io.ReadAll(r.Body)
captured.body = body
_ = json.Unmarshal(body, &captured.parsedBody)
if respContentType == "" {
respContentType = "application/json"
}
w.Header().Set("Content-Type", respContentType)
w.WriteHeader(status)
_, _ = w.Write([]byte(respBody))
}))
t.Cleanup(srv.Close)
return srv
}
func TestCompleteBasic(t *testing.T) {
resp := `{
"model": "kimi-k2.5",
"message": {"role": "assistant", "content": "hello there"},
"done": true,
"done_reason": "stop",
"prompt_eval_count": 10,
"eval_count": 3
}`
cap := &captureRequest{}
srv := newTestServer(t, cap, 200, resp, "")
p := newNative("test-key", srv.URL)
got, err := p.Complete(context.Background(), provider.Request{
Model: "kimi-k2.5",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
if err != nil {
t.Fatalf("Complete: %v", err)
}
if cap.method != "POST" {
t.Errorf("method: want POST, got %q", cap.method)
}
if cap.path != "/api/chat" {
t.Errorf("path: want /api/chat, got %q", cap.path)
}
if cap.authHeader != "Bearer test-key" {
t.Errorf("auth header: want %q, got %q", "Bearer test-key", cap.authHeader)
}
if cap.contentType != "application/json" {
t.Errorf("content-type: want application/json, got %q", cap.contentType)
}
if cap.parsedBody["model"] != "kimi-k2.5" {
t.Errorf("body.model: want kimi-k2.5, got %v", cap.parsedBody["model"])
}
if cap.parsedBody["stream"] != false {
t.Errorf("body.stream: want false, got %v", cap.parsedBody["stream"])
}
msgs, _ := cap.parsedBody["messages"].([]any)
if len(msgs) != 1 {
t.Fatalf("messages: want 1 entry, got %d", len(msgs))
}
m0, _ := msgs[0].(map[string]any)
if m0["role"] != "user" || m0["content"] != "hi" {
t.Errorf("first message: want role=user content=hi, got %v", m0)
}
if got.Text != "hello there" {
t.Errorf("Text: want %q, got %q", "hello there", got.Text)
}
if got.Usage == nil {
t.Fatal("Usage: want non-nil")
}
if got.Usage.InputTokens != 10 || got.Usage.OutputTokens != 3 {
t.Errorf("Usage: want input=10 output=3, got input=%d output=%d", got.Usage.InputTokens, got.Usage.OutputTokens)
}
if got.Usage.TotalTokens != 13 {
t.Errorf("Usage.TotalTokens: want 13, got %d", got.Usage.TotalTokens)
}
}
func TestCompleteNoAuthHeaderWhenLocal(t *testing.T) {
resp := `{"message":{"role":"assistant","content":"ok"},"done":true}`
cap := &captureRequest{}
srv := newTestServer(t, cap, 200, resp, "")
p := newNative("", srv.URL)
if _, err := p.Complete(context.Background(), provider.Request{
Model: "llama3.2",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
}); err != nil {
t.Fatalf("Complete: %v", err)
}
if cap.authHeader != "" {
t.Errorf("auth header: want empty (local mode), got %q", cap.authHeader)
}
}
func TestVisionImagesEncoded(t *testing.T) {
resp := `{"message":{"role":"assistant","content":"a cat"},"done":true}`
cap := &captureRequest{}
srv := newTestServer(t, cap, 200, resp, "")
p := newNative("", srv.URL)
if _, err := p.Complete(context.Background(), provider.Request{
Model: "llava",
Messages: []provider.Message{{
Role: "user",
Content: "what's in this?",
Images: []provider.Image{
{Base64: "AAAA", ContentType: "image/png"},
},
}},
}); err != nil {
t.Fatalf("Complete: %v", err)
}
msgs, _ := cap.parsedBody["messages"].([]any)
if len(msgs) != 1 {
t.Fatalf("messages: want 1, got %d", len(msgs))
}
m0, _ := msgs[0].(map[string]any)
imgs, _ := m0["images"].([]any)
if len(imgs) != 1 {
t.Fatalf("images: want 1 entry, got %d (msg=%v)", len(imgs), m0)
}
if imgs[0] != "AAAA" {
t.Errorf("images[0]: want raw base64 AAAA, got %v", imgs[0])
}
}
func TestThinkingField(t *testing.T) {
cases := []struct {
name string
reasoning string
want any // expected value of "think" in body, or nil if absent
}{
{"absent", "", nil},
{"high", "high", "high"},
{"low", "low", "low"},
{"medium", "medium", "medium"},
{"true", "true", true},
{"false", "false", false},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
resp := `{"message":{"role":"assistant","content":"ok"},"done":true}`
cap := &captureRequest{}
srv := newTestServer(t, cap, 200, resp, "")
p := newNative("", srv.URL)
_, err := p.Complete(context.Background(), provider.Request{
Model: "kimi-k2.5",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
Reasoning: c.reasoning,
})
if err != nil {
t.Fatalf("Complete: %v", err)
}
got, present := cap.parsedBody["think"]
if c.want == nil {
if present {
t.Errorf("think field should be absent, got %v", got)
}
return
}
if !present {
t.Fatalf("think field absent; want %v", c.want)
}
if got != c.want {
t.Errorf("think: want %v (%T), got %v (%T)", c.want, c.want, got, got)
}
})
}
}
func TestToolRoundTrip(t *testing.T) {
t.Run("response tool_calls convert to provider.Response", func(t *testing.T) {
resp := `{
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{"function": {"name": "search", "arguments": {"query": "foo"}}}
]
},
"done": true,
"prompt_eval_count": 5,
"eval_count": 2
}`
cap := &captureRequest{}
srv := newTestServer(t, cap, 200, resp, "")
p := newNative("", srv.URL)
got, err := p.Complete(context.Background(), provider.Request{
Model: "kimi-k2.5",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
Tools: []provider.ToolDef{
{
Name: "search",
Description: "Run a search",
Schema: map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{"type": "string"},
},
},
},
},
})
if err != nil {
t.Fatalf("Complete: %v", err)
}
// Verify request shape: tools array present.
toolsArr, _ := cap.parsedBody["tools"].([]any)
if len(toolsArr) != 1 {
t.Fatalf("tools: want 1 entry, got %d", len(toolsArr))
}
t0, _ := toolsArr[0].(map[string]any)
if t0["type"] != "function" {
t.Errorf("tools[0].type: want function, got %v", t0["type"])
}
fn, _ := t0["function"].(map[string]any)
if fn["name"] != "search" {
t.Errorf("tools[0].function.name: want search, got %v", fn["name"])
}
// Verify response conversion.
if len(got.ToolCalls) != 1 {
t.Fatalf("ToolCalls: want 1, got %d", len(got.ToolCalls))
}
tc := got.ToolCalls[0]
if tc.Name != "search" {
t.Errorf("ToolCall.Name: want search, got %q", tc.Name)
}
// Arguments should be valid JSON containing query=foo
var args map[string]any
if err := json.Unmarshal([]byte(tc.Arguments), &args); err != nil {
t.Fatalf("ToolCall.Arguments not valid JSON: %v (got %q)", err, tc.Arguments)
}
if args["query"] != "foo" {
t.Errorf("ToolCall.Arguments.query: want foo, got %v", args["query"])
}
})
t.Run("subsequent request includes assistant tool_calls and tool-role response", func(t *testing.T) {
resp := `{"message":{"role":"assistant","content":"done"},"done":true}`
cap := &captureRequest{}
srv := newTestServer(t, cap, 200, resp, "")
p := newNative("", srv.URL)
_, err := p.Complete(context.Background(), provider.Request{
Model: "kimi-k2.5",
Messages: []provider.Message{
{Role: "user", Content: "search foo"},
{
Role: "assistant",
ToolCalls: []provider.ToolCall{{
ID: "tc1",
Name: "search",
Arguments: `{"query":"foo"}`,
}},
},
{
Role: "tool",
ToolCallID: "tc1",
Content: `{"result":"bar"}`,
},
},
})
if err != nil {
t.Fatalf("Complete: %v", err)
}
msgs, _ := cap.parsedBody["messages"].([]any)
if len(msgs) != 3 {
t.Fatalf("messages: want 3, got %d", len(msgs))
}
// Assistant message must carry tool_calls with the JSON-object arguments.
asst, _ := msgs[1].(map[string]any)
if asst["role"] != "assistant" {
t.Errorf("msgs[1].role: want assistant, got %v", asst["role"])
}
tc, _ := asst["tool_calls"].([]any)
if len(tc) != 1 {
t.Fatalf("assistant.tool_calls: want 1, got %d", len(tc))
}
fn, _ := tc[0].(map[string]any)["function"].(map[string]any)
if fn["name"] != "search" {
t.Errorf("assistant.tool_calls[0].function.name: want search, got %v", fn["name"])
}
args, _ := fn["arguments"].(map[string]any)
if args["query"] != "foo" {
t.Errorf("assistant.tool_calls[0].function.arguments.query: want foo, got %v", args["query"])
}
// Tool-role message must have role=tool, tool_call_id, and content.
tool, _ := msgs[2].(map[string]any)
if tool["role"] != "tool" {
t.Errorf("msgs[2].role: want tool, got %v", tool["role"])
}
if tool["tool_call_id"] != "tc1" {
t.Errorf("msgs[2].tool_call_id: want tc1, got %v", tool["tool_call_id"])
}
if !strings.Contains(toString(tool["content"]), "bar") {
t.Errorf("msgs[2].content: want to contain bar, got %v", tool["content"])
}
})
}
func toString(v any) string {
if s, ok := v.(string); ok {
return s
}
b, _ := json.Marshal(v)
return string(b)
}
// streamServer returns an httptest.Server that writes the given NDJSON lines
// (each terminated with \n) as the response body.
func streamServer(t *testing.T, captured *captureRequest, lines []string) *httptest.Server {
t.Helper()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method
captured.path = r.URL.Path
captured.authHeader = r.Header.Get("Authorization")
captured.contentType = r.Header.Get("Content-Type")
body, _ := io.ReadAll(r.Body)
captured.body = body
_ = json.Unmarshal(body, &captured.parsedBody)
w.Header().Set("Content-Type", "application/x-ndjson")
w.WriteHeader(200)
flusher, _ := w.(http.Flusher)
for _, line := range lines {
_, _ = w.Write([]byte(line + "\n"))
if flusher != nil {
flusher.Flush()
}
}
}))
t.Cleanup(srv.Close)
return srv
}
func collectStream(t *testing.T, p *Provider, req provider.Request) []provider.StreamEvent {
t.Helper()
events := make(chan provider.StreamEvent, 64)
done := make(chan error, 1)
go func() {
done <- p.Stream(context.Background(), req, events)
}()
var out []provider.StreamEvent
timeout := time.After(5 * time.Second)
streamErrored := false
loop:
for {
select {
case ev, ok := <-events:
if !ok {
break loop
}
out = append(out, ev)
if ev.Type == provider.StreamEventError {
streamErrored = true
}
case err := <-done:
if err != nil && !streamErrored {
t.Fatalf("Stream returned error: %v", err)
}
// Drain any final events buffered in the channel.
for {
select {
case ev, ok := <-events:
if !ok {
return out
}
out = append(out, ev)
default:
return out
}
}
case <-timeout:
t.Fatal("Stream did not complete within 5s")
}
}
if err := <-done; err != nil && !streamErrored {
t.Fatalf("Stream returned error: %v", err)
}
return out
}
func TestStreamBasic(t *testing.T) {
lines := []string{
`{"message":{"role":"assistant","content":"hello"},"done":false}`,
`{"message":{"role":"assistant","content":" world","thinking":"reasoning"},"done":false}`,
`{"message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":12,"eval_count":2}`,
}
cap := &captureRequest{}
srv := streamServer(t, cap, lines)
p := newNative("", srv.URL)
events := collectStream(t, p, provider.Request{
Model: "kimi-k2.5",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
// Verify request shape: stream:true.
if cap.parsedBody["stream"] != true {
t.Errorf("body.stream: want true, got %v", cap.parsedBody["stream"])
}
// Filter to relevant events (text, thinking, done) preserving order.
var kinds []string
var texts []string
var doneEvent *provider.StreamEvent
for i, ev := range events {
switch ev.Type {
case provider.StreamEventText:
kinds = append(kinds, "text")
texts = append(texts, ev.Text)
case provider.StreamEventThinking:
kinds = append(kinds, "thinking")
texts = append(texts, ev.Text)
case provider.StreamEventDone:
kinds = append(kinds, "done")
e := events[i]
doneEvent = &e
}
}
wantKinds := []string{"text", "thinking", "text", "done"}
if !equalStrings(kinds, wantKinds) {
t.Errorf("event kinds: want %v, got %v", wantKinds, kinds)
}
if len(texts) >= 3 {
if texts[0] != "hello" {
t.Errorf("first text: want hello, got %q", texts[0])
}
if texts[1] != "reasoning" {
t.Errorf("thinking: want reasoning, got %q", texts[1])
}
if texts[2] != " world" {
t.Errorf("second text: want \" world\", got %q", texts[2])
}
}
if doneEvent == nil || doneEvent.Response == nil {
t.Fatal("Done event missing Response")
}
if doneEvent.Response.Text != "hello world" {
t.Errorf("Response.Text: want %q, got %q", "hello world", doneEvent.Response.Text)
}
if doneEvent.Response.Thinking != "reasoning" {
t.Errorf("Response.Thinking: want %q, got %q", "reasoning", doneEvent.Response.Thinking)
}
if doneEvent.Response.Usage == nil {
t.Fatal("Response.Usage missing")
}
if doneEvent.Response.Usage.InputTokens != 12 || doneEvent.Response.Usage.OutputTokens != 2 {
t.Errorf("Usage: want input=12 output=2, got input=%d output=%d", doneEvent.Response.Usage.InputTokens, doneEvent.Response.Usage.OutputTokens)
}
}
func TestStreamToolDeltaAccumulation(t *testing.T) {
lines := []string{
`{"message":{"role":"assistant","content":"","tool_calls":[{"id":"tc1","function":{"name":"search","arguments":"{\"que"}}]},"done":false}`,
`{"message":{"role":"assistant","content":"","tool_calls":[{"id":"tc1","function":{"arguments":"ry\":\"foo\"}"}}]},"done":false}`,
`{"message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":4,"eval_count":1}`,
}
cap := &captureRequest{}
srv := streamServer(t, cap, lines)
p := newNative("", srv.URL)
events := collectStream(t, p, provider.Request{
Model: "kimi-k2.5",
Messages: []provider.Message{{Role: "user", Content: "search foo"}},
Tools: []provider.ToolDef{
{Name: "search", Schema: map[string]any{"type": "object"}},
},
})
// Build a slim trace of tool events.
type traceEntry struct {
kind string
args string
name string
id string
}
var trace []traceEntry
var doneEvent *provider.StreamEvent
for i, ev := range events {
switch ev.Type {
case provider.StreamEventToolStart:
trace = append(trace, traceEntry{kind: "start", name: ev.ToolCall.Name, id: ev.ToolCall.ID})
case provider.StreamEventToolDelta:
trace = append(trace, traceEntry{kind: "delta", args: ev.ToolCall.Arguments})
case provider.StreamEventToolEnd:
trace = append(trace, traceEntry{kind: "end", args: ev.ToolCall.Arguments, name: ev.ToolCall.Name, id: ev.ToolCall.ID})
case provider.StreamEventDone:
e := events[i]
doneEvent = &e
}
}
if len(trace) != 4 {
t.Fatalf("trace: want 4 entries (start, delta, delta, end), got %d: %+v", len(trace), trace)
}
if trace[0].kind != "start" || trace[0].name != "search" || trace[0].id != "tc1" {
t.Errorf("trace[0]: want start search tc1, got %+v", trace[0])
}
if trace[1].kind != "delta" || trace[1].args != `{"que` {
t.Errorf("trace[1]: want delta args=%q, got %+v", `{"que`, trace[1])
}
if trace[2].kind != "delta" || trace[2].args != `ry":"foo"}` {
t.Errorf("trace[2]: want delta args=%q, got %+v", `ry":"foo"}`, trace[2])
}
if trace[3].kind != "end" || trace[3].args != `{"query":"foo"}` {
t.Errorf("trace[3]: want end args=%q, got %+v", `{"query":"foo"}`, trace[3])
}
if doneEvent == nil || doneEvent.Response == nil {
t.Fatal("Done event missing Response")
}
if len(doneEvent.Response.ToolCalls) != 1 {
t.Fatalf("Done.Response.ToolCalls: want 1, got %d", len(doneEvent.Response.ToolCalls))
}
tc := doneEvent.Response.ToolCalls[0]
if tc.ID != "tc1" || tc.Name != "search" || tc.Arguments != `{"query":"foo"}` {
t.Errorf("Done.Response.ToolCalls[0]: want tc1/search/{...}, got %+v", tc)
}
}
func equalStrings(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
// --- Retry tests ---
// newTestNative creates a Provider with a minimal retry delay so tests run fast.
func newTestNative(apiKey, baseURL string) *Provider {
p := newNative(apiKey, baseURL)
p.retryBaseDelayOverride = 1 * time.Millisecond
return p
}
func TestRetryOnTransientHTTPError(t *testing.T) {
t.Run("503 retries then succeeds", func(t *testing.T) {
var attempts int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Drain the request body so the client can reuse the connection.
_, _ = io.ReadAll(r.Body)
attempts++
if attempts <= 2 {
w.WriteHeader(503)
_, _ = w.Write([]byte(`{"error":"model is temporarily overloaded"}`))
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
_, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"ok"},"done":true,"prompt_eval_count":1,"eval_count":1}`))
}))
t.Cleanup(srv.Close)
p := newTestNative("key", srv.URL)
resp, err := p.Complete(context.Background(), provider.Request{
Model: "test",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
if err != nil {
t.Fatalf("expected success after retries, got error: %v", err)
}
if resp.Text != "ok" {
t.Errorf("Text: want %q, got %q", "ok", resp.Text)
}
if attempts != 3 {
t.Errorf("attempts: want 3, got %d", attempts)
}
})
t.Run("429 retries then succeeds", func(t *testing.T) {
var attempts int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
attempts++
if attempts == 1 {
w.WriteHeader(429)
_, _ = w.Write([]byte(`{"error":"rate limited"}`))
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
_, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"done"},"done":true}`))
}))
t.Cleanup(srv.Close)
p := newTestNative("", srv.URL)
resp, err := p.Complete(context.Background(), provider.Request{
Model: "test",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
if err != nil {
t.Fatalf("expected success after retry, got error: %v", err)
}
if resp.Text != "done" {
t.Errorf("Text: want %q, got %q", "done", resp.Text)
}
if attempts != 2 {
t.Errorf("attempts: want 2, got %d", attempts)
}
})
t.Run("502 retries then succeeds", func(t *testing.T) {
var attempts int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
attempts++
if attempts == 1 {
w.WriteHeader(502)
_, _ = w.Write([]byte(`Bad Gateway`))
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
_, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"yes"},"done":true}`))
}))
t.Cleanup(srv.Close)
p := newTestNative("", srv.URL)
resp, err := p.Complete(context.Background(), provider.Request{
Model: "test",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
if err != nil {
t.Fatalf("expected success after retry, got error: %v", err)
}
if resp.Text != "yes" {
t.Errorf("Text: want %q, got %q", "yes", resp.Text)
}
})
t.Run("exhausts retries and returns error", func(t *testing.T) {
var attempts int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
attempts++
w.WriteHeader(503)
_, _ = w.Write([]byte(`{"error":"overloaded"}`))
}))
t.Cleanup(srv.Close)
p := newTestNative("", srv.URL)
_, err := p.Complete(context.Background(), provider.Request{
Model: "test",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
if err == nil {
t.Fatal("expected error after exhausting retries")
}
if !strings.Contains(err.Error(), "503") {
t.Errorf("error should mention status 503, got: %v", err)
}
// 1 initial + 3 retries = 4 total
if attempts != 4 {
t.Errorf("attempts: want 4, got %d", attempts)
}
})
t.Run("400 is not retried", func(t *testing.T) {
var attempts int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
attempts++
w.WriteHeader(400)
_, _ = w.Write([]byte(`{"error":"bad request"}`))
}))
t.Cleanup(srv.Close)
p := newTestNative("", srv.URL)
_, err := p.Complete(context.Background(), provider.Request{
Model: "test",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
if err == nil {
t.Fatal("expected error for 400")
}
if attempts != 1 {
t.Errorf("attempts: want 1 (no retries for 400), got %d", attempts)
}
})
t.Run("context cancellation during backoff aborts retry", func(t *testing.T) {
var attempts int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
attempts++
w.WriteHeader(503)
_, _ = w.Write([]byte(`{"error":"overloaded"}`))
}))
t.Cleanup(srv.Close)
ctx, cancel := context.WithCancel(context.Background())
// Cancel shortly after the first attempt so the backoff wait is interrupted.
go func() {
time.Sleep(50 * time.Millisecond)
cancel()
}()
p := newTestNative("", srv.URL)
// Use a longer backoff so the context cancel fires during the wait.
p.retryBaseDelayOverride = 2 * time.Second
_, err := p.Complete(ctx, provider.Request{
Model: "test",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
if err == nil {
t.Fatal("expected error from cancelled context")
}
if !strings.Contains(err.Error(), "canceled") && !strings.Contains(err.Error(), "context") {
t.Errorf("expected context error, got: %v", err)
}
// Should have made only 1 attempt before the context cancelled during backoff.
if attempts != 1 {
t.Errorf("attempts: want 1, got %d", attempts)
}
})
t.Run("stream retries on 503 then succeeds", func(t *testing.T) {
var attempts int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.ReadAll(r.Body)
attempts++
if attempts == 1 {
w.WriteHeader(503)
_, _ = w.Write([]byte(`{"error":"overloaded"}`))
return
}
w.Header().Set("Content-Type", "application/x-ndjson")
w.WriteHeader(200)
_, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"streamed"},"done":true,"prompt_eval_count":1,"eval_count":1}` + "\n"))
}))
t.Cleanup(srv.Close)
p := newTestNative("", srv.URL)
events := collectStream(t, p, provider.Request{
Model: "test",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
var gotDone bool
for _, ev := range events {
if ev.Type == provider.StreamEventDone && ev.Response != nil {
if ev.Response.Text != "streamed" {
t.Errorf("Response.Text: want %q, got %q", "streamed", ev.Response.Text)
}
gotDone = true
}
}
if !gotDone {
t.Fatal("expected StreamEventDone")
}
if attempts != 2 {
t.Errorf("attempts: want 2, got %d", attempts)
}
})
}
func TestRetryBackoff(t *testing.T) {
t.Run("exponential backoff without Retry-After", func(t *testing.T) {
h := http.Header{}
d0 := retryBackoff(0, h, 0)
d1 := retryBackoff(1, h, 0)
d2 := retryBackoff(2, h, 0)
if d0 != 1*time.Second {
t.Errorf("attempt 0: want 1s, got %v", d0)
}
if d1 != 2*time.Second {
t.Errorf("attempt 1: want 2s, got %v", d1)
}
if d2 != 4*time.Second {
t.Errorf("attempt 2: want 4s, got %v", d2)
}
})
t.Run("Retry-After header overrides backoff", func(t *testing.T) {
h := http.Header{}
h.Set("Retry-After", "5")
d := retryBackoff(0, h, 0)
if d != 5*time.Second {
t.Errorf("want 5s from Retry-After, got %v", d)
}
})
t.Run("invalid Retry-After falls back to exponential", func(t *testing.T) {
h := http.Header{}
h.Set("Retry-After", "not-a-number")
d := retryBackoff(1, h, 0)
if d != 2*time.Second {
t.Errorf("want 2s fallback, got %v", d)
}
})
t.Run("baseOverride replaces default base delay", func(t *testing.T) {
h := http.Header{}
d := retryBackoff(0, h, 500*time.Millisecond)
if d != 500*time.Millisecond {
t.Errorf("attempt 0 with 500ms override: want 500ms, got %v", d)
}
d1 := retryBackoff(2, h, 500*time.Millisecond)
if d1 != 2*time.Second {
t.Errorf("attempt 2 with 500ms override: want 2s, got %v", d1)
}
})
}
func TestIsTransientHTTPStatus(t *testing.T) {
cases := []struct {
code int
want bool
}{
{200, false},
{400, false},
{401, false},
{403, false},
{404, false},
{429, true},
{500, false},
{502, true},
{503, true},
}
for _, c := range cases {
if got := isTransientHTTPStatus(c.code); got != c.want {
t.Errorf("isTransientHTTPStatus(%d): want %v, got %v", c.code, c.want, got)
}
}
}
func TestTruncateBody(t *testing.T) {
short := "hello"
if got := truncateBody([]byte(short), 10); got != short {
t.Errorf("short: want %q, got %q", short, got)
}
long := "hello world this is a very long string"
got := truncateBody([]byte(long), 10)
if got != "hello worl..." {
t.Errorf("long: want %q, got %q", "hello worl...", got)
}
}