diff --git a/config-schema.json b/config-schema.json index ce325105..4142c397 100644 --- a/config-schema.json +++ b/config-schema.json @@ -572,6 +572,24 @@ "default": {}, "description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap." }, + "upstream": { + "type": "object", + "description": "Controls behaviour of the /upstream passthrough endpoint. Recommended to only use in special use cases; leaving it as the default will typically be the best experience.", + "properties": { + "ignorePaths": { + "type": "array", + "items": { + "type": "string" + }, + "default": [ + ".*\\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$" + ], + "description": "List of RE2 compatible regular expressions. Any request to a path matching any of the regular expressions will be ignored and not trigger a swap. When not specified, defaults to a pattern matching common static-asset suffixes (.js, .json, .css, .png, .gif, .jpg, .jpeg, .ico, .txt)." + } + }, + "additionalProperties": false, + "default": {} + }, "routing": { "type": "object", "description": "Canonical routing/scheduling configuration. Alternative to the legacy top-level 'groups'/'matrix' keys; a config must not use both styles.", diff --git a/config.example.yaml b/config.example.yaml index 72ecd037..27de1131 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -134,6 +134,18 @@ apiKeys: - "${env.API_KEY_1}" - "${env.API_KEY_2}" +# upstream: controls behaviour of the /upstream passthrough endpoint +# - optional, default: empty dictionary +# - recommended to only use in special use cases. Leaving it as the +# default will typically be the best experience +upstream: + # ignorePaths: list of RE2 compatible regular expressions + # - default: (see below) + # - any request to a path matching any of the regular expressions + # will be ignored and not trigger a swap + ignorePaths: + - '.*\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$' + # models: a dictionary of model configurations # - required # - each key is the model's ID, used in API requests diff --git a/internal/config/config.go b/internal/config/config.go index 6e3aa464..3d34b060 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -163,6 +163,9 @@ type Config struct { // support remote peers, see issue #433, #296 Peers PeerDictionaryConfig `yaml:"peers"` + + // upstream controls behaviour of the /upstream passthrough endpoint + Upstream UpstreamConfig `yaml:"upstream"` } // RoutingConfig is the canonical, normalized routing/scheduling configuration. @@ -270,6 +273,12 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { return Config{}, fmt.Errorf("globalTTL must be >= 0") } + // Apply default for upstream.ignorePaths when not specified. The default + // matches common static-asset suffixes so they do not trigger a swap. + if len(config.Upstream.IgnorePaths) == 0 { + config.Upstream.IgnorePaths = DefaultUpstreamIgnorePaths() + } + switch config.LogToStdout { case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone: default: diff --git a/internal/config/config_posix_test.go b/internal/config/config_posix_test.go index 64124b5c..d3fd9546 100644 --- a/internal/config/config_posix_test.go +++ b/internal/config/config_posix_test.go @@ -266,6 +266,9 @@ groups: "mthree": "model3", }, Groups: expectedGroups, + Upstream: UpstreamConfig{ + IgnorePaths: DefaultUpstreamIgnorePaths(), + }, Routing: RoutingConfig{ Router: RouterConfig{ Use: "group", diff --git a/internal/config/config_windows_test.go b/internal/config/config_windows_test.go index 4992be3b..7f53a25d 100644 --- a/internal/config/config_windows_test.go +++ b/internal/config/config_windows_test.go @@ -255,6 +255,9 @@ groups: "mthree": "model3", }, Groups: expectedGroups, + Upstream: UpstreamConfig{ + IgnorePaths: DefaultUpstreamIgnorePaths(), + }, Routing: RoutingConfig{ Router: RouterConfig{ Use: "group", diff --git a/internal/config/upstream.go b/internal/config/upstream.go new file mode 100644 index 00000000..8419e817 --- /dev/null +++ b/internal/config/upstream.go @@ -0,0 +1,55 @@ +package config + +import ( + "fmt" + "regexp" + + "gopkg.in/yaml.v3" +) + +// DefaultUpstreamIgnorePathsPattern is the default regular expression applied +// to upstream.ignorePaths when the section is empty or absent from the config. +// It matches common static-asset suffixes so requests for .js/.css/.png/etc. +// files do not trigger a model swap. +const DefaultUpstreamIgnorePathsPattern = `.*\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$` + +// DefaultUpstreamIgnorePaths returns the default compiled ignore paths used +// when upstream.ignorePaths is not specified in the config. The returned slice +// is fresh so callers may mutate it without affecting other configs. +func DefaultUpstreamIgnorePaths() []*regexp.Regexp { + return []*regexp.Regexp{regexp.MustCompile(DefaultUpstreamIgnorePathsPattern)} +} + +// UpstreamConfig controls behaviour of the /upstream passthrough endpoint. +type UpstreamConfig struct { + // IgnorePaths is a slice of compiled regular expressions. Any request to + // /upstream// whose remaining path matches any of these + // expressions will be ignored and not trigger a swap. When the config + // does not specify any patterns, DefaultUpstreamIgnorePaths is applied. + IgnorePaths []*regexp.Regexp `yaml:"-"` +} + +// rawUpstreamConfig is the intermediate form used to unmarshal the YAML into +// plain strings, which are then compiled into *regexp.Regexp. +type rawUpstreamConfig struct { + IgnorePaths []string `yaml:"ignorePaths"` +} + +// UnmarshalYAML compiles each ignorePaths entry into a *regexp.Regexp. If any +// entry fails to compile, an error is returned. +func (u *UpstreamConfig) UnmarshalYAML(value *yaml.Node) error { + var raw rawUpstreamConfig + if err := value.Decode(&raw); err != nil { + return err + } + patterns := make([]*regexp.Regexp, 0, len(raw.IgnorePaths)) + for _, p := range raw.IgnorePaths { + re, err := regexp.Compile(p) + if err != nil { + return fmt.Errorf("upstream.ignorePaths: invalid regular expression %q: %w", p, err) + } + patterns = append(patterns, re) + } + u.IgnorePaths = patterns + return nil +} diff --git a/internal/config/upstream_test.go b/internal/config/upstream_test.go new file mode 100644 index 00000000..19b61eb5 --- /dev/null +++ b/internal/config/upstream_test.go @@ -0,0 +1,88 @@ +package config + +import ( + "regexp" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const upstreamConfigHeader = ` +models: + model1: + cmd: path/to/cmd --arg1 one + proxy: "http://localhost:8080" +` + +func TestConfig_UpstreamIgnorePaths_DefaultWhenAbsent(t *testing.T) { + // When upstream is not specified at all, the default pattern is applied. + content := upstreamConfigHeader + cfg, err := LoadConfigFromReader(strings.NewReader(content)) + require.NoError(t, err) + require.Len(t, cfg.Upstream.IgnorePaths, 1) + + def := cfg.Upstream.IgnorePaths[0] + assert.IsType(t, ®exp.Regexp{}, def) + assert.Equal(t, DefaultUpstreamIgnorePathsPattern, def.String()) + + // The default matches common static-asset suffixes. + assert.True(t, def.MatchString("/foo.js")) + assert.True(t, def.MatchString("/bar/baz.json")) + assert.True(t, def.MatchString("/static/img.png")) + assert.True(t, def.MatchString("/notes.txt")) + assert.True(t, def.MatchString("/favicon.ico")) + // And does not match inference API paths. + assert.False(t, def.MatchString("/v1/chat/completions")) + assert.False(t, def.MatchString("/v1/models")) + assert.False(t, def.MatchString("/health")) +} + +func TestConfig_UpstreamIgnorePaths_DefaultWhenSectionEmpty(t *testing.T) { + // When upstream is present but ignorePaths is omitted, the default is still + // applied. + content := `upstream: {}` + "\n" + upstreamConfigHeader + cfg, err := LoadConfigFromReader(strings.NewReader(content)) + require.NoError(t, err) + require.Len(t, cfg.Upstream.IgnorePaths, 1) + assert.Equal(t, DefaultUpstreamIgnorePathsPattern, cfg.Upstream.IgnorePaths[0].String()) +} + +func TestConfig_UpstreamIgnorePaths_Compiles(t *testing.T) { + content := ` +upstream: + ignorePaths: + - ".*\\.(js|json|css|png|gif|jpg|jpeg|txt)$" + - "^/static/.*" +` + upstreamConfigHeader + + cfg, err := LoadConfigFromReader(strings.NewReader(content)) + require.NoError(t, err) + require.Len(t, cfg.Upstream.IgnorePaths, 2) + + // Verify the patterns are compiled into *regexp.Regexp and match as expected. + assert.True(t, cfg.Upstream.IgnorePaths[0].MatchString("/foo.js")) + assert.True(t, cfg.Upstream.IgnorePaths[0].MatchString("/bar/baz.json")) + assert.False(t, cfg.Upstream.IgnorePaths[0].MatchString("/v1/chat/completions")) + assert.True(t, cfg.Upstream.IgnorePaths[1].MatchString("/static/foo.png")) + assert.False(t, cfg.Upstream.IgnorePaths[1].MatchString("/v1/chat/completions")) + + // Confirm the type is *regexp.Regexp to satisfy the API contract. + for _, re := range cfg.Upstream.IgnorePaths { + assert.IsType(t, ®exp.Regexp{}, re) + } +} + +func TestConfig_UpstreamIgnorePaths_InvalidRegexReturnsError(t *testing.T) { + content := ` +upstream: + ignorePaths: + - "[invalid(" +` + upstreamConfigHeader + + _, err := LoadConfigFromReader(strings.NewReader(content)) + require.Error(t, err) + assert.Contains(t, err.Error(), "upstream.ignorePaths") + assert.Contains(t, err.Error(), "invalid regular expression") +} diff --git a/internal/server/api.go b/internal/server/api.go index 96cafed8..0da4a781 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -2,6 +2,7 @@ package server import ( "encoding/json" + "fmt" "net/http" "sort" "strings" @@ -9,6 +10,7 @@ import ( "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/event" + "github.com/mostlygeek/llama-swap/internal/process" "github.com/mostlygeek/llama-swap/internal/shared" ) @@ -340,6 +342,28 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) { // Pin the resolved model so the router skips body/query extraction. *r = *r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: searchName, ModelID: modelID, Metadata: make(map[string]string)})) + // If the path matches an upstream.ignorePaths entry and the model is + // not already loaded, refuse the request without triggering a swap. The + // server was not able to process the response because the model was not + // already loaded. + for _, re := range s.cfg.Upstream.IgnorePaths { + if !re.MatchString(remainingPath) { + continue + } + if s.local.Handles(modelID) { + state, ok := s.local.RunningModels()[modelID] + if !ok || state != process.StateReady { + shared.SendResponse(w, r, http.StatusConflict, + fmt.Sprintf("model %s is not loaded; path matches upstream.ignorePaths", modelID)) + return + } + } + // Either the model is already loaded (no swap would be triggered) + // or this is a peer model (peer proxying never swaps). Fall through + // to normal dispatch. + break + } + switch { case s.local.Handles(modelID): s.local.ServeHTTP(w, r) diff --git a/internal/server/api_test.go b/internal/server/api_test.go index 5924bc06..715a480f 100644 --- a/internal/server/api_test.go +++ b/internal/server/api_test.go @@ -5,11 +5,13 @@ import ( "io" "net/http" "net/http/httptest" + "regexp" "strings" "testing" "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" + "github.com/mostlygeek/llama-swap/internal/process" "github.com/mostlygeek/llama-swap/internal/shared" ) @@ -156,6 +158,91 @@ func upstreamMetricsServer(response string) *Server { return s } +func TestServer_HandleUpstream_IgnorePaths(t *testing.T) { + // Compile a pattern that matches static asset suffixes. + pattern := regexp.MustCompile(`.*\.(js|json|css|png|gif|jpg|jpeg|txt)$`) + + t.Run("matched path, model not loaded, returns 409", func(t *testing.T) { + local := newStubRouter([]string{"m1"}, "upstream-body") + // running is nil/empty: model is not in RunningModels() => not loaded. + s := newTestServer(local, newStubRouter(nil, "")) + s.cfg = config.Config{ + Models: map[string]config.ModelConfig{"m1": {}}, + Upstream: config.UpstreamConfig{ + IgnorePaths: []*regexp.Regexp{pattern}, + }, + } + + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/foo.js", nil)) + + if w.Code != http.StatusConflict { + t.Fatalf("status = %d, want %d (body=%q)", w.Code, http.StatusConflict, w.Body.String()) + } + if !strings.Contains(w.Body.String(), "not loaded") { + t.Errorf("body = %q, want it to contain 'not loaded'", w.Body.String()) + } + }) + + t.Run("matched path, model already loaded, serves normally", func(t *testing.T) { + local := newStubRouter([]string{"m1"}, "upstream-body") + local.running = map[string]process.ProcessState{"m1": process.StateReady} + s := newTestServer(local, newStubRouter(nil, "")) + s.cfg = config.Config{ + Models: map[string]config.ModelConfig{"m1": {}}, + Upstream: config.UpstreamConfig{ + IgnorePaths: []*regexp.Regexp{pattern}, + }, + } + + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/foo.js", nil)) + + if w.Code != http.StatusOK || w.Body.String() != "upstream-body" { + t.Fatalf("status=%d body=%q, want 200 'upstream-body'", w.Code, w.Body.String()) + } + }) + + t.Run("non-matched path, model not loaded, serves normally", func(t *testing.T) { + local := newStubRouter([]string{"m1"}, "upstream-body") + s := newTestServer(local, newStubRouter(nil, "")) + s.cfg = config.Config{ + Models: map[string]config.ModelConfig{"m1": {}}, + Upstream: config.UpstreamConfig{ + IgnorePaths: []*regexp.Regexp{pattern}, + }, + } + + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/v1/chat/completions", nil)) + + if w.Code != http.StatusOK || w.Body.String() != "upstream-body" { + t.Fatalf("status=%d body=%q, want 200 'upstream-body'", w.Code, w.Body.String()) + } + }) + + t.Run("matched path, peer model, serves normally", func(t *testing.T) { + // Peer routers do not appear via RunningModels on the local router; + // they should fall through to normal dispatch without 409. + local := newStubRouter(nil, "") + peer := newStubRouter([]string{"m1"}, "peer-body") + s := newTestServer(local, peer) + s.cfg = config.Config{ + Models: map[string]config.ModelConfig{"m1": {}}, + Upstream: config.UpstreamConfig{ + IgnorePaths: []*regexp.Regexp{pattern}, + }, + } + + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/foo.js", nil)) + + if w.Code != http.StatusOK || w.Body.String() != "peer-body" { + t.Fatalf("status=%d body=%q, want 200 'peer-body'", w.Code, w.Body.String()) + } + }) +} + func TestServer_HandleUpstream_MetricsRecordsSupportedPath(t *testing.T) { resp := `{"usage":{"prompt_tokens":3,"completion_tokens":5}}` s := upstreamMetricsServer(resp)