fix(llamaswap): use A1111 /sdapi/v1/txt2img so seed is honored
The OpenAI /v1/images/generations endpoint ignores `seed` on our stable-diffusion.cpp build — every render of a given prompt comes back byte-identical, so a drawbot batch of N collapsed to one image. Switch the image provider to sd-server's A1111 /sdapi/v1/txt2img endpoint, which honors `seed` (verified live: distinct seeds -> distinct images on SDXL and Qwen-Image). Size is split into width/height; llama-swap still routes by the `model` field. Tests + ADR-0016 updated.
This commit is contained in:
+59
-43
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/imagegen"
|
||||
@@ -27,34 +28,32 @@ type imageModel struct {
|
||||
id string
|
||||
}
|
||||
|
||||
// imageRequest is the OpenAI /v1/images/generations request shape, plus the
|
||||
// stable-diffusion.cpp extras llama-swap forwards to sd-server. We always
|
||||
// request b64_json so the bytes come back inline (no second fetch). The
|
||||
// optional fields are pointers/omitempty so an unset value is omitted entirely
|
||||
// and sd-server falls back to the model's own default (a field name a given
|
||||
// sd-server build doesn't recognize is simply ignored — harmless).
|
||||
type imageRequest struct {
|
||||
// txt2imgRequest is the stable-diffusion.cpp sd-server A1111 request shape
|
||||
// (POST /sdapi/v1/txt2img). We use this endpoint rather than the OpenAI
|
||||
// /v1/images/generations one because that endpoint IGNORES `seed` on this
|
||||
// sd-server build — every render of a given prompt comes back byte-identical,
|
||||
// so a batch of N collapses to one image. /sdapi/v1/txt2img honours `seed`,
|
||||
// giving real variety. llama-swap still routes by the `model` field in the
|
||||
// body. Optional fields are pointers/omitempty so an unset value falls back to
|
||||
// the model's baked default (the per-model --steps/--cfg-scale/etc. flags).
|
||||
type txt2imgRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format"`
|
||||
NegativePrompt string `json:"negative_prompt,omitempty"`
|
||||
Seed *int64 `json:"seed,omitempty"`
|
||||
Steps *int `json:"steps,omitempty"`
|
||||
CFGScale *float64 `json:"cfg_scale,omitempty"`
|
||||
NegativePrompt string `json:"negative_prompt,omitempty"`
|
||||
Width *int `json:"width,omitempty"`
|
||||
Height *int `json:"height,omitempty"`
|
||||
SampleMethod string `json:"sample_method,omitempty"`
|
||||
Seed *int64 `json:"seed,omitempty"`
|
||||
BatchCount int `json:"batch_count,omitempty"`
|
||||
}
|
||||
|
||||
type imageResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []struct {
|
||||
B64JSON string `json:"b64_json"`
|
||||
URL string `json:"url"`
|
||||
} `json:"data"`
|
||||
type txt2imgResponse struct {
|
||||
Images []string `json:"images"`
|
||||
}
|
||||
|
||||
// Generate implements imagegen.Model via POST {base}/v1/images/generations.
|
||||
// Generate implements imagegen.Model via POST {base}/sdapi/v1/txt2img.
|
||||
func (m *imageModel) Generate(ctx context.Context, req imagegen.Request, opts ...imagegen.Option) (*imagegen.Result, error) {
|
||||
req = req.Apply(opts...)
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
@@ -64,37 +63,35 @@ func (m *imageModel) Generate(ctx context.Context, req imagegen.Request, opts ..
|
||||
return nil, fmt.Errorf("%w: image count N must be >= 0, got %d", llm.ErrUnsupported, req.N)
|
||||
}
|
||||
|
||||
wire := imageRequest{
|
||||
Model: m.id,
|
||||
Prompt: req.Prompt,
|
||||
N: req.N,
|
||||
Size: req.Size,
|
||||
ResponseFormat: "b64_json",
|
||||
Steps: req.Steps,
|
||||
CFGScale: req.CFGScale,
|
||||
NegativePrompt: req.NegativePrompt,
|
||||
SampleMethod: req.Sampler,
|
||||
Seed: req.Seed,
|
||||
width, height, err := parseSize(req.Size)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", llm.ErrUnsupported, err)
|
||||
}
|
||||
|
||||
var resp imageResponse
|
||||
if err := m.p.doJSON(ctx, http.MethodPost, "/v1/images/generations", m.id, &wire, &resp); err != nil {
|
||||
wire := txt2imgRequest{
|
||||
Model: m.id,
|
||||
Prompt: req.Prompt,
|
||||
NegativePrompt: req.NegativePrompt,
|
||||
Seed: req.Seed,
|
||||
Steps: req.Steps,
|
||||
CFGScale: req.CFGScale,
|
||||
Width: width,
|
||||
Height: height,
|
||||
SampleMethod: req.Sampler,
|
||||
BatchCount: req.N,
|
||||
}
|
||||
|
||||
var resp txt2imgResponse
|
||||
if err := m.p.doJSON(ctx, http.MethodPost, "/sdapi/v1/txt2img", m.id, &wire, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := &imagegen.Result{Raw: &resp}
|
||||
for i, d := range resp.Data {
|
||||
if d.B64JSON == "" {
|
||||
// Why error rather than skip: a url-only entry means the backend
|
||||
// ignored response_format; we don't fetch remote content (mirrors
|
||||
// llm.ImagePart's bytes-only contract), so surface it.
|
||||
return nil, &llm.APIError{
|
||||
Provider: m.p.name,
|
||||
Model: m.id,
|
||||
Message: fmt.Sprintf("image %d returned no inline b64_json data", i),
|
||||
}
|
||||
for i, b64 := range resp.Images {
|
||||
if b64 == "" {
|
||||
continue
|
||||
}
|
||||
raw, err := base64.StdEncoding.DecodeString(d.B64JSON)
|
||||
raw, err := base64.StdEncoding.DecodeString(b64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("llama-swap: decode image %d: %w", i, err)
|
||||
}
|
||||
@@ -110,6 +107,25 @@ func (m *imageModel) Generate(ctx context.Context, req imagegen.Request, opts ..
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// parseSize splits a "WxH" string into width/height pointers. "" yields
|
||||
// (nil, nil) so the model's own default resolution applies.
|
||||
func parseSize(size string) (*int, *int, error) {
|
||||
size = strings.TrimSpace(size)
|
||||
if size == "" {
|
||||
return nil, nil, nil
|
||||
}
|
||||
parts := strings.SplitN(strings.ToLower(size), "x", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, nil, fmt.Errorf("invalid size %q (want WxH)", size)
|
||||
}
|
||||
w, err1 := strconv.Atoi(strings.TrimSpace(parts[0]))
|
||||
h, err2 := strconv.Atoi(strings.TrimSpace(parts[1]))
|
||||
if err1 != nil || err2 != nil || w <= 0 || h <= 0 {
|
||||
return nil, nil, fmt.Errorf("invalid size %q (want WxH)", size)
|
||||
}
|
||||
return &w, &h, nil
|
||||
}
|
||||
|
||||
// sniffImageMIME identifies the image format from its leading bytes, defaulting
|
||||
// to image/png (stable-diffusion.cpp emits PNG) when detection is inconclusive.
|
||||
func sniffImageMIME(data []byte) string {
|
||||
|
||||
Reference in New Issue
Block a user