package openai import ( "context" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "testing" "gitea.stevedudenhoeffer.com/steve/majordomo/llm" ) // sseServer streams each payload as one "data: " SSE event and // records the request like newServer. func sseServer(t *testing.T, payloads ...string) (*httptest.Server, *recorded) { t.Helper() rec := &recorded{} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rec.hits++ rec.header = r.Header.Clone() rec.path = r.URL.Path raw, err := io.ReadAll(r.Body) if err != nil { t.Errorf("read request body: %v", err) } if len(raw) > 0 { if err := json.Unmarshal(raw, &rec.body); err != nil { t.Errorf("request body is not JSON: %v\n%s", err, raw) } } w.Header().Set("Content-Type", "text/event-stream") for _, p := range payloads { io.WriteString(w, "data: "+p+"\n\n") } })) t.Cleanup(srv.Close) return srv, rec } // collect drains a stream to io.EOF, failing the test on any other error. func collect(t *testing.T, s llm.Stream) []llm.StreamEvent { t.Helper() var events []llm.StreamEvent for { ev, err := s.Next() if err == io.EOF { return events } if err != nil { t.Fatalf("Next: %v", err) } events = append(events, ev) } } func TestStreamText(t *testing.T) { srv, rec := sseServer(t, `{"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-test","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"finish_reason":null}]}`, `{"choices":[{"index":0,"delta":{"content":"Hel"},"finish_reason":null}],"obfuscation":"xK9q"}`, `{"choices":[{"index":0,"delta":{"content":"lo"},"finish_reason":null}]}`, `{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`, `{"choices":[],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`, `[DONE]`, ) m := testModel(t, srv, nil) s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if err != nil { t.Fatalf("Stream: %v", err) } defer s.Close() events := collect(t, s) // Request shape: stream flag, usage opt-in, SSE accept header. if rec.body["stream"] != true { t.Errorf("stream = %v, want true", rec.body["stream"]) } so, _ := rec.body["stream_options"].(map[string]any) if so == nil || so["include_usage"] != true { t.Errorf("stream_options = %v, want include_usage true", rec.body["stream_options"]) } if got := rec.header.Get("Accept"); got != "text/event-stream" { t.Errorf("Accept = %q, want text/event-stream", got) } if len(events) != 3 { t.Fatalf("got %d events, want 3: %+v", len(events), events) } if events[0].TextDelta != "Hel" || events[1].TextDelta != "lo" { t.Errorf("deltas = %q, %q, want Hel, lo", events[0].TextDelta, events[1].TextDelta) } final := events[2].Response if final == nil { t.Fatal("last event has no Response") } if got := final.Text(); got != "Hello" { t.Errorf("final text = %q, want Hello", got) } if final.FinishReason != llm.FinishStop { t.Errorf("FinishReason = %v, want stop", final.FinishReason) } if final.Usage != (llm.Usage{InputTokens: 5, OutputTokens: 2}) { t.Errorf("Usage = %+v, want {5 2}", final.Usage) } if final.Model != "openai/gpt-test" { t.Errorf("Model = %q, want openai/gpt-test", final.Model) } // Next after EOF keeps returning EOF; Close is idempotent. if _, err := s.Next(); err != io.EOF { t.Errorf("Next after EOF = %v, want io.EOF", err) } if err := s.Close(); err != nil { t.Errorf("first Close: %v", err) } if err := s.Close(); err != nil { t.Errorf("second Close: %v", err) } } func TestStreamParallelToolCalls(t *testing.T) { // Two interleaved calls with distinct indexes; id/name only on the first // fragment of each; arguments split across fragments. srv, _ := sseServer(t, `{"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_a","type":"function","function":{"name":"get_weather","arguments":""}}]},"finish_reason":null}]}`, `{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"city\":"}}]},"finish_reason":null}]}`, `{"choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"id":"call_b","type":"function","function":{"name":"get_time","arguments":"{\"tz\":"}}]},"finish_reason":null}]}`, `{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"Boston\"}"}}]},"finish_reason":null}]}`, `{"choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"\"EST\"}"}}]},"finish_reason":null}]}`, `{"choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`, `{"choices":[],"usage":{"prompt_tokens":11,"completion_tokens":9,"total_tokens":20}}`, `[DONE]`, ) m := testModel(t, srv, nil) s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if err != nil { t.Fatalf("Stream: %v", err) } defer s.Close() events := collect(t, s) if len(events) != 3 { t.Fatalf("got %d events, want 3 (two tool calls + response): %+v", len(events), events) } a, b := events[0].ToolCall, events[1].ToolCall if a == nil || b == nil { t.Fatalf("events 0/1 are not tool calls: %+v", events) } if a.ID != "call_a" || a.Name != "get_weather" || string(a.Arguments) != `{"city":"Boston"}` { t.Errorf("first call = %+v", a) } if b.ID != "call_b" || b.Name != "get_time" || string(b.Arguments) != `{"tz":"EST"}` { t.Errorf("second call = %+v", b) } final := events[2].Response if final == nil { t.Fatal("last event has no Response") } if len(final.ToolCalls) != 2 { t.Fatalf("final ToolCalls = %d, want 2", len(final.ToolCalls)) } if final.ToolCalls[0].ID != "call_a" || final.ToolCalls[1].ID != "call_b" { t.Errorf("final ToolCalls order = %q, %q", final.ToolCalls[0].ID, final.ToolCalls[1].ID) } if final.FinishReason != llm.FinishToolCalls { t.Errorf("FinishReason = %v, want tool_calls", final.FinishReason) } if final.Usage != (llm.Usage{InputTokens: 11, OutputTokens: 9}) { t.Errorf("Usage = %+v, want {11 9}", final.Usage) } } func TestStreamMidStreamError(t *testing.T) { srv, _ := sseServer(t, `{"choices":[{"index":0,"delta":{"content":"par"},"finish_reason":null}]}`, `{"error":{"message":"The server had an error while processing your request","type":"server_error","param":null,"code":null}}`, ) m := testModel(t, srv, nil) s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if err != nil { t.Fatalf("Stream: %v", err) } defer s.Close() ev, err := s.Next() if err != nil || ev.TextDelta != "par" { t.Fatalf("first event = %+v, %v; want TextDelta par", ev, err) } _, err = s.Next() apiErr, ok := errors.AsType[*llm.APIError](err) if !ok { t.Fatalf("err = %v (%T), want *llm.APIError", err, err) } if apiErr.Code != "server_error" { t.Errorf("Code = %q, want server_error", apiErr.Code) } if apiErr.Message != "The server had an error while processing your request" { t.Errorf("Message = %q", apiErr.Message) } if apiErr.Status != 0 { t.Errorf("Status = %d, want 0 (the HTTP stream was 200)", apiErr.Status) } } func TestStreamHTTPError(t *testing.T) { srv, _ := newServer(t, http.StatusTooManyRequests, `{"error":{"message":"Rate limit reached","type":"rate_limit_error","param":null,"code":"rate_limit_exceeded"}}`) m := testModel(t, srv, nil) _, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) apiErr, ok := errors.AsType[*llm.APIError](err) if !ok { t.Fatalf("err = %v (%T), want *llm.APIError from Stream itself", err, err) } if apiErr.Status != http.StatusTooManyRequests || apiErr.Code != "rate_limit_exceeded" { t.Errorf("Status/Code = %d/%q", apiErr.Status, apiErr.Code) } if got := llm.Classify(err); got != llm.ClassTransient { t.Errorf("Classify = %v, want transient", got) } } func TestStreamWithoutDoneSentinel(t *testing.T) { // Why: some compat servers close the connection without "data: [DONE]"; // a clean EOF must still produce the final Response. srv, _ := sseServer(t, `{"choices":[{"index":0,"delta":{"content":"ok"},"finish_reason":null}]}`, `{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`, ) m := testModel(t, srv, nil) s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if err != nil { t.Fatalf("Stream: %v", err) } defer s.Close() events := collect(t, s) if len(events) != 2 { t.Fatalf("got %d events, want 2: %+v", len(events), events) } final := events[1].Response if final == nil || final.Text() != "ok" || final.FinishReason != llm.FinishStop { t.Errorf("final = %+v", final) } } func TestStreamCloseEarly(t *testing.T) { srv, _ := sseServer(t, `{"choices":[{"index":0,"delta":{"content":"a"},"finish_reason":null}]}`, `{"choices":[{"index":0,"delta":{"content":"b"},"finish_reason":null}]}`, `{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`, `[DONE]`, ) m := testModel(t, srv, nil) s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if err != nil { t.Fatalf("Stream: %v", err) } if _, err := s.Next(); err != nil { t.Fatalf("Next: %v", err) } if err := s.Close(); err != nil { t.Errorf("Close mid-stream: %v", err) } if err := s.Close(); err != nil { t.Errorf("Close again: %v", err) } }