package llamaswap import ( "context" "encoding/json" "errors" "net/http" "net/http/httptest" "testing" "gitea.stevedudenhoeffer.com/steve/majordomo/imagegen" "gitea.stevedudenhoeffer.com/steve/majordomo/llm" ) // 1x1 transparent PNG, base64 (used to assert image decoding end-to-end). const onePixelPNG = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" func TestChatDelegatesToOpenAI(t *testing.T) { var gotPath, gotAuth string var gotBody map[string]any srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotPath = r.URL.Path gotAuth = r.Header.Get("Authorization") _ = json.NewDecoder(r.Body).Decode(&gotBody) w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"choices":[{"message":{"role":"assistant","content":"hi"},"finish_reason":"stop"}]}`)) })) defer srv.Close() p := New(WithBaseURL(srv.URL), WithToken("test-token"), WithHTTPClient(srv.Client())) m, err := p.Model("qwen3:14b") if err != nil { t.Fatalf("Model: %v", err) } resp, err := m.Generate(context.Background(), llm.Request{ Messages: []llm.Message{llm.UserText("hello")}, MaxTokens: 64, }) if err != nil { t.Fatalf("Generate: %v", err) } if resp.Text() != "hi" { t.Errorf("Text = %q, want %q", resp.Text(), "hi") } if gotPath != "/v1/chat/completions" { t.Errorf("path = %q, want /v1/chat/completions", gotPath) } if gotAuth != "Bearer test-token" { t.Errorf("auth = %q, want Bearer test-token", gotAuth) } // llama.cpp's OpenAI shim wants the legacy max_tokens field. if _, ok := gotBody["max_tokens"]; !ok { t.Errorf("request missing max_tokens (legacy); body=%v", gotBody) } if _, ok := gotBody["max_completion_tokens"]; ok { t.Errorf("request used max_completion_tokens; want legacy max_tokens") } } func TestChatNoTokenSendsPlaceholder(t *testing.T) { var gotAuth string srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotAuth = r.Header.Get("Authorization") _, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`)) })) defer srv.Close() p := New(WithBaseURL(srv.URL), WithHTTPClient(srv.Client())) // no token m, _ := p.Model("m") if _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("x")}}); err != nil { t.Fatalf("Generate: %v", err) } // Keyless local llama-swap: a placeholder bearer it ignores, never a blank // that the openai client would reject as a synthetic 401. if gotAuth != "Bearer no-key" { t.Errorf("auth = %q, want Bearer no-key", gotAuth) } } func TestModelNoBaseURL(t *testing.T) { if _, err := New().Model("m"); err == nil { t.Fatal("expected error for missing base URL") } if _, err := New().ImageModel("m"); err == nil { t.Fatal("expected error for missing base URL (image)") } } func TestListModels(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/models" { t.Errorf("path = %q", r.URL.Path) } _, _ = w.Write([]byte(`{"object":"list","data":[{"id":"qwen3:14b","object":"model","owned_by":"llama-swap"},{"id":"sd","object":"model"}]}`)) })) defer srv.Close() p := New(WithBaseURL(srv.URL), WithHTTPClient(srv.Client())) models, err := p.ListModels(context.Background()) if err != nil { t.Fatalf("ListModels: %v", err) } if len(models) != 2 || models[0].ID != "qwen3:14b" { t.Fatalf("models = %+v", models) } } func TestUnload(t *testing.T) { var gotPath, gotMethod string srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotPath, gotMethod = r.URL.Path, r.Method w.WriteHeader(http.StatusOK) })) defer srv.Close() p := New(WithBaseURL(srv.URL), WithHTTPClient(srv.Client())) if err := p.Unload(context.Background(), "qwen3:14b"); err != nil { t.Fatalf("Unload: %v", err) } if gotMethod != http.MethodPost || gotPath != "/api/models/unload/qwen3:14b" { t.Errorf("got %s %s", gotMethod, gotPath) } if err := p.Unload(context.Background(), ""); err != nil { t.Fatalf("Unload all: %v", err) } if gotPath != "/api/models/unload" { t.Errorf("unload-all path = %q", gotPath) } // A model id with a path separator is rejected before any request. if err := p.Unload(context.Background(), "../admin"); err == nil { t.Error("expected error for model id with path separator") } } func TestManagementNoBaseURL(t *testing.T) { p := New() // no base URL if _, err := p.ListModels(context.Background()); err == nil { t.Error("ListModels: expected error for missing base URL") } if _, err := p.Running(context.Background()); err == nil { t.Error("Running: expected error for missing base URL") } if err := p.Unload(context.Background(), "m"); err == nil { t.Error("Unload: expected error for missing base URL") } } func TestRunningRaw(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte(`{"running":["qwen3:14b"]}`)) })) defer srv.Close() p := New(WithBaseURL(srv.URL), WithHTTPClient(srv.Client())) raw, err := p.Running(context.Background()) if err != nil { t.Fatalf("Running: %v", err) } if string(raw) != `{"running":["qwen3:14b"]}` { t.Errorf("raw = %s", raw) } } func TestImageGenerate(t *testing.T) { var gotBody map[string]any srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/images/generations" { t.Errorf("path = %q", r.URL.Path) } _ = json.NewDecoder(r.Body).Decode(&gotBody) _, _ = w.Write([]byte(`{"created":1,"data":[{"b64_json":"` + onePixelPNG + `"}]}`)) })) defer srv.Close() p := New(WithBaseURL(srv.URL), WithHTTPClient(srv.Client())) im, err := p.ImageModel("sd") if err != nil { t.Fatalf("ImageModel: %v", err) } res, err := im.Generate(context.Background(), imagegen.Request{Prompt: "a red bicycle"}, imagegen.WithSize("512x512")) if err != nil { t.Fatalf("Generate: %v", err) } if len(res.Images) != 1 { t.Fatalf("images = %d, want 1", len(res.Images)) } if res.Images[0].MIME != "image/png" { t.Errorf("MIME = %q, want image/png", res.Images[0].MIME) } if len(res.Images[0].Data) == 0 { t.Error("decoded image has no bytes") } // response_format must be forced to b64_json, and options applied. if gotBody["response_format"] != "b64_json" { t.Errorf("response_format = %v, want b64_json", gotBody["response_format"]) } if gotBody["size"] != "512x512" { t.Errorf("size = %v, want 512x512", gotBody["size"]) } } func TestImageGenerateEmptyPrompt(t *testing.T) { p := New(WithBaseURL("http://example.invalid")) im, _ := p.ImageModel("sd") _, err := im.Generate(context.Background(), imagegen.Request{Prompt: " "}) if !errors.Is(err, llm.ErrUnsupported) { t.Errorf("err = %v, want ErrUnsupported", err) } } func TestAPIErrorClassifies(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusTooManyRequests) _, _ = w.Write([]byte(`{"error":{"message":"slow down","code":"rate_limited"}}`)) })) defer srv.Close() p := New(WithBaseURL(srv.URL), WithHTTPClient(srv.Client())) _, err := p.ListModels(context.Background()) if err == nil { t.Fatal("expected error") } var apiErr *llm.APIError if !errors.As(err, &apiErr) { t.Fatalf("err type = %T, want *llm.APIError", err) } if apiErr.Status != http.StatusTooManyRequests || apiErr.Code != "rate_limited" { t.Errorf("apiErr = %+v", apiErr) } if llm.Classify(err) != llm.ClassTransient { t.Errorf("429 should classify transient") } }