fix(llamaswap): address Gadfly review findings
- Unload: reject model ids containing path separators (/?#) so a model name can't redirect the request to another endpoint; ":" (common in ids) stays verbatim. - doJSON: take a model arg so image/management HTTP errors carry the target id (was always ""); add a base-URL guard so management methods fail clearly instead of building a bare-path request; cap the success-path JSON decode with io.LimitReader (64 MiB) and drain the body when out is nil for conn reuse. - image: reject negative Request.N before sending. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -51,6 +51,9 @@ func (m *imageModel) Generate(ctx context.Context, req imagegen.Request, opts ..
|
|||||||
if strings.TrimSpace(req.Prompt) == "" {
|
if strings.TrimSpace(req.Prompt) == "" {
|
||||||
return nil, fmt.Errorf("%w: image generation requires a prompt", llm.ErrUnsupported)
|
return nil, fmt.Errorf("%w: image generation requires a prompt", llm.ErrUnsupported)
|
||||||
}
|
}
|
||||||
|
if req.N < 0 {
|
||||||
|
return nil, fmt.Errorf("%w: image count N must be >= 0, got %d", llm.ErrUnsupported, req.N)
|
||||||
|
}
|
||||||
|
|
||||||
wire := imageRequest{
|
wire := imageRequest{
|
||||||
Model: m.id,
|
Model: m.id,
|
||||||
@@ -61,7 +64,7 @@ func (m *imageModel) Generate(ctx context.Context, req imagegen.Request, opts ..
|
|||||||
}
|
}
|
||||||
|
|
||||||
var resp imageResponse
|
var resp imageResponse
|
||||||
if err := m.p.doJSON(ctx, http.MethodPost, "/v1/images/generations", &wire, &resp); err != nil {
|
if err := m.p.doJSON(ctx, http.MethodPost, "/v1/images/generations", m.id, &wire, &resp); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -38,6 +38,11 @@ import (
|
|||||||
// DefaultName is the registry name used when WithName is not given.
|
// DefaultName is the registry name used when WithName is not given.
|
||||||
const DefaultName = "llama-swap"
|
const DefaultName = "llama-swap"
|
||||||
|
|
||||||
|
// maxResponseBytes caps the JSON body read on the success path. Generous
|
||||||
|
// enough for a multi-image b64 payload, bounded so a hostile/buggy upstream
|
||||||
|
// can't make a decode allocate without limit.
|
||||||
|
const maxResponseBytes = 64 << 20
|
||||||
|
|
||||||
// Provider is a llama-swap client. It satisfies llm.Provider (chat, delegated
|
// Provider is a llama-swap client. It satisfies llm.Provider (chat, delegated
|
||||||
// to provider/openai) and imagegen.Provider (image generation), and exposes
|
// to provider/openai) and imagegen.Provider (image generation), and exposes
|
||||||
// llama-swap's management endpoints as concrete methods.
|
// llama-swap's management endpoints as concrete methods.
|
||||||
@@ -136,7 +141,7 @@ func (p *Provider) ListModels(ctx context.Context) ([]ModelInfo, error) {
|
|||||||
var out struct {
|
var out struct {
|
||||||
Data []ModelInfo `json:"data"`
|
Data []ModelInfo `json:"data"`
|
||||||
}
|
}
|
||||||
if err := p.doJSON(ctx, http.MethodGet, "/v1/models", nil, &out); err != nil {
|
if err := p.doJSON(ctx, http.MethodGet, "/v1/models", "", nil, &out); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return out.Data, nil
|
return out.Data, nil
|
||||||
@@ -148,7 +153,7 @@ func (p *Provider) ListModels(ctx context.Context) ([]ModelInfo, error) {
|
|||||||
// would have to guess.
|
// would have to guess.
|
||||||
func (p *Provider) Running(ctx context.Context) (json.RawMessage, error) {
|
func (p *Provider) Running(ctx context.Context) (json.RawMessage, error) {
|
||||||
var out json.RawMessage
|
var out json.RawMessage
|
||||||
if err := p.doJSON(ctx, http.MethodGet, "/running", nil, &out); err != nil {
|
if err := p.doJSON(ctx, http.MethodGet, "/running", "", nil, &out); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
@@ -160,18 +165,30 @@ func (p *Provider) Running(ctx context.Context) (json.RawMessage, error) {
|
|||||||
func (p *Provider) Unload(ctx context.Context, model string) error {
|
func (p *Provider) Unload(ctx context.Context, model string) error {
|
||||||
path := "/api/models/unload"
|
path := "/api/models/unload"
|
||||||
if model != "" {
|
if model != "" {
|
||||||
|
// Why reject rather than percent-escape: llama-swap model ids legitimately
|
||||||
|
// contain ":" (e.g. "qwen3:14b"), which is path-legal and must reach the
|
||||||
|
// server verbatim; only path-structure characters are dangerous (they'd
|
||||||
|
// redirect the request to another endpoint), and those never appear in a
|
||||||
|
// real model id.
|
||||||
|
if strings.ContainsAny(model, "/?#") {
|
||||||
|
return fmt.Errorf("llama-swap: invalid model id %q for unload (contains a path separator)", model)
|
||||||
|
}
|
||||||
path += "/" + model
|
path += "/" + model
|
||||||
}
|
}
|
||||||
return p.doJSON(ctx, http.MethodPost, path, nil, nil)
|
return p.doJSON(ctx, http.MethodPost, path, "", nil, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- shared HTTP helper for management + image endpoints ---
|
// --- shared HTTP helper for management + image endpoints ---
|
||||||
|
|
||||||
// doJSON performs a request to a llama-swap endpoint relative to baseURL,
|
// doJSON performs a request to a llama-swap endpoint relative to baseURL,
|
||||||
// optionally encoding body and decoding into out (either may be nil). Transport
|
// optionally encoding body and decoding into out (either may be nil). model
|
||||||
// failures are wrapped raw so llm.Classify still sees the underlying net error;
|
// labels the failing target in any *llm.APIError ("" for endpoints that aren't
|
||||||
// non-2xx responses become *llm.APIError.
|
// model-specific). Transport failures are wrapped raw so llm.Classify still
|
||||||
func (p *Provider) doJSON(ctx context.Context, method, path string, body, out any) error {
|
// sees the underlying net error; non-2xx responses become *llm.APIError.
|
||||||
|
func (p *Provider) doJSON(ctx context.Context, method, path, model string, body, out any) error {
|
||||||
|
if p.baseURL == "" {
|
||||||
|
return fmt.Errorf("llama-swap provider %q: no base URL configured (set one via WithBaseURL or an LLM_* env DSN)", p.name)
|
||||||
|
}
|
||||||
var rdr io.Reader
|
var rdr io.Reader
|
||||||
if body != nil {
|
if body != nil {
|
||||||
b, err := json.Marshal(body)
|
b, err := json.Marshal(body)
|
||||||
@@ -196,12 +213,15 @@ func (p *Provider) doJSON(ctx context.Context, method, path string, body, out an
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
if resp.StatusCode/100 != 2 {
|
if resp.StatusCode/100 != 2 {
|
||||||
return p.apiError(resp, "")
|
return p.apiError(resp, model)
|
||||||
}
|
}
|
||||||
if out != nil {
|
if out != nil {
|
||||||
if err := json.NewDecoder(resp.Body).Decode(out); err != nil {
|
if err := json.NewDecoder(io.LimitReader(resp.Body, maxResponseBytes)).Decode(out); err != nil {
|
||||||
return fmt.Errorf("llama-swap: decode response: %w", err)
|
return fmt.Errorf("llama-swap: decode response: %w", err)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// Drain (bounded) so the connection can be reused.
|
||||||
|
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, maxResponseBytes))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -127,6 +127,24 @@ func TestUnload(t *testing.T) {
|
|||||||
if gotPath != "/api/models/unload" {
|
if gotPath != "/api/models/unload" {
|
||||||
t.Errorf("unload-all path = %q", gotPath)
|
t.Errorf("unload-all path = %q", gotPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A model id with a path separator is rejected before any request.
|
||||||
|
if err := p.Unload(context.Background(), "../admin"); err == nil {
|
||||||
|
t.Error("expected error for model id with path separator")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagementNoBaseURL(t *testing.T) {
|
||||||
|
p := New() // no base URL
|
||||||
|
if _, err := p.ListModels(context.Background()); err == nil {
|
||||||
|
t.Error("ListModels: expected error for missing base URL")
|
||||||
|
}
|
||||||
|
if _, err := p.Running(context.Background()); err == nil {
|
||||||
|
t.Error("Running: expected error for missing base URL")
|
||||||
|
}
|
||||||
|
if err := p.Unload(context.Background(), "m"); err == nil {
|
||||||
|
t.Error("Unload: expected error for missing base URL")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRunningRaw(t *testing.T) {
|
func TestRunningRaw(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user