package llm import ( "context" "errors" "fmt" "net" "strings" "syscall" "testing" ) type fakeNetErr struct{ timeout bool } func (e fakeNetErr) Error() string { return "fake net error" } func (e fakeNetErr) Timeout() bool { return e.timeout } func (e fakeNetErr) Temporary() bool { return true } var _ net.Error = fakeNetErr{} func TestClassify(t *testing.T) { tests := []struct { name string err error want ErrorClass }{ {"canceled is permanent", context.Canceled, ClassPermanent}, {"deadline is transient", context.DeadlineExceeded, ClassTransient}, {"wrapped canceled", fmt.Errorf("call: %w", context.Canceled), ClassPermanent}, {"model not found", fmt.Errorf("x: %w", ErrModelNotFound), ClassPermanent}, {"conn refused", syscall.ECONNREFUSED, ClassTransient}, {"conn reset", fmt.Errorf("write: %w", syscall.ECONNRESET), ClassTransient}, {"net timeout", fakeNetErr{timeout: true}, ClassTransient}, {"http 429", &APIError{Status: 429}, ClassTransient}, {"http 408", &APIError{Status: 408}, ClassTransient}, {"http 500", &APIError{Status: 500}, ClassTransient}, {"http 503", &APIError{Status: 503}, ClassTransient}, {"http 529", &APIError{Status: 529}, ClassTransient}, {"http 400", &APIError{Status: 400}, ClassPermanent}, {"http 401", &APIError{Status: 401}, ClassPermanent}, {"http 403", &APIError{Status: 403}, ClassPermanent}, {"http 404", &APIError{Status: 404}, ClassPermanent}, {"http 422", &APIError{Status: 422}, ClassPermanent}, {"wrapped api error", fmt.Errorf("call: %w", &APIError{Status: 503}), ClassTransient}, {"unknown defaults transient", errors.New("mystery"), ClassTransient}, {"non-http api error defaults transient", &APIError{Message: "decode failed"}, ClassTransient}, } for _, tt := range tests { if got := Classify(tt.err); got != tt.want { t.Errorf("%s: Classify = %v, want %v", tt.name, got, tt.want) } } } func TestAPIError404UnwrapsToModelNotFound(t *testing.T) { err := &APIError{Provider: "openai", Model: "nope", Status: 404} if !errors.Is(err, ErrModelNotFound) { t.Error("404 APIError should match ErrModelNotFound") } if errors.Is(&APIError{Status: 500}, ErrModelNotFound) { t.Error("500 APIError must not match ErrModelNotFound") } } func TestAPIErrorMessage(t *testing.T) { err := &APIError{ Provider: "anthropic", Model: "opus-4.8", Status: 429, Code: "rate_limit_error", Message: "slow down", } got := err.Error() for _, frag := range []string{"anthropic/opus-4.8", "429", "rate_limit_error", "slow down"} { if !strings.Contains(got, frag) { t.Errorf("error string %q missing %q", got, frag) } } } func TestAPIErrorUnwrapsCause(t *testing.T) { cause := errors.New("boom") err := &APIError{Provider: "p", Model: "m", Err: cause} if !errors.Is(err, cause) { t.Error("APIError should unwrap to its cause") } }