Files
llama-swap/internal/router/loading_test.go
T
Benson Wong 62aea0e83d internal/router,server,shared: refactor auth, libs (#839)
- refactor shared http functionality into internal/shared/http.go
- remove stripping of Authorization and x-api-key
- add Request Context middleware to internal/server
- add /ui and /metrics behind auth middleware, fixes #717

Fix #717
Updates: #834
2026-06-13 10:19:04 -07:00

266 lines
6.9 KiB
Go

package router
import (
"bufio"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/logmon"
)
func TestLoadingWriter_SSEHeadersAndInitialMessage(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
if ct := lw.Header().Get("Content-Type"); ct != "text/event-stream" {
t.Errorf("Content-Type: want text/event-stream, got %q", ct)
}
if cc := lw.Header().Get("Cache-Control"); cc != "no-cache" {
t.Errorf("Cache-Control: want no-cache, got %q", cc)
}
if conn := lw.Header().Get("Connection"); conn != "keep-alive" {
t.Errorf("Connection: want keep-alive, got %q", conn)
}
body := w.Body.String()
if !strings.HasPrefix(body, "data: ") {
t.Errorf("expected SSE data: prefix, got: %s", body)
}
content := extractStreamedContent(body)
if !strings.Contains(content, "━━━━━\n") {
t.Errorf("missing separator in streamed content: %q", content)
}
if !strings.Contains(content, "llama-swap loading model: test-model\n") {
t.Errorf("missing initial message in streamed content: %q", content)
}
}
func TestLoadingWriter_WriteHeaderOnce(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.WriteHeader(http.StatusCreated)
if w.Code != http.StatusOK {
t.Errorf("first WriteHeader: want %d, got %d", http.StatusOK, w.Code)
}
}
func TestLoadingWriter_WritePassthrough(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.Write([]byte("hello"))
lw.Flush()
body := w.Body.String()
if !strings.Contains(body, "hello") {
t.Errorf("Write passthrough failed, body: %s", body)
}
}
func TestLoadingWriter_StartStopsOnCancel(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.tickDuration = 10 * time.Millisecond
lw.loopStarted = make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go lw.start(ctx)
<-lw.loopStarted
cancel()
if !lw.waitForCompletion(time.Second) {
t.Fatal("waitForCompletion timed out")
}
body := w.Body.String()
if !strings.Contains(body, "Done!") {
t.Errorf("expected Done! message, body: %s", body)
}
}
func TestLoadingWriter_StartShowsSetUpdate(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.tickDuration = 10 * time.Millisecond
lw.charPerSecond = 0
lw.loopStarted = make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go lw.start(ctx)
<-lw.loopStarted
lw.setUpdate("custom status message")
time.Sleep(50 * time.Millisecond)
cancel()
if !lw.waitForCompletion(time.Second) {
t.Fatal("waitForCompletion timed out")
}
body := w.Body.String()
content := extractStreamedContent(body)
if !strings.Contains(content, "custom status message") {
t.Errorf("expected setUpdate message in output, got: %q", content)
}
}
func TestLoadingWriter_SendDataFormat(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.sendData("hello world")
body := w.Body.String()
if !strings.Contains(body, `"reasoning_content":"hello world"`) {
t.Errorf("expected reasoning_content in SSE data, body: %s", body)
}
if !strings.HasPrefix(body, "data: ") {
t.Errorf("expected data: prefix, got: %s", body)
}
}
func TestLoadingWriter_SendLine(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.charPerSecond = 0
// Capture only the content from this sendLine call
before := w.Body.Len()
lw.sendLine("line content")
after := w.Body.Len()
chunkBody := w.Body.String()[before:after]
content := extractStreamedContent(chunkBody)
if content != "line content\n" {
t.Errorf("expected complete streamed line, got: %q", content)
}
}
func TestLoadingWriter_FlushesPeriodicallyDuringStatusUpdates(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.tickDuration = 10 * time.Millisecond
lw.charPerSecond = 0
lw.loopStarted = make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
lw.start(ctx)
close(done)
}()
<-lw.loopStarted
time.Sleep(50 * time.Millisecond)
cancel()
<-done
body := w.Body.String()
lines := countSSEMessages(body)
if lines < 2 {
t.Errorf("expected multiple SSE messages from periodic updates, got %d", lines)
}
}
func TestLoadingWriter_ReqStored(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
if lw.req != req {
t.Fatal("req not stored")
}
}
func TestIsLoadingPath(t *testing.T) {
tests := []struct {
path string
want bool
}{
{"/v1/chat/completions", true},
{"/v1/chat/completions/extra", true},
{"/v1/completions", false},
{"/v1/embeddings", false},
{"/health", false},
{"", false},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
if got := isLoadingPath(tt.path); got != tt.want {
t.Errorf("isLoadingPath(%q) = %v, want %v", tt.path, got, tt.want)
}
})
}
}
func countSSEMessages(s string) int {
scanner := bufio.NewScanner(strings.NewReader(s))
count := 0
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "data: ") {
count++
}
}
return count
}
func extractStreamedContent(body string) string {
var result strings.Builder
scanner := bufio.NewScanner(strings.NewReader(body))
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
jsonData := strings.TrimPrefix(line, "data: ")
var msg struct {
Choices []struct {
Delta struct {
ReasoningContent string `json:"reasoning_content"`
} `json:"delta"`
} `json:"choices"`
}
if err := json.Unmarshal([]byte(jsonData), &msg); err != nil {
continue
}
if len(msg.Choices) > 0 {
result.WriteString(msg.Choices[0].Delta.ReasoningContent)
}
}
return result.String()
}