a213c18263
The OpenAI /v1/images/generations endpoint ignores `seed` on our stable-diffusion.cpp build — every render of a given prompt comes back byte-identical, so a drawbot batch of N collapsed to one image. Switch the image provider to sd-server's A1111 /sdapi/v1/txt2img endpoint, which honors `seed` (verified live: distinct seeds -> distinct images on SDXL and Qwen-Image). Size is split into width/height; llama-swap still routes by the `model` field. Tests + ADR-0016 updated.
275 lines
8.9 KiB
Go
275 lines
8.9 KiB
Go
package llamaswap
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"gitea.stevedudenhoeffer.com/steve/majordomo/imagegen"
|
|
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
|
)
|
|
|
|
// 1x1 transparent PNG, base64 (used to assert image decoding end-to-end).
|
|
const onePixelPNG = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
|
|
|
func TestChatDelegatesToOpenAI(t *testing.T) {
|
|
var gotPath, gotAuth string
|
|
var gotBody map[string]any
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
gotPath = r.URL.Path
|
|
gotAuth = r.Header.Get("Authorization")
|
|
_ = json.NewDecoder(r.Body).Decode(&gotBody)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"choices":[{"message":{"role":"assistant","content":"hi"},"finish_reason":"stop"}]}`))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
p := New(WithBaseURL(srv.URL), WithToken("test-token"), WithHTTPClient(srv.Client()))
|
|
m, err := p.Model("qwen3:14b")
|
|
if err != nil {
|
|
t.Fatalf("Model: %v", err)
|
|
}
|
|
resp, err := m.Generate(context.Background(), llm.Request{
|
|
Messages: []llm.Message{llm.UserText("hello")},
|
|
MaxTokens: 64,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
if resp.Text() != "hi" {
|
|
t.Errorf("Text = %q, want %q", resp.Text(), "hi")
|
|
}
|
|
if gotPath != "/v1/chat/completions" {
|
|
t.Errorf("path = %q, want /v1/chat/completions", gotPath)
|
|
}
|
|
if gotAuth != "Bearer test-token" {
|
|
t.Errorf("auth = %q, want Bearer test-token", gotAuth)
|
|
}
|
|
// llama.cpp's OpenAI shim wants the legacy max_tokens field.
|
|
if _, ok := gotBody["max_tokens"]; !ok {
|
|
t.Errorf("request missing max_tokens (legacy); body=%v", gotBody)
|
|
}
|
|
if _, ok := gotBody["max_completion_tokens"]; ok {
|
|
t.Errorf("request used max_completion_tokens; want legacy max_tokens")
|
|
}
|
|
}
|
|
|
|
func TestChatNoTokenSendsPlaceholder(t *testing.T) {
|
|
var gotAuth string
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
gotAuth = r.Header.Get("Authorization")
|
|
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
p := New(WithBaseURL(srv.URL), WithHTTPClient(srv.Client())) // no token
|
|
m, _ := p.Model("m")
|
|
if _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("x")}}); err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
// Keyless local llama-swap: a placeholder bearer it ignores, never a blank
|
|
// that the openai client would reject as a synthetic 401.
|
|
if gotAuth != "Bearer no-key" {
|
|
t.Errorf("auth = %q, want Bearer no-key", gotAuth)
|
|
}
|
|
}
|
|
|
|
func TestModelNoBaseURL(t *testing.T) {
|
|
if _, err := New().Model("m"); err == nil {
|
|
t.Fatal("expected error for missing base URL")
|
|
}
|
|
if _, err := New().ImageModel("m"); err == nil {
|
|
t.Fatal("expected error for missing base URL (image)")
|
|
}
|
|
}
|
|
|
|
func TestListModels(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/v1/models" {
|
|
t.Errorf("path = %q", r.URL.Path)
|
|
}
|
|
_, _ = w.Write([]byte(`{"object":"list","data":[{"id":"qwen3:14b","object":"model","owned_by":"llama-swap"},{"id":"sd","object":"model"}]}`))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
p := New(WithBaseURL(srv.URL), WithHTTPClient(srv.Client()))
|
|
models, err := p.ListModels(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("ListModels: %v", err)
|
|
}
|
|
if len(models) != 2 || models[0].ID != "qwen3:14b" {
|
|
t.Fatalf("models = %+v", models)
|
|
}
|
|
}
|
|
|
|
func TestUnload(t *testing.T) {
|
|
var gotPath, gotMethod string
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
gotPath, gotMethod = r.URL.Path, r.Method
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
p := New(WithBaseURL(srv.URL), WithHTTPClient(srv.Client()))
|
|
if err := p.Unload(context.Background(), "qwen3:14b"); err != nil {
|
|
t.Fatalf("Unload: %v", err)
|
|
}
|
|
if gotMethod != http.MethodPost || gotPath != "/api/models/unload/qwen3:14b" {
|
|
t.Errorf("got %s %s", gotMethod, gotPath)
|
|
}
|
|
|
|
if err := p.Unload(context.Background(), ""); err != nil {
|
|
t.Fatalf("Unload all: %v", err)
|
|
}
|
|
if gotPath != "/api/models/unload" {
|
|
t.Errorf("unload-all path = %q", gotPath)
|
|
}
|
|
|
|
// A model id with a path separator is rejected before any request.
|
|
if err := p.Unload(context.Background(), "../admin"); err == nil {
|
|
t.Error("expected error for model id with path separator")
|
|
}
|
|
}
|
|
|
|
func TestManagementNoBaseURL(t *testing.T) {
|
|
p := New() // no base URL
|
|
if _, err := p.ListModels(context.Background()); err == nil {
|
|
t.Error("ListModels: expected error for missing base URL")
|
|
}
|
|
if _, err := p.Running(context.Background()); err == nil {
|
|
t.Error("Running: expected error for missing base URL")
|
|
}
|
|
if err := p.Unload(context.Background(), "m"); err == nil {
|
|
t.Error("Unload: expected error for missing base URL")
|
|
}
|
|
}
|
|
|
|
func TestRunningRaw(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = w.Write([]byte(`{"running":["qwen3:14b"]}`))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
p := New(WithBaseURL(srv.URL), WithHTTPClient(srv.Client()))
|
|
raw, err := p.Running(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("Running: %v", err)
|
|
}
|
|
if string(raw) != `{"running":["qwen3:14b"]}` {
|
|
t.Errorf("raw = %s", raw)
|
|
}
|
|
}
|
|
|
|
func TestImageGenerate(t *testing.T) {
|
|
var gotBody map[string]any
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/sdapi/v1/txt2img" {
|
|
t.Errorf("path = %q", r.URL.Path)
|
|
}
|
|
_ = json.NewDecoder(r.Body).Decode(&gotBody)
|
|
_, _ = w.Write([]byte(`{"images":["` + onePixelPNG + `"]}`))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
p := New(WithBaseURL(srv.URL), WithHTTPClient(srv.Client()))
|
|
im, err := p.ImageModel("sd")
|
|
if err != nil {
|
|
t.Fatalf("ImageModel: %v", err)
|
|
}
|
|
res, err := im.Generate(context.Background(), imagegen.Request{Prompt: "a red bicycle"}, imagegen.WithSize("512x512"))
|
|
if err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
if len(res.Images) != 1 {
|
|
t.Fatalf("images = %d, want 1", len(res.Images))
|
|
}
|
|
if res.Images[0].MIME != "image/png" {
|
|
t.Errorf("MIME = %q, want image/png", res.Images[0].MIME)
|
|
}
|
|
if len(res.Images[0].Data) == 0 {
|
|
t.Error("decoded image has no bytes")
|
|
}
|
|
// Size is split into width/height ints for the A1111 endpoint.
|
|
if gotBody["width"] != float64(512) || gotBody["height"] != float64(512) {
|
|
t.Errorf("width/height = %v/%v, want 512/512", gotBody["width"], gotBody["height"])
|
|
}
|
|
}
|
|
|
|
func TestImageGenerateSettings(t *testing.T) {
|
|
var gotBody map[string]any
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_ = json.NewDecoder(r.Body).Decode(&gotBody)
|
|
_, _ = w.Write([]byte(`{"images":["` + onePixelPNG + `"]}`))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
p := New(WithBaseURL(srv.URL), WithHTTPClient(srv.Client()))
|
|
im, _ := p.ImageModel("sd")
|
|
|
|
// Unset overrides must be omitted entirely so sd-server keeps its own
|
|
// per-model defaults.
|
|
if _, err := im.Generate(context.Background(), imagegen.Request{Prompt: "x"}); err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
for _, k := range []string{"steps", "cfg_scale", "negative_prompt", "sample_method", "seed"} {
|
|
if v, ok := gotBody[k]; ok {
|
|
t.Errorf("unset request sent %q = %v, want omitted", k, v)
|
|
}
|
|
}
|
|
|
|
// Set overrides are forwarded with the sd-server-friendly field names.
|
|
gotBody = nil
|
|
_, err := im.Generate(context.Background(), imagegen.Request{Prompt: "x"},
|
|
imagegen.WithSteps(8),
|
|
imagegen.WithCFGScale(3.5),
|
|
imagegen.WithNegativePrompt("blurry"),
|
|
imagegen.WithSampler("euler"),
|
|
imagegen.WithSeed(42),
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
want := map[string]any{"steps": float64(8), "cfg_scale": 3.5, "negative_prompt": "blurry", "sample_method": "euler", "seed": float64(42)}
|
|
for k, w := range want {
|
|
if gotBody[k] != w {
|
|
t.Errorf("%s = %v, want %v", k, gotBody[k], w)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestImageGenerateEmptyPrompt(t *testing.T) {
|
|
p := New(WithBaseURL("http://example.invalid"))
|
|
im, _ := p.ImageModel("sd")
|
|
_, err := im.Generate(context.Background(), imagegen.Request{Prompt: " "})
|
|
if !errors.Is(err, llm.ErrUnsupported) {
|
|
t.Errorf("err = %v, want ErrUnsupported", err)
|
|
}
|
|
}
|
|
|
|
func TestAPIErrorClassifies(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
_, _ = w.Write([]byte(`{"error":{"message":"slow down","code":"rate_limited"}}`))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
p := New(WithBaseURL(srv.URL), WithHTTPClient(srv.Client()))
|
|
_, err := p.ListModels(context.Background())
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
var apiErr *llm.APIError
|
|
if !errors.As(err, &apiErr) {
|
|
t.Fatalf("err type = %T, want *llm.APIError", err)
|
|
}
|
|
if apiErr.Status != http.StatusTooManyRequests || apiErr.Code != "rate_limited" {
|
|
t.Errorf("apiErr = %+v", apiErr)
|
|
}
|
|
if llm.Classify(err) != llm.ClassTransient {
|
|
t.Errorf("429 should classify transient")
|
|
}
|
|
}
|