fix(llamaswap): use A1111 /sdapi/v1/txt2img so seed is honored #10
@@ -52,3 +52,12 @@ contract is unchanged for callers that don't set them. `provider/llamaswap`
|
|||||||
forwards them to sd-server as `steps`/`cfg_scale`/`negative_prompt`/`sample_method`/
|
forwards them to sd-server as `steps`/`cfg_scale`/`negative_prompt`/`sample_method`/
|
||||||
`seed` (omitempty). This realizes the "seeds/steps … additive fields" note above;
|
`seed` (omitempty). This realizes the "seeds/steps … additive fields" note above;
|
||||||
img2img/masks/streaming remain deferred.
|
img2img/masks/streaming remain deferred.
|
||||||
|
|
||||||
|
## Update — A1111 txt2img endpoint (seed support)
|
||||||
|
|
||||||
|
`provider/llamaswap` now POSTs to sd-server's **`/sdapi/v1/txt2img`** (A1111)
|
||||||
|
instead of the OpenAI `/v1/images/generations`. That OpenAI endpoint **ignores
|
||||||
|
`seed`** on the stable-diffusion.cpp build we run — every render of a prompt is
|
||||||
|
byte-identical, so a batch of N collapses to one image. `/sdapi/v1/txt2img`
|
||||||
|
honours `seed`, restoring real per-render variety. llama-swap still routes by
|
||||||
|
the `model` field in the body; `Size` is split into `width`/`height`.
|
||||||
|
|||||||
+59
-43
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gitea.stevedudenhoeffer.com/steve/majordomo/imagegen"
|
"gitea.stevedudenhoeffer.com/steve/majordomo/imagegen"
|
||||||
@@ -27,34 +28,32 @@ type imageModel struct {
|
|||||||
id string
|
id string
|
||||||
}
|
}
|
||||||
|
|
||||||
// imageRequest is the OpenAI /v1/images/generations request shape, plus the
|
// txt2imgRequest is the stable-diffusion.cpp sd-server A1111 request shape
|
||||||
// stable-diffusion.cpp extras llama-swap forwards to sd-server. We always
|
// (POST /sdapi/v1/txt2img). We use this endpoint rather than the OpenAI
|
||||||
// request b64_json so the bytes come back inline (no second fetch). The
|
// /v1/images/generations one because that endpoint IGNORES `seed` on this
|
||||||
// optional fields are pointers/omitempty so an unset value is omitted entirely
|
// sd-server build — every render of a given prompt comes back byte-identical,
|
||||||
// and sd-server falls back to the model's own default (a field name a given
|
// so a batch of N collapses to one image. /sdapi/v1/txt2img honours `seed`,
|
||||||
// sd-server build doesn't recognize is simply ignored — harmless).
|
// giving real variety. llama-swap still routes by the `model` field in the
|
||||||
type imageRequest struct {
|
// body. Optional fields are pointers/omitempty so an unset value falls back to
|
||||||
|
// the model's baked default (the per-model --steps/--cfg-scale/etc. flags).
|
||||||
|
type txt2imgRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
N int `json:"n,omitempty"`
|
NegativePrompt string `json:"negative_prompt,omitempty"`
|
||||||
Size string `json:"size,omitempty"`
|
Seed *int64 `json:"seed,omitempty"`
|
||||||
ResponseFormat string `json:"response_format"`
|
|
||||||
Steps *int `json:"steps,omitempty"`
|
Steps *int `json:"steps,omitempty"`
|
||||||
CFGScale *float64 `json:"cfg_scale,omitempty"`
|
CFGScale *float64 `json:"cfg_scale,omitempty"`
|
||||||
NegativePrompt string `json:"negative_prompt,omitempty"`
|
Width *int `json:"width,omitempty"`
|
||||||
|
Height *int `json:"height,omitempty"`
|
||||||
SampleMethod string `json:"sample_method,omitempty"`
|
SampleMethod string `json:"sample_method,omitempty"`
|
||||||
Seed *int64 `json:"seed,omitempty"`
|
BatchCount int `json:"batch_count,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type imageResponse struct {
|
type txt2imgResponse struct {
|
||||||
Created int64 `json:"created"`
|
Images []string `json:"images"`
|
||||||
Data []struct {
|
|
||||||
B64JSON string `json:"b64_json"`
|
|
||||||
URL string `json:"url"`
|
|
||||||
} `json:"data"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate implements imagegen.Model via POST {base}/v1/images/generations.
|
// Generate implements imagegen.Model via POST {base}/sdapi/v1/txt2img.
|
||||||
func (m *imageModel) Generate(ctx context.Context, req imagegen.Request, opts ...imagegen.Option) (*imagegen.Result, error) {
|
func (m *imageModel) Generate(ctx context.Context, req imagegen.Request, opts ...imagegen.Option) (*imagegen.Result, error) {
|
||||||
req = req.Apply(opts...)
|
req = req.Apply(opts...)
|
||||||
if strings.TrimSpace(req.Prompt) == "" {
|
if strings.TrimSpace(req.Prompt) == "" {
|
||||||
@@ -64,37 +63,35 @@ func (m *imageModel) Generate(ctx context.Context, req imagegen.Request, opts ..
|
|||||||
return nil, fmt.Errorf("%w: image count N must be >= 0, got %d", llm.ErrUnsupported, req.N)
|
return nil, fmt.Errorf("%w: image count N must be >= 0, got %d", llm.ErrUnsupported, req.N)
|
||||||
}
|
}
|
||||||
|
|
||||||
wire := imageRequest{
|
width, height, err := parseSize(req.Size)
|
||||||
Model: m.id,
|
if err != nil {
|
||||||
Prompt: req.Prompt,
|
return nil, fmt.Errorf("%w: %v", llm.ErrUnsupported, err)
|
||||||
N: req.N,
|
|
||||||
Size: req.Size,
|
|
||||||
ResponseFormat: "b64_json",
|
|
||||||
Steps: req.Steps,
|
|
||||||
CFGScale: req.CFGScale,
|
|
||||||
NegativePrompt: req.NegativePrompt,
|
|
||||||
SampleMethod: req.Sampler,
|
|
||||||
Seed: req.Seed,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var resp imageResponse
|
wire := txt2imgRequest{
|
||||||
if err := m.p.doJSON(ctx, http.MethodPost, "/v1/images/generations", m.id, &wire, &resp); err != nil {
|
Model: m.id,
|
||||||
|
Prompt: req.Prompt,
|
||||||
|
NegativePrompt: req.NegativePrompt,
|
||||||
|
Seed: req.Seed,
|
||||||
|
Steps: req.Steps,
|
||||||
|
CFGScale: req.CFGScale,
|
||||||
|
Width: width,
|
||||||
|
Height: height,
|
||||||
|
SampleMethod: req.Sampler,
|
||||||
|
BatchCount: req.N,
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp txt2imgResponse
|
||||||
|
if err := m.p.doJSON(ctx, http.MethodPost, "/sdapi/v1/txt2img", m.id, &wire, &resp); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
out := &imagegen.Result{Raw: &resp}
|
out := &imagegen.Result{Raw: &resp}
|
||||||
for i, d := range resp.Data {
|
for i, b64 := range resp.Images {
|
||||||
if d.B64JSON == "" {
|
if b64 == "" {
|
||||||
// Why error rather than skip: a url-only entry means the backend
|
continue
|
||||||
// ignored response_format; we don't fetch remote content (mirrors
|
|
||||||
// llm.ImagePart's bytes-only contract), so surface it.
|
|
||||||
return nil, &llm.APIError{
|
|
||||||
Provider: m.p.name,
|
|
||||||
Model: m.id,
|
|
||||||
Message: fmt.Sprintf("image %d returned no inline b64_json data", i),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
raw, err := base64.StdEncoding.DecodeString(d.B64JSON)
|
raw, err := base64.StdEncoding.DecodeString(b64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("llama-swap: decode image %d: %w", i, err)
|
return nil, fmt.Errorf("llama-swap: decode image %d: %w", i, err)
|
||||||
}
|
}
|
||||||
@@ -110,6 +107,25 @@ func (m *imageModel) Generate(ctx context.Context, req imagegen.Request, opts ..
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseSize splits a "WxH" string into width/height pointers. "" yields
|
||||||
|
// (nil, nil) so the model's own default resolution applies.
|
||||||
|
func parseSize(size string) (*int, *int, error) {
|
||||||
|
size = strings.TrimSpace(size)
|
||||||
|
if size == "" {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
parts := strings.SplitN(strings.ToLower(size), "x", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return nil, nil, fmt.Errorf("invalid size %q (want WxH)", size)
|
||||||
|
}
|
||||||
|
w, err1 := strconv.Atoi(strings.TrimSpace(parts[0]))
|
||||||
|
h, err2 := strconv.Atoi(strings.TrimSpace(parts[1]))
|
||||||
|
if err1 != nil || err2 != nil || w <= 0 || h <= 0 {
|
||||||
|
return nil, nil, fmt.Errorf("invalid size %q (want WxH)", size)
|
||||||
|
}
|
||||||
|
return &w, &h, nil
|
||||||
|
}
|
||||||
|
|
||||||
// sniffImageMIME identifies the image format from its leading bytes, defaulting
|
// sniffImageMIME identifies the image format from its leading bytes, defaulting
|
||||||
// to image/png (stable-diffusion.cpp emits PNG) when detection is inconclusive.
|
// to image/png (stable-diffusion.cpp emits PNG) when detection is inconclusive.
|
||||||
func sniffImageMIME(data []byte) string {
|
func sniffImageMIME(data []byte) string {
|
||||||
|
|||||||
@@ -166,11 +166,11 @@ func TestRunningRaw(t *testing.T) {
|
|||||||
func TestImageGenerate(t *testing.T) {
|
func TestImageGenerate(t *testing.T) {
|
||||||
var gotBody map[string]any
|
var gotBody map[string]any
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.Path != "/v1/images/generations" {
|
if r.URL.Path != "/sdapi/v1/txt2img" {
|
||||||
t.Errorf("path = %q", r.URL.Path)
|
t.Errorf("path = %q", r.URL.Path)
|
||||||
}
|
}
|
||||||
_ = json.NewDecoder(r.Body).Decode(&gotBody)
|
_ = json.NewDecoder(r.Body).Decode(&gotBody)
|
||||||
_, _ = w.Write([]byte(`{"created":1,"data":[{"b64_json":"` + onePixelPNG + `"}]}`))
|
_, _ = w.Write([]byte(`{"images":["` + onePixelPNG + `"]}`))
|
||||||
}))
|
}))
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
@@ -192,12 +192,9 @@ func TestImageGenerate(t *testing.T) {
|
|||||||
if len(res.Images[0].Data) == 0 {
|
if len(res.Images[0].Data) == 0 {
|
||||||
t.Error("decoded image has no bytes")
|
t.Error("decoded image has no bytes")
|
||||||
}
|
}
|
||||||
// response_format must be forced to b64_json, and options applied.
|
// Size is split into width/height ints for the A1111 endpoint.
|
||||||
if gotBody["response_format"] != "b64_json" {
|
if gotBody["width"] != float64(512) || gotBody["height"] != float64(512) {
|
||||||
t.Errorf("response_format = %v, want b64_json", gotBody["response_format"])
|
t.Errorf("width/height = %v/%v, want 512/512", gotBody["width"], gotBody["height"])
|
||||||
}
|
|
||||||
if gotBody["size"] != "512x512" {
|
|
||||||
t.Errorf("size = %v, want 512x512", gotBody["size"])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -205,7 +202,7 @@ func TestImageGenerateSettings(t *testing.T) {
|
|||||||
var gotBody map[string]any
|
var gotBody map[string]any
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
_ = json.NewDecoder(r.Body).Decode(&gotBody)
|
_ = json.NewDecoder(r.Body).Decode(&gotBody)
|
||||||
_, _ = w.Write([]byte(`{"created":1,"data":[{"b64_json":"` + onePixelPNG + `"}]}`))
|
_, _ = w.Write([]byte(`{"images":["` + onePixelPNG + `"]}`))
|
||||||
}))
|
}))
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user