feat(llamaswap): add llama-swap provider + canonical imagegen interface
Add provider/llamaswap, a tailored provider for llama-swap (the model-swapping
proxy over llama.cpp / stable-diffusion.cpp). Its chat path delegates to
provider/openai at {base}/v1 — no duplicated wire client (ADR-0007) — with
legacy max_tokens, a Bearer no-key placeholder for keyless local instances, and
a timeout-free client so cold model swaps rely on context deadlines. The
"tailored" surface is concrete management methods (ListModels / Running /
Unload) that don't belong on the canonical llm.Provider interface. The
llama-swap:// DSN scheme builds an http base URL (local-first); a no-URL
built-in errors clearly on use, mirroring foreman.
Add imagegen, a new canonical text-to-image interface separate from llm
(Request/Result/Model/Provider; Image = llm.ImagePart so generated images feed
straight back into chat). First backend is llama-swap via OpenAI
/v1/images/generations (b64_json, bytes-only). Re-exported from the root. v1 is
txt2img only.
Hermetic httptest coverage for chat delegation, management endpoints, image
decode, and scheme wiring. ADR-0015 + ADR-0016, README support matrix +
image-gen section, CLAUDE.md package map, and progress.md updated in the same
commit.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,104 @@
|
||||
package llamaswap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/imagegen"
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// ImageModel implements imagegen.Provider, binding an image-generation model
|
||||
// served by llama-swap (routed to a stable-diffusion.cpp upstream). The id is
|
||||
// passed through verbatim and selects which upstream llama-swap loads.
|
||||
func (p *Provider) ImageModel(id string, opts ...imagegen.ModelOption) (imagegen.Model, error) {
|
||||
if p.baseURL == "" {
|
||||
return nil, fmt.Errorf("llama-swap provider %q: no base URL configured (set one via WithBaseURL or an LLM_* env DSN)", p.name)
|
||||
}
|
||||
_ = imagegen.ApplyModelOptions(opts)
|
||||
return &imageModel{p: p, id: id}, nil
|
||||
}
|
||||
|
||||
type imageModel struct {
|
||||
p *Provider
|
||||
id string
|
||||
}
|
||||
|
||||
// imageRequest is the OpenAI /v1/images/generations request shape. We always
|
||||
// request b64_json so the bytes come back inline (no second fetch).
|
||||
type imageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format"`
|
||||
}
|
||||
|
||||
type imageResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []struct {
|
||||
B64JSON string `json:"b64_json"`
|
||||
URL string `json:"url"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// Generate implements imagegen.Model via POST {base}/v1/images/generations.
|
||||
func (m *imageModel) Generate(ctx context.Context, req imagegen.Request, opts ...imagegen.Option) (*imagegen.Result, error) {
|
||||
req = req.Apply(opts...)
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
return nil, fmt.Errorf("%w: image generation requires a prompt", llm.ErrUnsupported)
|
||||
}
|
||||
|
||||
wire := imageRequest{
|
||||
Model: m.id,
|
||||
Prompt: req.Prompt,
|
||||
N: req.N,
|
||||
Size: req.Size,
|
||||
ResponseFormat: "b64_json",
|
||||
}
|
||||
|
||||
var resp imageResponse
|
||||
if err := m.p.doJSON(ctx, http.MethodPost, "/v1/images/generations", &wire, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := &imagegen.Result{Raw: &resp}
|
||||
for i, d := range resp.Data {
|
||||
if d.B64JSON == "" {
|
||||
// Why error rather than skip: a url-only entry means the backend
|
||||
// ignored response_format; we don't fetch remote content (mirrors
|
||||
// llm.ImagePart's bytes-only contract), so surface it.
|
||||
return nil, &llm.APIError{
|
||||
Provider: m.p.name,
|
||||
Model: m.id,
|
||||
Message: fmt.Sprintf("image %d returned no inline b64_json data", i),
|
||||
}
|
||||
}
|
||||
raw, err := base64.StdEncoding.DecodeString(d.B64JSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("llama-swap: decode image %d: %w", i, err)
|
||||
}
|
||||
out.Images = append(out.Images, llm.ImagePart{MIME: sniffImageMIME(raw), Data: raw})
|
||||
}
|
||||
if len(out.Images) == 0 {
|
||||
return nil, &llm.APIError{
|
||||
Provider: m.p.name,
|
||||
Model: m.id,
|
||||
Message: "image response contained no images",
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// sniffImageMIME identifies the image format from its leading bytes, defaulting
|
||||
// to image/png (stable-diffusion.cpp emits PNG) when detection is inconclusive.
|
||||
func sniffImageMIME(data []byte) string {
|
||||
mime := http.DetectContentType(data)
|
||||
if !strings.HasPrefix(mime, "image/") {
|
||||
return "image/png"
|
||||
}
|
||||
return mime
|
||||
}
|
||||
@@ -0,0 +1,241 @@
|
||||
// Package llamaswap implements majordomo's provider contract for llama-swap
|
||||
// (https://github.com/mostlygeek/llama-swap), an on-demand model-swapping
|
||||
// proxy that fronts llama.cpp (and stable-diffusion.cpp) servers, loading and
|
||||
// hot-swapping the requested model per request.
|
||||
//
|
||||
// Chat is OpenAI Chat Completions, byte-for-byte: this package does NOT carry
|
||||
// its own chat wire client. Provider.Model delegates to provider/openai
|
||||
// pointed at {baseURL}/v1 (ADR-0007: reuse, don't duplicate). What this
|
||||
// package adds beyond a bare OpenAI-compat endpoint is the "tailored" surface:
|
||||
//
|
||||
// - llama-swap management endpoints exposed as concrete methods — ListModels
|
||||
// (GET /v1/models), Running (GET /running), Unload (POST /api/models/unload)
|
||||
// — which have no place on the canonical llm.Provider interface;
|
||||
// - image generation via the imagegen interface (see image.go); and
|
||||
// - swap-aware defaults: the HTTP client carries NO timeout, because the
|
||||
// first request to an unloaded model blocks while llama-swap spawns the
|
||||
// upstream (its healthCheckTimeout is at least 15s). Bound a call with a
|
||||
// context deadline, never a client timeout.
|
||||
//
|
||||
// DSN form (registered as the "llama-swap" scheme): llama-swap://token@host:port
|
||||
// builds an http:// base URL (llama-swap is local-first; a TLS-fronted instance
|
||||
// can use the openai:// scheme for chat instead).
|
||||
package llamaswap
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/provider/openai"
|
||||
)
|
||||
|
||||
// DefaultName is the registry name used when WithName is not given.
|
||||
const DefaultName = "llama-swap"
|
||||
|
||||
// Provider is a llama-swap client. It satisfies llm.Provider (chat, delegated
|
||||
// to provider/openai) and imagegen.Provider (image generation), and exposes
|
||||
// llama-swap's management endpoints as concrete methods.
|
||||
type Provider struct {
|
||||
name string
|
||||
baseURL string // no trailing slash, no /v1 suffix; e.g. "http://host:port"
|
||||
token string // bearer credential; empty = no auth (local)
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// Option configures the provider.
|
||||
type Option func(*Provider)
|
||||
|
||||
// WithName overrides the registry name (default "llama-swap").
|
||||
func WithName(name string) Option { return func(p *Provider) { p.name = name } }
|
||||
|
||||
// WithBaseURL sets the llama-swap base URL (scheme://host[:port]); the /v1 and
|
||||
// management paths are appended internally. A trailing slash is trimmed.
|
||||
func WithBaseURL(u string) Option {
|
||||
return func(p *Provider) { p.baseURL = strings.TrimRight(u, "/") }
|
||||
}
|
||||
|
||||
// WithToken sets the bearer token (llama-swap API key). Empty means no
|
||||
// Authorization header.
|
||||
func WithToken(token string) Option { return func(p *Provider) { p.token = token } }
|
||||
|
||||
// WithHTTPClient overrides the HTTP client. Prefer context deadlines over a
|
||||
// client timeout: a cold model swap can legitimately take many seconds.
|
||||
func WithHTTPClient(c *http.Client) Option {
|
||||
return func(p *Provider) {
|
||||
if c != nil {
|
||||
p.client = c
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a llama-swap provider. Construction never fails; a missing base
|
||||
// URL surfaces at request time. The default client has no timeout (swap cold
|
||||
// starts); bound calls with a context deadline.
|
||||
func New(opts ...Option) *Provider {
|
||||
p := &Provider{name: DefaultName, client: &http.Client{}}
|
||||
for _, opt := range opts {
|
||||
opt(p)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// Name implements llm.Provider and imagegen.Provider.
|
||||
func (p *Provider) Name() string { return p.name }
|
||||
|
||||
// BaseURL reports the configured base URL (diagnostics).
|
||||
func (p *Provider) BaseURL() string { return p.baseURL }
|
||||
|
||||
// Model implements llm.Provider via llama-swap's OpenAI-compatible chat
|
||||
// endpoint, delegating to provider/openai. The id is passed through verbatim
|
||||
// and selects which upstream llama-swap loads.
|
||||
func (p *Provider) Model(id string, opts ...llm.ModelOption) (llm.Model, error) {
|
||||
if p.baseURL == "" {
|
||||
return nil, fmt.Errorf("llama-swap provider %q: no base URL configured (set one via WithBaseURL or an LLM_* env DSN)", p.name)
|
||||
}
|
||||
return p.chatProvider().Model(id, opts...)
|
||||
}
|
||||
|
||||
// chatProvider builds the OpenAI-compatible client for llama-swap's chat API.
|
||||
// Why a placeholder key when token is empty: the openai client treats a blank
|
||||
// key as a synthetic 401, but a local llama-swap may require no auth at all —
|
||||
// a bearer it ignores is harmless. Why legacy max_tokens: llama.cpp's OpenAI
|
||||
// shim honors "max_tokens", not "max_completion_tokens".
|
||||
func (p *Provider) chatProvider() *openai.Provider {
|
||||
key := p.token
|
||||
if key == "" {
|
||||
key = "no-key"
|
||||
}
|
||||
return openai.New(
|
||||
openai.WithName(p.name),
|
||||
openai.WithBaseURL(p.baseURL+"/v1"),
|
||||
openai.WithAPIKey(key),
|
||||
openai.WithLegacyMaxTokens(),
|
||||
openai.WithHTTPClient(p.client),
|
||||
)
|
||||
}
|
||||
|
||||
// --- management endpoints ---
|
||||
|
||||
// ModelInfo is one entry from llama-swap's GET /v1/models (the OpenAI model
|
||||
// list shape). Fields llama-swap adds beyond these are ignored.
|
||||
type ModelInfo struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
// ListModels returns the models llama-swap is configured to serve (GET
|
||||
// /v1/models). Unlisted models are excluded by llama-swap itself.
|
||||
func (p *Provider) ListModels(ctx context.Context) ([]ModelInfo, error) {
|
||||
var out struct {
|
||||
Data []ModelInfo `json:"data"`
|
||||
}
|
||||
if err := p.doJSON(ctx, http.MethodGet, "/v1/models", nil, &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out.Data, nil
|
||||
}
|
||||
|
||||
// Running returns llama-swap's currently-loaded models as the raw GET /running
|
||||
// payload. Why raw: llama-swap's /running shape is not a stable, OpenAI-style
|
||||
// contract, so this exposes the endpoint without pinning a schema this package
|
||||
// would have to guess.
|
||||
func (p *Provider) Running(ctx context.Context) (json.RawMessage, error) {
|
||||
var out json.RawMessage
|
||||
if err := p.doJSON(ctx, http.MethodGet, "/running", nil, &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Unload unloads a running model to free its resources (POST
|
||||
// /api/models/unload/:model). An empty model unloads all running models (POST
|
||||
// /api/models/unload).
|
||||
func (p *Provider) Unload(ctx context.Context, model string) error {
|
||||
path := "/api/models/unload"
|
||||
if model != "" {
|
||||
path += "/" + model
|
||||
}
|
||||
return p.doJSON(ctx, http.MethodPost, path, nil, nil)
|
||||
}
|
||||
|
||||
// --- shared HTTP helper for management + image endpoints ---
|
||||
|
||||
// doJSON performs a request to a llama-swap endpoint relative to baseURL,
|
||||
// optionally encoding body and decoding into out (either may be nil). Transport
|
||||
// failures are wrapped raw so llm.Classify still sees the underlying net error;
|
||||
// non-2xx responses become *llm.APIError.
|
||||
func (p *Provider) doJSON(ctx context.Context, method, path string, body, out any) error {
|
||||
var rdr io.Reader
|
||||
if body != nil {
|
||||
b, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("llama-swap: encode request: %w", err)
|
||||
}
|
||||
rdr = bytes.NewReader(b)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, method, p.baseURL+path, rdr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("llama-swap: build request: %w", err)
|
||||
}
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
if p.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+p.token)
|
||||
}
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("llama-swap: do request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode/100 != 2 {
|
||||
return p.apiError(resp, "")
|
||||
}
|
||||
if out != nil {
|
||||
if err := json.NewDecoder(resp.Body).Decode(out); err != nil {
|
||||
return fmt.Errorf("llama-swap: decode response: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// apiError converts a non-2xx response into *llm.APIError, tolerating the
|
||||
// OpenAI {"error":{"message",...}} envelope, the Ollama-style {"error":"..."}
|
||||
// string form, and a raw body.
|
||||
func (p *Provider) apiError(resp *http.Response, model string) error {
|
||||
e := &llm.APIError{Provider: p.name, Model: model, Status: resp.StatusCode}
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
|
||||
var env struct {
|
||||
Error json.RawMessage `json:"error"`
|
||||
}
|
||||
if json.Unmarshal(body, &env) == nil && len(env.Error) > 0 {
|
||||
var obj struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code"`
|
||||
}
|
||||
if json.Unmarshal(env.Error, &obj) == nil && (obj.Message != "" || obj.Code != "" || obj.Type != "") {
|
||||
e.Message = obj.Message
|
||||
e.Code = obj.Code
|
||||
if e.Code == "" {
|
||||
e.Code = obj.Type
|
||||
}
|
||||
return e
|
||||
}
|
||||
var msg string
|
||||
if json.Unmarshal(env.Error, &msg) == nil && msg != "" {
|
||||
e.Message = msg
|
||||
return e
|
||||
}
|
||||
}
|
||||
e.Message = strings.TrimSpace(string(body))
|
||||
return e
|
||||
}
|
||||
@@ -0,0 +1,217 @@
|
||||
package llamaswap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/imagegen"
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// 1x1 transparent PNG, base64 (used to assert image decoding end-to-end).
|
||||
const onePixelPNG = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
|
||||
func TestChatDelegatesToOpenAI(t *testing.T) {
|
||||
var gotPath, gotAuth string
|
||||
var gotBody map[string]any
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
_ = json.NewDecoder(r.Body).Decode(&gotBody)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"choices":[{"message":{"role":"assistant","content":"hi"},"finish_reason":"stop"}]}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := New(WithBaseURL(srv.URL), WithToken("test-token"), WithHTTPClient(srv.Client()))
|
||||
m, err := p.Model("qwen3:14b")
|
||||
if err != nil {
|
||||
t.Fatalf("Model: %v", err)
|
||||
}
|
||||
resp, err := m.Generate(context.Background(), llm.Request{
|
||||
Messages: []llm.Message{llm.UserText("hello")},
|
||||
MaxTokens: 64,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if resp.Text() != "hi" {
|
||||
t.Errorf("Text = %q, want %q", resp.Text(), "hi")
|
||||
}
|
||||
if gotPath != "/v1/chat/completions" {
|
||||
t.Errorf("path = %q, want /v1/chat/completions", gotPath)
|
||||
}
|
||||
if gotAuth != "Bearer test-token" {
|
||||
t.Errorf("auth = %q, want Bearer test-token", gotAuth)
|
||||
}
|
||||
// llama.cpp's OpenAI shim wants the legacy max_tokens field.
|
||||
if _, ok := gotBody["max_tokens"]; !ok {
|
||||
t.Errorf("request missing max_tokens (legacy); body=%v", gotBody)
|
||||
}
|
||||
if _, ok := gotBody["max_completion_tokens"]; ok {
|
||||
t.Errorf("request used max_completion_tokens; want legacy max_tokens")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatNoTokenSendsPlaceholder(t *testing.T) {
|
||||
var gotAuth string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := New(WithBaseURL(srv.URL), WithHTTPClient(srv.Client())) // no token
|
||||
m, _ := p.Model("m")
|
||||
if _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("x")}}); err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
// Keyless local llama-swap: a placeholder bearer it ignores, never a blank
|
||||
// that the openai client would reject as a synthetic 401.
|
||||
if gotAuth != "Bearer no-key" {
|
||||
t.Errorf("auth = %q, want Bearer no-key", gotAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelNoBaseURL(t *testing.T) {
|
||||
if _, err := New().Model("m"); err == nil {
|
||||
t.Fatal("expected error for missing base URL")
|
||||
}
|
||||
if _, err := New().ImageModel("m"); err == nil {
|
||||
t.Fatal("expected error for missing base URL (image)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModels(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/models" {
|
||||
t.Errorf("path = %q", r.URL.Path)
|
||||
}
|
||||
_, _ = w.Write([]byte(`{"object":"list","data":[{"id":"qwen3:14b","object":"model","owned_by":"llama-swap"},{"id":"sd","object":"model"}]}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := New(WithBaseURL(srv.URL), WithHTTPClient(srv.Client()))
|
||||
models, err := p.ListModels(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("ListModels: %v", err)
|
||||
}
|
||||
if len(models) != 2 || models[0].ID != "qwen3:14b" {
|
||||
t.Fatalf("models = %+v", models)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnload(t *testing.T) {
|
||||
var gotPath, gotMethod string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath, gotMethod = r.URL.Path, r.Method
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := New(WithBaseURL(srv.URL), WithHTTPClient(srv.Client()))
|
||||
if err := p.Unload(context.Background(), "qwen3:14b"); err != nil {
|
||||
t.Fatalf("Unload: %v", err)
|
||||
}
|
||||
if gotMethod != http.MethodPost || gotPath != "/api/models/unload/qwen3:14b" {
|
||||
t.Errorf("got %s %s", gotMethod, gotPath)
|
||||
}
|
||||
|
||||
if err := p.Unload(context.Background(), ""); err != nil {
|
||||
t.Fatalf("Unload all: %v", err)
|
||||
}
|
||||
if gotPath != "/api/models/unload" {
|
||||
t.Errorf("unload-all path = %q", gotPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunningRaw(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte(`{"running":["qwen3:14b"]}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := New(WithBaseURL(srv.URL), WithHTTPClient(srv.Client()))
|
||||
raw, err := p.Running(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Running: %v", err)
|
||||
}
|
||||
if string(raw) != `{"running":["qwen3:14b"]}` {
|
||||
t.Errorf("raw = %s", raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageGenerate(t *testing.T) {
|
||||
var gotBody map[string]any
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/images/generations" {
|
||||
t.Errorf("path = %q", r.URL.Path)
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&gotBody)
|
||||
_, _ = w.Write([]byte(`{"created":1,"data":[{"b64_json":"` + onePixelPNG + `"}]}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := New(WithBaseURL(srv.URL), WithHTTPClient(srv.Client()))
|
||||
im, err := p.ImageModel("sd")
|
||||
if err != nil {
|
||||
t.Fatalf("ImageModel: %v", err)
|
||||
}
|
||||
res, err := im.Generate(context.Background(), imagegen.Request{Prompt: "a red bicycle"}, imagegen.WithSize("512x512"))
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if len(res.Images) != 1 {
|
||||
t.Fatalf("images = %d, want 1", len(res.Images))
|
||||
}
|
||||
if res.Images[0].MIME != "image/png" {
|
||||
t.Errorf("MIME = %q, want image/png", res.Images[0].MIME)
|
||||
}
|
||||
if len(res.Images[0].Data) == 0 {
|
||||
t.Error("decoded image has no bytes")
|
||||
}
|
||||
// response_format must be forced to b64_json, and options applied.
|
||||
if gotBody["response_format"] != "b64_json" {
|
||||
t.Errorf("response_format = %v, want b64_json", gotBody["response_format"])
|
||||
}
|
||||
if gotBody["size"] != "512x512" {
|
||||
t.Errorf("size = %v, want 512x512", gotBody["size"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageGenerateEmptyPrompt(t *testing.T) {
|
||||
p := New(WithBaseURL("http://example.invalid"))
|
||||
im, _ := p.ImageModel("sd")
|
||||
_, err := im.Generate(context.Background(), imagegen.Request{Prompt: " "})
|
||||
if !errors.Is(err, llm.ErrUnsupported) {
|
||||
t.Errorf("err = %v, want ErrUnsupported", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIErrorClassifies(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"error":{"message":"slow down","code":"rate_limited"}}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := New(WithBaseURL(srv.URL), WithHTTPClient(srv.Client()))
|
||||
_, err := p.ListModels(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
var apiErr *llm.APIError
|
||||
if !errors.As(err, &apiErr) {
|
||||
t.Fatalf("err type = %T, want *llm.APIError", err)
|
||||
}
|
||||
if apiErr.Status != http.StatusTooManyRequests || apiErr.Code != "rate_limited" {
|
||||
t.Errorf("apiErr = %+v", apiErr)
|
||||
}
|
||||
if llm.Classify(err) != llm.ClassTransient {
|
||||
t.Errorf("429 should classify transient")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user