From 64642c43c4367b9ecf8db7076aa8c7f7bb0f5595 Mon Sep 17 00:00:00 2001 From: Steve Dudenhoeffer Date: Sat, 27 Jun 2026 16:04:23 -0400 Subject: [PATCH] fix(llamaswap): address Gadfly review findings - Unload: reject model ids containing path separators (/?#) so a model name can't redirect the request to another endpoint; ":" (common in ids) stays verbatim. - doJSON: take a model arg so image/management HTTP errors carry the target id (was always ""); add a base-URL guard so management methods fail clearly instead of building a bare-path request; cap the success-path JSON decode with io.LimitReader (64 MiB) and drain the body when out is nil for conn reuse. - image: reject negative Request.N before sending. Co-Authored-By: Claude Opus 4.8 (1M context) --- provider/llamaswap/image.go | 5 +++- provider/llamaswap/llamaswap.go | 38 +++++++++++++++++++++------- provider/llamaswap/llamaswap_test.go | 18 +++++++++++++ 3 files changed, 51 insertions(+), 10 deletions(-) diff --git a/provider/llamaswap/image.go b/provider/llamaswap/image.go index 10c0796..e331fbd 100644 --- a/provider/llamaswap/image.go +++ b/provider/llamaswap/image.go @@ -51,6 +51,9 @@ func (m *imageModel) Generate(ctx context.Context, req imagegen.Request, opts .. if strings.TrimSpace(req.Prompt) == "" { return nil, fmt.Errorf("%w: image generation requires a prompt", llm.ErrUnsupported) } + if req.N < 0 { + return nil, fmt.Errorf("%w: image count N must be >= 0, got %d", llm.ErrUnsupported, req.N) + } wire := imageRequest{ Model: m.id, @@ -61,7 +64,7 @@ func (m *imageModel) Generate(ctx context.Context, req imagegen.Request, opts .. } var resp imageResponse - if err := m.p.doJSON(ctx, http.MethodPost, "/v1/images/generations", &wire, &resp); err != nil { + if err := m.p.doJSON(ctx, http.MethodPost, "/v1/images/generations", m.id, &wire, &resp); err != nil { return nil, err } diff --git a/provider/llamaswap/llamaswap.go b/provider/llamaswap/llamaswap.go index d7caa84..f973cf1 100644 --- a/provider/llamaswap/llamaswap.go +++ b/provider/llamaswap/llamaswap.go @@ -38,6 +38,11 @@ import ( // DefaultName is the registry name used when WithName is not given. const DefaultName = "llama-swap" +// maxResponseBytes caps the JSON body read on the success path. Generous +// enough for a multi-image b64 payload, bounded so a hostile/buggy upstream +// can't make a decode allocate without limit. +const maxResponseBytes = 64 << 20 + // Provider is a llama-swap client. It satisfies llm.Provider (chat, delegated // to provider/openai) and imagegen.Provider (image generation), and exposes // llama-swap's management endpoints as concrete methods. @@ -136,7 +141,7 @@ func (p *Provider) ListModels(ctx context.Context) ([]ModelInfo, error) { var out struct { Data []ModelInfo `json:"data"` } - if err := p.doJSON(ctx, http.MethodGet, "/v1/models", nil, &out); err != nil { + if err := p.doJSON(ctx, http.MethodGet, "/v1/models", "", nil, &out); err != nil { return nil, err } return out.Data, nil @@ -148,7 +153,7 @@ func (p *Provider) ListModels(ctx context.Context) ([]ModelInfo, error) { // would have to guess. func (p *Provider) Running(ctx context.Context) (json.RawMessage, error) { var out json.RawMessage - if err := p.doJSON(ctx, http.MethodGet, "/running", nil, &out); err != nil { + if err := p.doJSON(ctx, http.MethodGet, "/running", "", nil, &out); err != nil { return nil, err } return out, nil @@ -160,18 +165,30 @@ func (p *Provider) Running(ctx context.Context) (json.RawMessage, error) { func (p *Provider) Unload(ctx context.Context, model string) error { path := "/api/models/unload" if model != "" { + // Why reject rather than percent-escape: llama-swap model ids legitimately + // contain ":" (e.g. "qwen3:14b"), which is path-legal and must reach the + // server verbatim; only path-structure characters are dangerous (they'd + // redirect the request to another endpoint), and those never appear in a + // real model id. + if strings.ContainsAny(model, "/?#") { + return fmt.Errorf("llama-swap: invalid model id %q for unload (contains a path separator)", model) + } path += "/" + model } - return p.doJSON(ctx, http.MethodPost, path, nil, nil) + return p.doJSON(ctx, http.MethodPost, path, "", nil, nil) } // --- shared HTTP helper for management + image endpoints --- // doJSON performs a request to a llama-swap endpoint relative to baseURL, -// optionally encoding body and decoding into out (either may be nil). Transport -// failures are wrapped raw so llm.Classify still sees the underlying net error; -// non-2xx responses become *llm.APIError. -func (p *Provider) doJSON(ctx context.Context, method, path string, body, out any) error { +// optionally encoding body and decoding into out (either may be nil). model +// labels the failing target in any *llm.APIError ("" for endpoints that aren't +// model-specific). Transport failures are wrapped raw so llm.Classify still +// sees the underlying net error; non-2xx responses become *llm.APIError. +func (p *Provider) doJSON(ctx context.Context, method, path, model string, body, out any) error { + if p.baseURL == "" { + return fmt.Errorf("llama-swap provider %q: no base URL configured (set one via WithBaseURL or an LLM_* env DSN)", p.name) + } var rdr io.Reader if body != nil { b, err := json.Marshal(body) @@ -196,12 +213,15 @@ func (p *Provider) doJSON(ctx context.Context, method, path string, body, out an } defer resp.Body.Close() if resp.StatusCode/100 != 2 { - return p.apiError(resp, "") + return p.apiError(resp, model) } if out != nil { - if err := json.NewDecoder(resp.Body).Decode(out); err != nil { + if err := json.NewDecoder(io.LimitReader(resp.Body, maxResponseBytes)).Decode(out); err != nil { return fmt.Errorf("llama-swap: decode response: %w", err) } + } else { + // Drain (bounded) so the connection can be reused. + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, maxResponseBytes)) } return nil } diff --git a/provider/llamaswap/llamaswap_test.go b/provider/llamaswap/llamaswap_test.go index 8b82e09..a210654 100644 --- a/provider/llamaswap/llamaswap_test.go +++ b/provider/llamaswap/llamaswap_test.go @@ -127,6 +127,24 @@ func TestUnload(t *testing.T) { if gotPath != "/api/models/unload" { t.Errorf("unload-all path = %q", gotPath) } + + // A model id with a path separator is rejected before any request. + if err := p.Unload(context.Background(), "../admin"); err == nil { + t.Error("expected error for model id with path separator") + } +} + +func TestManagementNoBaseURL(t *testing.T) { + p := New() // no base URL + if _, err := p.ListModels(context.Background()); err == nil { + t.Error("ListModels: expected error for missing base URL") + } + if _, err := p.Running(context.Background()); err == nil { + t.Error("Running: expected error for missing base URL") + } + if err := p.Unload(context.Background(), "m"); err == nil { + t.Error("Unload: expected error for missing base URL") + } } func TestRunningRaw(t *testing.T) {