diff --git a/.github/workflows/go-ci.yml b/.github/workflows/go-ci.yml
index 0a95ab21..6c3976c7 100644
--- a/.github/workflows/go-ci.yml
+++ b/.github/workflows/go-ci.yml
@@ -2,69 +2,68 @@ name: Linux CI
on:
push:
- branches: [ "main" ]
+ branches: ["main"]
# only run when backend source changes
# cmd/ is excluded because it contains utilities without tests
paths:
- - '**/*.go'
- - '!cmd/**'
- - 'go.mod'
- - 'go.sum'
- - 'Makefile'
- - '.github/workflows/go-ci.yml'
+ - "**/*.go"
+ - "!cmd/**"
+ - "go.mod"
+ - "go.sum"
+ - "Makefile"
+ - ".github/workflows/go-ci.yml"
pull_request:
- branches: [ "main" ]
+ branches: ["main"]
paths:
- - '**/*.go'
- - '!cmd/**'
- - 'go.mod'
- - 'go.sum'
- - 'Makefile'
- - '.github/workflows/go-ci.yml'
+ - "**/*.go"
+ - "!cmd/**"
+ - "go.mod"
+ - "go.sum"
+ - "Makefile"
+ - ".github/workflows/go-ci.yml"
# Allows manual triggering of the workflow
workflow_dispatch:
jobs:
-
run-tests:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v4
- - name: Set up Go
- uses: actions/setup-go@v4
- with:
- go-version-file: go.mod
+ - name: Set up Go
+ uses: actions/setup-go@v4
+ with:
+ go-version-file: go.mod
- # Only run in this linux based runner
- - name: Check Formatting
- run: |
- if [ "$(gofmt -l . | grep -v 'event/.*_test.go' | wc -l)" -gt 0 ]; then
- gofmt -l . | grep -v 'event/.*_test.go'
- exit 1
- fi
- # cache simple-responder to save the build time
- - name: Restore Simple Responder
- id: restore-simple-responder
- uses: actions/cache/restore@v4
- with:
- path: ./build
- key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
+ # Only run in this linux based runner
+ - name: Check Formatting
+ run: |
+ if [ "$(gofmt -l . | grep -v 'event/.*_test.go' | wc -l)" -gt 0 ]; then
+ gofmt -l . | grep -v 'event/.*_test.go'
+ exit 1
+ fi
+ # cache simple-responder to save the build time
+ - name: Restore Simple Responder
+ id: restore-simple-responder
+ uses: actions/cache/restore@v4
+ with:
+ path: ./build
+ key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
- # necessary for testing proxy/Process swapping
- - name: Create simple-responder
- run: make simple-responder
+ # necessary for testing proxy/Process swapping
+ - name: Create simple-responder
+ run: make simple-responder
- - name: Save Simple Responder
- # nothing new to save ... skip this step
- if: steps.restore-simple-responder.outputs.cache-hit != 'true'
- id: save-simple-responder
- uses: actions/cache/save@v4
- with:
- path: ./build
- key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
+ - name: Save Simple Responder
+ # nothing new to save ... skip this step
+ if: steps.restore-simple-responder.outputs.cache-hit != 'true'
+ id: save-simple-responder
+ uses: actions/cache/save@v4
+ with:
+ path: ./build
+ key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
- - name: Test all
- run: make test-all
+ - name: Test all
+ run: make test-all
diff --git a/README.md b/README.md
index ee091aa3..7a04ff16 100644
--- a/README.md
+++ b/README.md
@@ -5,7 +5,7 @@
# llama-swap
-Run multiple generative AI models on your machine and hot-swap between them on demand. llama-swap works with any OpenAI and Anthropic API compatible server and is used by thousands of people to power their local AI workflows.
+Run multiple generative AI models on your machine and hot-swap between them on demand. llama-swap works with any OpenAI and Anthropic API compatible server and is used by thousands of people to power their local AI workflows.
Built in Go for performance and simplicity, llama-swap has zero dependencies and is incredibly easy to set up. Get started in minutes - just one binary and one configuration file.
@@ -45,14 +45,14 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
- `/health` - just returns "OK"
- ✅ API Key support - define keys to restrict access to API endpoints
- ✅ Customizable
- - Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
+ - Run concurrent models with a custom DSL swap matrix ([#643](https://github.com/mostlygeek/llama-swap/issues/643))
- Automatic unloading of models after timeout by setting a `ttl`
- Reliable Docker and Podman support using `cmd` and `cmdStop` together
- Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
### Web UI
-llama-swap includes a real time web interface with a playground for testing out all sorts of local models:
+llama-swap includes a real time web interface with a playground for testing out all sorts of local models:
@@ -64,16 +64,14 @@ Inspect request and responses:
-Manually load and unload models:
+Manually load and unload models:
-
-Real time log streaming:
+Real time log streaming:
-
## Installation
llama-swap can be installed in multiple ways
diff --git a/config-schema.json b/config-schema.json
index f3f31f20..36161d9a 100644
--- a/config-schema.json
+++ b/config-schema.json
@@ -325,6 +325,44 @@
},
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
},
+ "matrix": {
+ "type": "object",
+ "description": "Solver-based alternative to groups. Declares valid combinations of concurrent models. The solver minimizes eviction cost when swapping. A config must use either groups or matrix, not both.",
+ "required": [
+ "vars",
+ "sets"
+ ],
+ "properties": {
+ "vars": {
+ "type": "object",
+ "description": "Short names for models. Keys must be alphanumeric, 1-8 characters. All sets and evict_costs must use these IDs.",
+ "minProperties": 1,
+ "additionalProperties": {
+ "type": "string"
+ },
+ "propertyNames": {
+ "pattern": "^[a-zA-Z0-9]{1,8}$"
+ }
+ },
+ "evict_costs": {
+ "type": "object",
+ "description": "Relative cost of evicting a running model. Models not listed default to 1. Values must be positive integers.",
+ "additionalProperties": {
+ "type": "integer",
+ "minimum": 1
+ }
+ },
+ "sets": {
+ "type": "object",
+ "description": "Named sets of concurrent model combinations. Values are DSL strings using & (AND), | (OR), () (grouping), and +ref (inline another set). Definition order is used for tie-breaking.",
+ "minProperties": 1,
+ "additionalProperties": {
+ "type": "string"
+ }
+ }
+ },
+ "additionalProperties": false
+ },
"hooks": {
"type": "object",
"properties": {
@@ -456,5 +494,27 @@
"default": {},
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
}
- }
+ },
+ "allOf": [
+ {
+ "if": {
+ "required": ["groups"]
+ },
+ "then": {
+ "not": {
+ "required": ["matrix"]
+ }
+ }
+ },
+ {
+ "if": {
+ "required": ["matrix"]
+ },
+ "then": {
+ "not": {
+ "required": ["groups"]
+ }
+ }
+ }
+ ]
}
\ No newline at end of file
diff --git a/config.example.yaml b/config.example.yaml
index 550573b3..5569dd33 100644
--- a/config.example.yaml
+++ b/config.example.yaml
@@ -331,68 +331,83 @@ models:
# - processes have 5 seconds to shutdown until forceful termination is attempted
cmdStop: docker stop ${MODEL_ID}
-# groups: a dictionary of group settings
-# - optional, default: empty dictionary
-# - provides advanced controls over model swapping behaviour
-# - using groups some models can be kept loaded indefinitely, while others are swapped out
-# - model IDs must be defined in the Models section
-# - a model can only be a member of one group
-# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
-# - see issue #109 for details
+# =============================================================================
+# matrix: run concurrent models with a solver-based swap DSL
+# =============================================================================
#
-# NOTE: the example below uses model names that are not defined above for demonstration purposes
-groups:
- # group1 works the same as the default behaviour of llama-swap where only one model is allowed
- # to run a time across the whole llama-swap instance
- "group1":
- # swap: controls the model swapping behaviour in within the group
- # - optional, default: true
- # - true : only one model is allowed to run at a time
- # - false: all models can run together, no swapping
- swap: true
+# Note:
+# A config must use either a matrix or legacy groups, not both. A configuration error
+# will occur if both are defined. Configuration examples for legacy Groups can be found:
+# https://github.com/mostlygeek/llama-swap/blob/40e39f7/config.example.yaml#L334-L396
+#
+# The matrix declares valid combinations of models that can run concurrently.
+# When a model is requested, the solver finds the cheapest way to make it
+# available by evicting as few (and least costly) running models as possible.
+#
+# Solver behavior:
+# 1. Request arrives for model X
+# 2. If X is already running, forward immediately. Done.
+# 3. Find all sets containing X
+# 4. For each candidate set, compute cost: sum of evict_costs for
+# every running model NOT in that set
+# 5. Pick lowest cost candidate. Ties broken by definition order.
+# 6. Evict what needs to stop. Start X. Forward request.
+#
+# Subset semantics: a set [a, b, c] means any subset is valid.
+# Only the requested model is started — others are not preloaded.
+#
+# A model not appearing in any set can only run alone.
+#
+matrix:
+ # vars: short names for models (alphanumeric, 1-8 chars)
+ # - required for sets and evict_costs settings
+ # - each entry is a short name to a real model ID. Do not use an alias
+ # - used to keep set DSL logic short and easier to read
+ # - sets and evict_costs only use identifiers defined in vars
+ vars:
+ g: gemma-model
+ q: qwen-model
+ m: mistral-model
+ v: voxtral-model
+ e: reranker-model
+ L: llama-70B
+ sd: stable-diffusion
- # exclusive: controls how the group affects other groups
- # - optional, default: true
- # - true: causes all other groups to unload when this group runs a model
- # - false: does not affect other groups
- exclusive: true
+ # evict_costs: relative cost of losing a running model (default: 1)
+ evict_costs:
+ v: 50 # vllm backend, slow cold start
+ L: 30 # 70B weights, slow to load
- # members references the models defined above
- # required
- members:
- - "llama"
- - "qwen-unlisted"
+ # sets: named sets of concurrent model combinations
+ # Values are DSL strings with operators:
+ # & AND (models run together)
+ # | OR (alternatives)
+ # () grouping
+ # +ref inline another set's expression
+ #
+ # Expansion examples:
+ # "L" → [L]
+ # "a & b" → [a, b]
+ # "a | b" → [a], [b]
+ # "(a | b) & c" → [a, c], [b, c]
+ # "(a | b) & (c | d)" → [a,c], [a,d], [b,c], [b,d]
+ # "+llms & v" → expands llms inline, then applies & v
+ sets:
+ # LLM + TTS: switching between g/q/m won't evict v
+ # expands to: [g,v], [q,v], [m,v]
+ standard: "(g | q | m) & v"
- # Example:
- # - in group2 all models can run at the same time
- # - when a different group is loaded it causes all running models in this group to unload
- "group2":
- swap: false
+ # LLM + TTS + reranker
+ # expands to: [g,v,e], [q,v,e]
+ with_rerank: "(g | q) & v & e"
- # exclusive: false does not unload other groups when a model in group2 is requested
- # - the models in group2 will be loaded but will not unload any other groups
- exclusive: false
- members:
- - "docker-llama"
- - "modelA"
- - "modelB"
+ # LLM + image generation, no TTS
+ # expands to: [g,sd], [q,sd]
+ creative: "(g | q) & sd"
- # Example:
- # - a persistent group, prevents other groups from unloading it
- "forever":
- # persistent: prevents over groups from unloading the models in this group
- # - optional, default: false
- # - does not affect individual model behaviour
- persistent: true
-
- # set swap/exclusive to false to prevent swapping inside the group
- # and the unloading of other groups
- swap: false
- exclusive: false
- members:
- - "forever-modelA"
- - "forever-modelB"
- - "forever-modelc"
+ # 70B model uses all GPUs, can only run alone
+ # expands to: [L]
+ full: "L"
# hooks: a dictionary of event triggers and actions
# - optional, default: empty dictionary
diff --git a/proxy/config/config.go b/proxy/config/config.go
index 00f44970..a4c4e1fc 100644
--- a/proxy/config/config.go
+++ b/proxy/config/config.go
@@ -129,6 +129,12 @@ type Config struct {
Profiles map[string][]string `yaml:"profiles"`
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
+ // swap matrix: solver-based alternative to groups
+ Matrix *MatrixConfig `yaml:"matrix"`
+
+ // populated during validation when matrix is configured
+ ExpandedSets []ExpandedSet `yaml:"-"`
+
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
Macros MacroList `yaml:"macros"`
@@ -438,22 +444,35 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
config.Models[modelId] = modelConfig
}
- config = AddDefaultGroupToConfig(config)
+ // groups XOR matrix
+ if config.Matrix != nil && len(config.Groups) > 0 {
+ return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
+ }
- // Validate group members
- memberUsage := make(map[string]string)
- for groupID, groupConfig := range config.Groups {
- prevSet := make(map[string]bool)
- for _, member := range groupConfig.Members {
- if _, found := prevSet[member]; found {
- return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
- }
- prevSet[member] = true
+ if config.Matrix != nil {
+ expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
+ if err != nil {
+ return Config{}, fmt.Errorf("matrix: %w", err)
+ }
+ config.ExpandedSets = expandedSets
+ } else {
+ config = AddDefaultGroupToConfig(config)
- if existingGroup, exists := memberUsage[member]; exists {
- return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
+ // Validate group members
+ memberUsage := make(map[string]string)
+ for groupID, groupConfig := range config.Groups {
+ prevSet := make(map[string]bool)
+ for _, member := range groupConfig.Members {
+ if _, found := prevSet[member]; found {
+ return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
+ }
+ prevSet[member] = true
+
+ if existingGroup, exists := memberUsage[member]; exists {
+ return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
+ }
+ memberUsage[member] = groupID
}
- memberUsage[member] = groupID
}
}
diff --git a/proxy/config/matrix.go b/proxy/config/matrix.go
new file mode 100644
index 00000000..7ac9e046
--- /dev/null
+++ b/proxy/config/matrix.go
@@ -0,0 +1,226 @@
+package config
+
+import (
+ "fmt"
+ "regexp"
+ "sort"
+
+ "gopkg.in/yaml.v3"
+)
+
+var varKeyPattern = regexp.MustCompile(`^[a-zA-Z0-9]{1,8}$`)
+
+// MatrixConfig represents the swap matrix configuration block.
+type MatrixConfig struct {
+ Var map[string]string `yaml:"vars"`
+ EvictCosts map[string]int `yaml:"evict_costs"`
+ Sets OrderedSets `yaml:"sets"`
+}
+
+// SetEntry is a single named set with its DSL expression.
+type SetEntry struct {
+ Name string
+ DSL string
+}
+
+// OrderedSets preserves YAML definition order of sets (used for tie-breaking).
+type OrderedSets []SetEntry
+
+func (os *OrderedSets) UnmarshalYAML(value *yaml.Node) error {
+ if value.Kind != yaml.MappingNode {
+ return fmt.Errorf("sets must be a mapping")
+ }
+
+ entries := make([]SetEntry, 0, len(value.Content)/2)
+ for i := 0; i < len(value.Content); i += 2 {
+ keyNode := value.Content[i]
+ valueNode := value.Content[i+1]
+
+ var name string
+ if err := keyNode.Decode(&name); err != nil {
+ return fmt.Errorf("failed to decode set name: %w", err)
+ }
+
+ var dsl string
+ if err := valueNode.Decode(&dsl); err != nil {
+ return fmt.Errorf("failed to decode DSL for set %q: %w", name, err)
+ }
+
+ entries = append(entries, SetEntry{Name: name, DSL: dsl})
+ }
+
+ *os = entries
+ return nil
+}
+
+// ExpandedSet is one valid combination of concurrent models (real model names).
+type ExpandedSet struct {
+ SetName string
+ DSL string
+ Models []string // real model names, sorted
+}
+
+// ValidateMatrix validates the matrix config and returns all expanded sets.
+func ValidateMatrix(matrix MatrixConfig, models map[string]ModelConfig) ([]ExpandedSet, error) {
+ if len(matrix.Sets) == 0 {
+ return nil, fmt.Errorf("matrix must define at least one set")
+ }
+
+ if len(matrix.Var) == 0 {
+ return nil, fmt.Errorf("matrix must define at least one var")
+ }
+
+ // Validate var entries
+ if matrix.Var != nil {
+ for id, modelName := range matrix.Var {
+ if !varKeyPattern.MatchString(id) {
+ return nil, fmt.Errorf("var key %q must be alphanumeric and 1-8 characters", id)
+ }
+ if _, exists := models[modelName]; !exists {
+ return nil, fmt.Errorf("var key %q references unknown model %q", id, modelName)
+ }
+ }
+ }
+
+ // Validate evict_costs
+ if matrix.EvictCosts != nil {
+ for key, cost := range matrix.EvictCosts {
+ if cost <= 0 {
+ return nil, fmt.Errorf("evict_cost for %q must be a positive integer, got %d", key, cost)
+ }
+ if _, ok := matrix.Var[key]; !ok {
+ return nil, fmt.Errorf("evict_costs: unknown var ID %q", key)
+ }
+ }
+ }
+
+ // Build dependency graph for +ref topological sort
+ setNames := make(map[string]bool)
+ for _, entry := range matrix.Sets {
+ setNames[entry.Name] = true
+ }
+
+ deps := make(map[string][]string) // setName -> set names it depends on
+ for _, entry := range matrix.Sets {
+ refs, err := extractRefs(entry.DSL)
+ if err != nil {
+ return nil, fmt.Errorf("set %q: %w", entry.Name, err)
+ }
+ for _, ref := range refs {
+ if !setNames[ref] {
+ return nil, fmt.Errorf("set %q references undefined set %q", entry.Name, ref)
+ }
+ }
+ deps[entry.Name] = refs
+ }
+
+ // Topological sort with cycle detection
+ order, err := topologicalSort(matrix.Sets, deps)
+ if err != nil {
+ return nil, err
+ }
+
+ // Expand sets in topological order
+ resolvedRefs := make(map[string][][]string) // set name -> expanded alias-level combos
+ var allExpanded []ExpandedSet
+ totalCombinations := 0
+
+ // Build ordered map for efficient lookup
+ setDSL := make(map[string]string)
+ for _, entry := range matrix.Sets {
+ setDSL[entry.Name] = entry.DSL
+ }
+
+ for _, name := range order {
+ dsl := setDSL[name]
+ combos, err := ParseAndExpandDSL(dsl, resolvedRefs)
+ if err != nil {
+ return nil, fmt.Errorf("set %q: %w", name, err)
+ }
+
+ resolvedRefs[name] = combos
+
+ // Resolve var IDs to real model names
+ for _, combo := range combos {
+ resolved := make([]string, len(combo))
+ for i, ident := range combo {
+ realName, ok := matrix.Var[ident]
+ if !ok {
+ return nil, fmt.Errorf("set %q: unknown var ID %q", name, ident)
+ }
+ resolved[i] = realName
+ }
+ sort.Strings(resolved)
+ allExpanded = append(allExpanded, ExpandedSet{
+ SetName: name,
+ DSL: dsl,
+ Models: resolved,
+ })
+ }
+
+ totalCombinations += len(combos)
+ if totalCombinations > maxDSLExpansions {
+ return nil, fmt.Errorf("total expanded combinations (%d) exceed limit of %d", totalCombinations, maxDSLExpansions)
+ }
+ }
+
+ return allExpanded, nil
+}
+
+// topologicalSort returns set names in dependency order.
+// Returns an error if a cycle is detected.
+func topologicalSort(sets OrderedSets, deps map[string][]string) ([]string, error) {
+ // States: 0 = unvisited, 1 = visiting, 2 = visited
+ state := make(map[string]int)
+ var order []string
+
+ var visit func(name string) error
+ visit = func(name string) error {
+ switch state[name] {
+ case 1:
+ return fmt.Errorf("circular reference detected involving set %q", name)
+ case 2:
+ return nil
+ }
+ state[name] = 1
+
+ for _, dep := range deps[name] {
+ if err := visit(dep); err != nil {
+ return err
+ }
+ }
+
+ state[name] = 2
+ order = append(order, name)
+ return nil
+ }
+
+ // Visit in definition order for deterministic output
+ for _, entry := range sets {
+ if state[entry.Name] == 0 {
+ if err := visit(entry.Name); err != nil {
+ return nil, err
+ }
+ }
+ }
+
+ return order, nil
+}
+
+// ResolvedEvictCosts returns a map of real model name -> evict cost,
+// resolving var IDs. Models not listed default to 1.
+func (m *MatrixConfig) ResolvedEvictCosts() map[string]int {
+ costs := make(map[string]int)
+ if m.EvictCosts == nil {
+ return costs
+ }
+ for key, cost := range m.EvictCosts {
+ // Resolve var ID if present
+ if realName, ok := m.Var[key]; ok {
+ costs[realName] = cost
+ } else {
+ costs[key] = cost
+ }
+ }
+ return costs
+}
diff --git a/proxy/config/matrix_dsl.go b/proxy/config/matrix_dsl.go
new file mode 100644
index 00000000..2b5097b6
--- /dev/null
+++ b/proxy/config/matrix_dsl.go
@@ -0,0 +1,376 @@
+package config
+
+import (
+ "fmt"
+ "sort"
+ "strings"
+ "unicode"
+)
+
+const maxDSLExpansions = 1000
+
+// Token types for the DSL lexer
+type tokenType int
+
+const (
+ tokIdent tokenType = iota // model alias or name
+ tokAnd // &
+ tokOr // |
+ tokLParen // (
+ tokRParen // )
+ tokRef // +setName
+ tokEOF
+)
+
+type token struct {
+ typ tokenType
+ val string
+}
+
+// tokenize splits a DSL string into tokens.
+func tokenize(input string) ([]token, error) {
+ var tokens []token
+ i := 0
+ runes := []rune(input)
+
+ for i < len(runes) {
+ ch := runes[i]
+
+ // skip whitespace
+ if unicode.IsSpace(ch) {
+ i++
+ continue
+ }
+
+ switch ch {
+ case '&':
+ tokens = append(tokens, token{tokAnd, "&"})
+ i++
+ case '|':
+ tokens = append(tokens, token{tokOr, "|"})
+ i++
+ case '(':
+ tokens = append(tokens, token{tokLParen, "("})
+ i++
+ case ')':
+ tokens = append(tokens, token{tokRParen, ")"})
+ i++
+ case '+':
+ // +ref: read the identifier that follows
+ i++
+ start := i
+ for i < len(runes) && isIdentChar(runes[i]) {
+ i++
+ }
+ if i == start {
+ return nil, fmt.Errorf("expected set name after '+' at position %d", start)
+ }
+ tokens = append(tokens, token{tokRef, string(runes[start:i])})
+ default:
+ if isIdentChar(ch) {
+ start := i
+ for i < len(runes) && isIdentChar(runes[i]) {
+ i++
+ }
+ tokens = append(tokens, token{tokIdent, string(runes[start:i])})
+ } else {
+ return nil, fmt.Errorf("unexpected character %q at position %d", ch, i)
+ }
+ }
+ }
+
+ tokens = append(tokens, token{tokEOF, ""})
+ return tokens, nil
+}
+
+func isIdentChar(ch rune) bool {
+ return unicode.IsLetter(ch) || unicode.IsDigit(ch) || ch == '_' || ch == '-' || ch == '.'
+}
+
+// AST node types
+type dslNode interface {
+ dslNode()
+}
+
+type andNode struct {
+ children []dslNode
+}
+
+type orNode struct {
+ children []dslNode
+}
+
+type leafNode struct {
+ name string
+}
+
+type refNode struct {
+ setName string
+}
+
+func (andNode) dslNode() {}
+func (orNode) dslNode() {}
+func (leafNode) dslNode() {}
+func (refNode) dslNode() {}
+
+// parser holds state for recursive-descent parsing.
+type parser struct {
+ tokens []token
+ pos int
+}
+
+func (p *parser) peek() token {
+ if p.pos < len(p.tokens) {
+ return p.tokens[p.pos]
+ }
+ return token{tokEOF, ""}
+}
+
+func (p *parser) next() token {
+ t := p.peek()
+ if t.typ != tokEOF {
+ p.pos++
+ }
+ return t
+}
+
+func (p *parser) expect(typ tokenType) (token, error) {
+ t := p.next()
+ if t.typ != typ {
+ return t, fmt.Errorf("expected token type %d, got %q", typ, t.val)
+ }
+ return t, nil
+}
+
+// Grammar:
+//
+// expr = andExpr
+// andExpr = orExpr ('&' orExpr)*
+// orExpr = atom ('|' atom)*
+// atom = ident | '+' ident | '(' expr ')'
+//
+// & binds tighter than |, so "a | b & c" means "a | (b & c)"
+func parse(tokens []token) (dslNode, error) {
+ p := &parser{tokens: tokens}
+ node, err := p.parseExpr()
+ if err != nil {
+ return nil, err
+ }
+ if p.peek().typ != tokEOF {
+ return nil, fmt.Errorf("unexpected token %q after expression", p.peek().val)
+ }
+ return node, nil
+}
+
+func (p *parser) parseExpr() (dslNode, error) {
+ return p.parseOrExpr()
+}
+
+func (p *parser) parseOrExpr() (dslNode, error) {
+ left, err := p.parseAndExpr()
+ if err != nil {
+ return nil, err
+ }
+
+ if p.peek().typ == tokOr {
+ children := []dslNode{left}
+ for p.peek().typ == tokOr {
+ p.next() // consume |
+ right, err := p.parseAndExpr()
+ if err != nil {
+ return nil, err
+ }
+ children = append(children, right)
+ }
+ return orNode{children: children}, nil
+ }
+
+ return left, nil
+}
+
+func (p *parser) parseAndExpr() (dslNode, error) {
+ left, err := p.parseAtom()
+ if err != nil {
+ return nil, err
+ }
+
+ if p.peek().typ == tokAnd {
+ children := []dslNode{left}
+ for p.peek().typ == tokAnd {
+ p.next() // consume &
+ right, err := p.parseAtom()
+ if err != nil {
+ return nil, err
+ }
+ children = append(children, right)
+ }
+ return andNode{children: children}, nil
+ }
+
+ return left, nil
+}
+
+func (p *parser) parseAtom() (dslNode, error) {
+ t := p.peek()
+
+ switch t.typ {
+ case tokIdent:
+ p.next()
+ return leafNode{name: t.val}, nil
+
+ case tokRef:
+ p.next()
+ return refNode{setName: t.val}, nil
+
+ case tokLParen:
+ p.next() // consume (
+ node, err := p.parseExpr()
+ if err != nil {
+ return nil, err
+ }
+ if _, err := p.expect(tokRParen); err != nil {
+ return nil, fmt.Errorf("missing closing parenthesis")
+ }
+ return node, nil
+
+ default:
+ return nil, fmt.Errorf("unexpected token %q", t.val)
+ }
+}
+
+// expand walks the AST and produces all combinations.
+// resolvedRefs contains previously expanded sets for +ref resolution.
+func expand(node dslNode, resolvedRefs map[string][][]string) ([][]string, error) {
+ switch n := node.(type) {
+ case leafNode:
+ return [][]string{{n.name}}, nil
+
+ case refNode:
+ expanded, ok := resolvedRefs[n.setName]
+ if !ok {
+ return nil, fmt.Errorf("unknown set reference +%s", n.setName)
+ }
+ // Return a copy
+ result := make([][]string, len(expanded))
+ for i, combo := range expanded {
+ result[i] = make([]string, len(combo))
+ copy(result[i], combo)
+ }
+ return result, nil
+
+ case orNode:
+ // Union of all children's expansions
+ var result [][]string
+ for _, child := range n.children {
+ childResult, err := expand(child, resolvedRefs)
+ if err != nil {
+ return nil, err
+ }
+ result = append(result, childResult...)
+ if len(result) > maxDSLExpansions {
+ return nil, fmt.Errorf("DSL expansion exceeded %d combinations", maxDSLExpansions)
+ }
+ }
+ return result, nil
+
+ case andNode:
+ // Cartesian product across children
+ result := [][]string{{}} // start with one empty combo
+ for _, child := range n.children {
+ childResult, err := expand(child, resolvedRefs)
+ if err != nil {
+ return nil, err
+ }
+ result, err = cartesianProduct(result, childResult, maxDSLExpansions)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return result, nil
+
+ default:
+ return nil, fmt.Errorf("unknown node type %T", node)
+ }
+}
+
+// cartesianProduct computes the cartesian product of two sets of combinations.
+// It returns an error if the product would exceed cap.
+func cartesianProduct(left, right [][]string, cap int) ([][]string, error) {
+ if int64(len(left))*int64(len(right)) > int64(cap) {
+ return nil, fmt.Errorf("DSL expansion exceeded %d combinations", cap)
+ }
+ result := make([][]string, 0, len(left)*len(right))
+ for _, l := range left {
+ for _, r := range right {
+ combo := make([]string, 0, len(l)+len(r))
+ combo = append(combo, l...)
+ combo = append(combo, r...)
+ result = append(result, combo)
+ }
+ }
+ return result, nil
+}
+
+// ParseAndExpandDSL tokenizes, parses, and expands a DSL string.
+// resolvedRefs contains previously expanded sets for +ref inlining.
+func ParseAndExpandDSL(dsl string, resolvedRefs map[string][][]string) ([][]string, error) {
+ dsl = strings.TrimSpace(dsl)
+ if dsl == "" {
+ return nil, fmt.Errorf("empty DSL expression")
+ }
+
+ tokens, err := tokenize(dsl)
+ if err != nil {
+ return nil, fmt.Errorf("tokenize: %w", err)
+ }
+
+ tree, err := parse(tokens)
+ if err != nil {
+ return nil, fmt.Errorf("parse: %w", err)
+ }
+
+ result, err := expand(tree, resolvedRefs)
+ if err != nil {
+ return nil, err
+ }
+
+ // Deduplicate models within each combination and sort for consistency
+ for i, combo := range result {
+ result[i] = dedupAndSort(combo)
+ }
+
+ return result, nil
+}
+
+// dedupAndSort removes duplicate entries and sorts alphabetically.
+func dedupAndSort(items []string) []string {
+ seen := make(map[string]bool, len(items))
+ var unique []string
+ for _, item := range items {
+ if !seen[item] {
+ seen[item] = true
+ unique = append(unique, item)
+ }
+ }
+ sort.Strings(unique)
+ return unique
+}
+
+// extractRefs scans a DSL string for +ref tokens without full parsing.
+// Used for building the dependency graph for topological sorting.
+func extractRefs(dsl string) ([]string, error) {
+ tokens, err := tokenize(dsl)
+ if err != nil {
+ return nil, err
+ }
+
+ var refs []string
+ seen := make(map[string]bool)
+ for _, t := range tokens {
+ if t.typ == tokRef && !seen[t.val] {
+ seen[t.val] = true
+ refs = append(refs, t.val)
+ }
+ }
+ return refs, nil
+}
diff --git a/proxy/config/matrix_dsl_test.go b/proxy/config/matrix_dsl_test.go
new file mode 100644
index 00000000..96a6454b
--- /dev/null
+++ b/proxy/config/matrix_dsl_test.go
@@ -0,0 +1,300 @@
+package config
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestDSL_Tokenize(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ expect []token
+ errMsg string
+ }{
+ {
+ name: "single identifier",
+ input: "abc",
+ expect: []token{
+ {tokIdent, "abc"},
+ {tokEOF, ""},
+ },
+ },
+ {
+ name: "identifier with hyphens and dots",
+ input: "model-name.v2",
+ expect: []token{
+ {tokIdent, "model-name.v2"},
+ {tokEOF, ""},
+ },
+ },
+ {
+ name: "and expression",
+ input: "a & b",
+ expect: []token{
+ {tokIdent, "a"},
+ {tokAnd, "&"},
+ {tokIdent, "b"},
+ {tokEOF, ""},
+ },
+ },
+ {
+ name: "or expression",
+ input: "a | b",
+ expect: []token{
+ {tokIdent, "a"},
+ {tokOr, "|"},
+ {tokIdent, "b"},
+ {tokEOF, ""},
+ },
+ },
+ {
+ name: "parentheses",
+ input: "(a | b) & c",
+ expect: []token{
+ {tokLParen, "("},
+ {tokIdent, "a"},
+ {tokOr, "|"},
+ {tokIdent, "b"},
+ {tokRParen, ")"},
+ {tokAnd, "&"},
+ {tokIdent, "c"},
+ {tokEOF, ""},
+ },
+ },
+ {
+ name: "ref token",
+ input: "+llms & v",
+ expect: []token{
+ {tokRef, "llms"},
+ {tokAnd, "&"},
+ {tokIdent, "v"},
+ {tokEOF, ""},
+ },
+ },
+ {
+ name: "no whitespace",
+ input: "(a|b)&c",
+ expect: []token{
+ {tokLParen, "("},
+ {tokIdent, "a"},
+ {tokOr, "|"},
+ {tokIdent, "b"},
+ {tokRParen, ")"},
+ {tokAnd, "&"},
+ {tokIdent, "c"},
+ {tokEOF, ""},
+ },
+ },
+ {
+ name: "empty ref",
+ input: "+",
+ errMsg: "expected set name after '+'",
+ },
+ {
+ name: "invalid character",
+ input: "a @ b",
+ errMsg: "unexpected character",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tokens, err := tokenize(tt.input)
+ if tt.errMsg != "" {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errMsg)
+ } else {
+ require.NoError(t, err)
+ assert.Equal(t, tt.expect, tokens)
+ }
+ })
+ }
+}
+
+func TestDSL_ParseAndExpand(t *testing.T) {
+ tests := []struct {
+ name string
+ dsl string
+ refs map[string][][]string
+ expect [][]string
+ errMsg string
+ }{
+ {
+ name: "single model",
+ dsl: "L",
+ expect: [][]string{{"L"}},
+ },
+ {
+ name: "two models with AND",
+ dsl: "a & b",
+ expect: [][]string{{"a", "b"}},
+ },
+ {
+ name: "two models with OR",
+ dsl: "a | b",
+ expect: [][]string{{"a"}, {"b"}},
+ },
+ {
+ name: "three models with OR",
+ dsl: "a | b | c",
+ expect: [][]string{{"a"}, {"b"}, {"c"}},
+ },
+ {
+ name: "cartesian product (a|b) & (c|d)",
+ dsl: "(a | b) & (c | d)",
+ expect: [][]string{
+ {"a", "c"},
+ {"a", "d"},
+ {"b", "c"},
+ {"b", "d"},
+ },
+ },
+ {
+ name: "three-way AND",
+ dsl: "a & b & c",
+ expect: [][]string{
+ {"a", "b", "c"},
+ },
+ },
+ {
+ name: "(g | q | m) & v",
+ dsl: "(g | q | m) & v",
+ expect: [][]string{
+ {"g", "v"},
+ {"q", "v"},
+ {"m", "v"},
+ },
+ },
+ {
+ name: "(g | q) & v & e",
+ dsl: "(g | q) & v & e",
+ expect: [][]string{
+ {"e", "g", "v"},
+ {"e", "q", "v"},
+ },
+ },
+ {
+ name: "precedence: a | b & c means a | (b & c)",
+ dsl: "a | b & c",
+ expect: [][]string{
+ {"a"},
+ {"b", "c"},
+ },
+ },
+ {
+ name: "+ref inlining",
+ dsl: "+llms & v",
+ refs: map[string][][]string{
+ "llms": {{"g"}, {"q"}, {"m"}},
+ },
+ expect: [][]string{
+ {"g", "v"},
+ {"q", "v"},
+ {"m", "v"},
+ },
+ },
+ {
+ name: "+ref chained",
+ dsl: "+with_tts & e",
+ refs: map[string][][]string{
+ "with_tts": {{"g", "v"}, {"q", "v"}, {"m", "v"}},
+ },
+ expect: [][]string{
+ {"e", "g", "v"},
+ {"e", "q", "v"},
+ {"e", "m", "v"},
+ },
+ },
+ {
+ name: "dedup within combination",
+ dsl: "a & a",
+ expect: [][]string{
+ {"a"},
+ },
+ },
+ {
+ name: "empty expression",
+ dsl: "",
+ errMsg: "empty DSL expression",
+ },
+ {
+ name: "unmatched open paren",
+ dsl: "(a | b",
+ errMsg: "missing closing parenthesis",
+ },
+ {
+ name: "unmatched close paren",
+ dsl: "a | b)",
+ errMsg: "unexpected token",
+ },
+ {
+ name: "unknown ref",
+ dsl: "+unknown",
+ errMsg: "unknown set reference +unknown",
+ },
+ {
+ name: "empty parens",
+ dsl: "()",
+ errMsg: "unexpected token",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ refs := tt.refs
+ if refs == nil {
+ refs = map[string][][]string{}
+ }
+ result, err := ParseAndExpandDSL(tt.dsl, refs)
+ if tt.errMsg != "" {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errMsg)
+ } else {
+ require.NoError(t, err)
+ assert.Equal(t, tt.expect, result)
+ }
+ })
+ }
+}
+
+func TestDSL_ExpansionCap(t *testing.T) {
+ // Build an expression that would exceed 1000 combinations:
+ // (a1|a2|...|a32) & (b1|b2|...|b32) = 1024 combos
+ var aItems, bItems []string
+ for i := 0; i < 32; i++ {
+ aItems = append(aItems, fmt.Sprintf("a%d", i))
+ bItems = append(bItems, fmt.Sprintf("b%d", i))
+ }
+ dsl := fmt.Sprintf("(%s) & (%s)",
+ join(aItems, " | "),
+ join(bItems, " | "),
+ )
+ _, err := ParseAndExpandDSL(dsl, map[string][][]string{})
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "exceeded")
+}
+
+func TestDSL_ExtractRefs(t *testing.T) {
+ refs, err := extractRefs("+llms & v & +other")
+ require.NoError(t, err)
+ assert.Equal(t, []string{"llms", "other"}, refs)
+
+ refs, err = extractRefs("a & b")
+ require.NoError(t, err)
+ assert.Empty(t, refs)
+}
+
+func join(items []string, sep string) string {
+ result := ""
+ for i, item := range items {
+ if i > 0 {
+ result += sep
+ }
+ result += item
+ }
+ return result
+}
diff --git a/proxy/config/matrix_test.go b/proxy/config/matrix_test.go
new file mode 100644
index 00000000..6d497b7d
--- /dev/null
+++ b/proxy/config/matrix_test.go
@@ -0,0 +1,305 @@
+package config
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func makeModels(names ...string) map[string]ModelConfig {
+ m := make(map[string]ModelConfig)
+ for _, name := range names {
+ m[name] = ModelConfig{Cmd: "echo " + name}
+ }
+ return m
+}
+
+func TestValidateMatrix_Basic(t *testing.T) {
+ models := makeModels("gemma", "qwen", "mistral", "voxtral", "llama70B")
+
+ matrix := MatrixConfig{
+ Var: map[string]string{
+ "g": "gemma",
+ "q": "qwen",
+ "m": "mistral",
+ "v": "voxtral",
+ "L": "llama70B",
+ },
+ EvictCosts: map[string]int{
+ "L": 30,
+ "v": 50,
+ },
+ Sets: OrderedSets{
+ {Name: "standard", DSL: "(g | q | m) & v"},
+ {Name: "full", DSL: "L"},
+ },
+ }
+
+ expanded, err := ValidateMatrix(matrix, models)
+ require.NoError(t, err)
+
+ // standard expands to [gemma,voxtral], [qwen,voxtral], [mistral,voxtral]
+ // full expands to [llama70B]
+ assert.Len(t, expanded, 4)
+
+ assert.Equal(t, "standard", expanded[0].SetName)
+ assert.Equal(t, []string{"gemma", "voxtral"}, expanded[0].Models)
+
+ assert.Equal(t, "standard", expanded[1].SetName)
+ assert.Equal(t, []string{"qwen", "voxtral"}, expanded[1].Models)
+
+ assert.Equal(t, "standard", expanded[2].SetName)
+ assert.Equal(t, []string{"mistral", "voxtral"}, expanded[2].Models)
+
+ assert.Equal(t, "full", expanded[3].SetName)
+ assert.Equal(t, []string{"llama70B"}, expanded[3].Models)
+}
+
+func TestValidateMatrix_WithRef(t *testing.T) {
+ models := makeModels("gemma", "qwen", "mistral", "voxtral", "reranker")
+
+ matrix := MatrixConfig{
+ Var: map[string]string{
+ "g": "gemma",
+ "q": "qwen",
+ "m": "mistral",
+ "v": "voxtral",
+ "e": "reranker",
+ },
+ Sets: OrderedSets{
+ {Name: "llms", DSL: "g | q | m"},
+ {Name: "with_tts", DSL: "+llms & v"},
+ {Name: "mega", DSL: "+with_tts & e"},
+ },
+ }
+
+ expanded, err := ValidateMatrix(matrix, models)
+ require.NoError(t, err)
+
+ // llms: [gemma], [qwen], [mistral]
+ // with_tts: [gemma,voxtral], [qwen,voxtral], [mistral,voxtral]
+ // mega: [gemma,reranker,voxtral], [qwen,reranker,voxtral], [mistral,reranker,voxtral]
+ assert.Len(t, expanded, 9)
+
+ // Check mega entries
+ megaEntries := filterBySetName(expanded, "mega")
+ assert.Len(t, megaEntries, 3)
+ assert.Equal(t, []string{"gemma", "reranker", "voxtral"}, megaEntries[0].Models)
+}
+
+func TestValidateMatrix_MapIDRequired(t *testing.T) {
+ // DSL cannot use real model names directly — must use var IDs
+ models := makeModels("gemma", "voxtral")
+
+ matrix := MatrixConfig{
+ Var: map[string]string{"g": "gemma"},
+ Sets: OrderedSets{
+ {Name: "combo", DSL: "g & voxtral"},
+ },
+ }
+
+ _, err := ValidateMatrix(matrix, models)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "unknown var ID")
+}
+
+func TestValidateMatrix_InvalidAliasKey(t *testing.T) {
+ models := makeModels("gemma")
+
+ tests := []struct {
+ name string
+ alias string
+ errMsg string
+ }{
+ {"too long", "abcdefghi", "alphanumeric and 1-8 characters"},
+ {"has underscore", "a_b", "alphanumeric and 1-8 characters"},
+ {"has hyphen", "a-b", "alphanumeric and 1-8 characters"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ matrix := MatrixConfig{
+ Var: map[string]string{tt.alias: "gemma"},
+ Sets: OrderedSets{{Name: "s", DSL: tt.alias}},
+ }
+ _, err := ValidateMatrix(matrix, models)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errMsg)
+ })
+ }
+}
+
+func TestValidateMatrix_AliasReferencesUnknownModel(t *testing.T) {
+ models := makeModels("gemma")
+
+ matrix := MatrixConfig{
+ Var: map[string]string{"x": "nonexistent"},
+ Sets: OrderedSets{{Name: "s", DSL: "x"}},
+ }
+
+ _, err := ValidateMatrix(matrix, models)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "unknown model")
+}
+
+func TestValidateMatrix_EvictCostInvalid(t *testing.T) {
+ models := makeModels("gemma")
+
+ t.Run("zero cost", func(t *testing.T) {
+ matrix := MatrixConfig{
+ Var: map[string]string{"g": "gemma"},
+ EvictCosts: map[string]int{"g": 0},
+ Sets: OrderedSets{{Name: "s", DSL: "g"}},
+ }
+ _, err := ValidateMatrix(matrix, models)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "positive integer")
+ })
+
+ t.Run("negative cost", func(t *testing.T) {
+ matrix := MatrixConfig{
+ Var: map[string]string{"g": "gemma"},
+ EvictCosts: map[string]int{"g": -1},
+ Sets: OrderedSets{{Name: "s", DSL: "g"}},
+ }
+ _, err := ValidateMatrix(matrix, models)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "positive integer")
+ })
+
+ t.Run("unknown var ID in evict_costs", func(t *testing.T) {
+ matrix := MatrixConfig{
+ Var: map[string]string{"g": "gemma"},
+ EvictCosts: map[string]int{"unknown": 5},
+ Sets: OrderedSets{{Name: "s", DSL: "g"}},
+ }
+ _, err := ValidateMatrix(matrix, models)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "unknown var ID")
+ })
+}
+
+func TestValidateMatrix_CycleDetection(t *testing.T) {
+ models := makeModels("gemma")
+
+ matrix := MatrixConfig{
+ Var: map[string]string{"g": "gemma"},
+ Sets: OrderedSets{
+ {Name: "a", DSL: "+b"},
+ {Name: "b", DSL: "+a"},
+ },
+ }
+
+ _, err := ValidateMatrix(matrix, models)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "circular reference")
+}
+
+func TestValidateMatrix_UndefinedRefTarget(t *testing.T) {
+ models := makeModels("gemma")
+
+ matrix := MatrixConfig{
+ Var: map[string]string{"g": "gemma"},
+ Sets: OrderedSets{
+ {Name: "a", DSL: "+nonexistent"},
+ },
+ }
+
+ _, err := ValidateMatrix(matrix, models)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "references undefined set")
+}
+
+func TestValidateMatrix_NoSets(t *testing.T) {
+ _, err := ValidateMatrix(MatrixConfig{}, makeModels("gemma"))
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "at least one set")
+}
+
+func TestValidateMatrix_UnknownMapIDInDSL(t *testing.T) {
+ models := makeModels("gemma")
+
+ matrix := MatrixConfig{
+ Var: map[string]string{"g": "gemma"},
+ Sets: OrderedSets{
+ {Name: "s", DSL: "g & nonexistent"},
+ },
+ }
+
+ _, err := ValidateMatrix(matrix, models)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "unknown var ID")
+}
+
+func TestValidateMatrix_ResolvedEvictCosts(t *testing.T) {
+ mc := &MatrixConfig{
+ Var: map[string]string{
+ "g": "gemma",
+ "L": "llama70B",
+ },
+ EvictCosts: map[string]int{
+ "L": 30,
+ "g": 5,
+ },
+ }
+
+ costs := mc.ResolvedEvictCosts()
+ assert.Equal(t, 30, costs["llama70B"])
+ assert.Equal(t, 5, costs["gemma"])
+}
+
+func TestValidateMatrix_ConfigXOR(t *testing.T) {
+ // groups and matrix both defined
+ yaml := `
+models:
+ model1:
+ cmd: echo model1
+ proxy: http://localhost:8080
+groups:
+ group1:
+ members:
+ - model1
+matrix:
+ sets:
+ s: "model1"
+`
+ _, err := LoadConfigFromReader(strings.NewReader(yaml))
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "cannot use both")
+}
+
+func TestValidateMatrix_ConfigMatrixOnly(t *testing.T) {
+ yaml := `
+models:
+ gemma:
+ cmd: echo gemma
+ proxy: http://localhost:8080
+ qwen:
+ cmd: echo qwen
+ proxy: http://localhost:8081
+matrix:
+ vars:
+ g: gemma
+ q: qwen
+ sets:
+ combo: "g | q"
+`
+ cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
+ require.NoError(t, err)
+ assert.NotNil(t, cfg.Matrix)
+ assert.Len(t, cfg.ExpandedSets, 2)
+ // Groups should be empty when matrix is used
+ assert.Empty(t, cfg.Groups)
+}
+
+func filterBySetName(sets []ExpandedSet, name string) []ExpandedSet {
+ var result []ExpandedSet
+ for _, s := range sets {
+ if s.SetName == name {
+ result = append(result, s)
+ }
+ }
+ return result
+}
diff --git a/proxy/matrix.go b/proxy/matrix.go
new file mode 100644
index 00000000..f3f7227a
--- /dev/null
+++ b/proxy/matrix.go
@@ -0,0 +1,298 @@
+package proxy
+
+import (
+ "fmt"
+ "net/http"
+ "slices"
+ "sort"
+ "sync"
+
+ "github.com/mostlygeek/llama-swap/proxy/config"
+)
+
+// MatrixSolver contains pure swap-decision logic with no Process dependencies.
+// It is safe for concurrent reads after construction.
+type MatrixSolver struct {
+ expandedSets []config.ExpandedSet // all valid model combinations
+ evictCosts map[string]int // real model name -> eviction cost (default 1)
+ modelToSets map[string][]int // model name -> indices into expandedSets
+}
+
+// NewMatrixSolver builds a solver from expanded sets and eviction costs.
+func NewMatrixSolver(expandedSets []config.ExpandedSet, evictCosts map[string]int) *MatrixSolver {
+ modelToSets := make(map[string][]int)
+ for i, es := range expandedSets {
+ for _, model := range es.Models {
+ modelToSets[model] = append(modelToSets[model], i)
+ }
+ }
+
+ return &MatrixSolver{
+ expandedSets: expandedSets,
+ evictCosts: evictCosts,
+ modelToSets: modelToSets,
+ }
+}
+
+// SolveResult describes what the solver decided.
+type SolveResult struct {
+ Evict []string // running models that must be stopped
+ TargetSet []string // the chosen set of models (for informational purposes)
+ SetName string // name of the chosen set
+ DSL string // original DSL expression for the chosen set
+ TotalCost int // total eviction cost
+}
+
+// Solve determines which models to evict when a model is requested.
+//
+// Algorithm:
+// 1. If requestedModel is already running, no eviction needed.
+// 2. Find all sets containing requestedModel.
+// 3. If no sets found, the model runs alone; evict all running models.
+// 4. For each candidate set, compute cost = sum of evict_costs for running
+// models NOT in that set.
+// 5. Pick lowest cost. Ties broken by definition order (index in expandedSets).
+// 6. Return models to evict and the chosen set.
+func (s *MatrixSolver) Solve(requestedModel string, runningModels []string) (SolveResult, error) {
+ // If already running, nothing to do (but fill in set info for logging)
+ if slices.Contains(runningModels, requestedModel) {
+ setName, dsl := s.findMatchingSet(requestedModel, runningModels)
+ return SolveResult{
+ TargetSet: runningModels,
+ SetName: setName,
+ DSL: dsl,
+ }, nil
+ }
+
+ candidateIndices := s.modelToSets[requestedModel]
+
+ // Model not in any set: runs alone, evict everything
+ if len(candidateIndices) == 0 {
+ evict := make([]string, len(runningModels))
+ copy(evict, runningModels)
+ return SolveResult{
+ Evict: evict,
+ TargetSet: []string{requestedModel},
+ }, nil
+ }
+
+ // Find the cheapest candidate set
+ bestCost := -1
+ bestIdx := -1
+
+ for _, idx := range candidateIndices {
+ setModels := s.expandedSets[idx].Models
+ cost := 0
+ for _, running := range runningModels {
+ if !slices.Contains(setModels, running) {
+ cost += s.evictCost(running)
+ }
+ }
+
+ if bestCost < 0 || cost < bestCost || (cost == bestCost && idx < bestIdx) {
+ bestCost = cost
+ bestIdx = idx
+ }
+ }
+
+ // Determine which running models to evict
+ chosen := s.expandedSets[bestIdx]
+ var evict []string
+ for _, running := range runningModels {
+ if !slices.Contains(chosen.Models, running) {
+ evict = append(evict, running)
+ }
+ }
+
+ return SolveResult{
+ Evict: evict,
+ TargetSet: chosen.Models,
+ SetName: chosen.SetName,
+ DSL: chosen.DSL,
+ TotalCost: bestCost,
+ }, nil
+}
+
+// findMatchingSet finds the expanded set that contains all running models.
+// Returns the set name and DSL, or empty strings if no match.
+func (s *MatrixSolver) findMatchingSet(requestedModel string, runningModels []string) (string, string) {
+ for _, idx := range s.modelToSets[requestedModel] {
+ set := s.expandedSets[idx]
+ allInSet := true
+ for _, m := range runningModels {
+ if !slices.Contains(set.Models, m) {
+ allInSet = false
+ break
+ }
+ }
+ if allInSet {
+ return set.SetName, set.DSL
+ }
+ }
+ return "", ""
+}
+
+func (s *MatrixSolver) evictCost(model string) int {
+ if cost, ok := s.evictCosts[model]; ok {
+ return cost
+ }
+ return 1
+}
+
+// Matrix manages processes using solver-based swap logic.
+type Matrix struct {
+ sync.Mutex
+ solver *MatrixSolver
+ processes map[string]*Process // all processes keyed by real model name
+ config config.Config
+ proxyLogger *LogMonitor
+ upstreamLogger *LogMonitor
+}
+
+// NewMatrix creates a Matrix from config. It creates a Process for every
+// model defined in the config (any model can run alone even if not in a set).
+func NewMatrix(cfg config.Config, proxyLogger, upstreamLogger *LogMonitor) *Matrix {
+ processes := make(map[string]*Process)
+ for modelID, modelConfig := range cfg.Models {
+ processLogger := NewLogMonitorWriter(upstreamLogger)
+ process := NewProcess(modelID, cfg.HealthCheckTimeout, modelConfig, processLogger, proxyLogger)
+ processes[modelID] = process
+ }
+
+ evictCosts := cfg.Matrix.ResolvedEvictCosts()
+
+ return &Matrix{
+ solver: NewMatrixSolver(cfg.ExpandedSets, evictCosts),
+ processes: processes,
+ config: cfg,
+ proxyLogger: proxyLogger,
+ upstreamLogger: upstreamLogger,
+ }
+}
+
+// ProxyRequest handles the swap logic and proxies the request to the model.
+func (m *Matrix) ProxyRequest(modelID string, w http.ResponseWriter, r *http.Request) error {
+ process, ok := m.processes[modelID]
+ if !ok {
+ return fmt.Errorf("model %s not found in matrix", modelID)
+ }
+
+ m.Lock()
+ running := m.runningModels()
+ result, err := m.solver.Solve(modelID, running)
+ if err != nil {
+ m.Unlock()
+ return fmt.Errorf("matrix solver error: %w", err)
+ }
+
+ // Log solver decision
+ if len(result.Evict) > 0 {
+ m.proxyLogger.Infof("Matrix: model=%s set=%s dsl=%q evict=%v target=%v cost=%d",
+ modelID, result.SetName, result.DSL, result.Evict, result.TargetSet, result.TotalCost)
+ } else if len(running) == 0 {
+ m.proxyLogger.Infof("Matrix: model=%s starting (no models running)", modelID)
+ } else {
+ m.proxyLogger.Debugf("Matrix: model=%s already running in set=%s dsl=%q", modelID, result.SetName, result.DSL)
+ }
+
+ // Evict models that need to be stopped
+ if len(result.Evict) > 0 {
+ var wg sync.WaitGroup
+ for _, evictModel := range result.Evict {
+ if p, exists := m.processes[evictModel]; exists {
+ wg.Add(1)
+ go func(p *Process) {
+ defer wg.Done()
+ p.Stop()
+ }(p)
+ }
+ }
+ wg.Wait()
+ }
+ m.Unlock()
+
+ // Proxy the request (Process handles on-demand start)
+ process.ProxyRequest(w, r)
+ return nil
+}
+
+// StopProcesses stops all running processes.
+func (m *Matrix) StopProcesses(strategy StopStrategy) {
+ m.Lock()
+ defer m.Unlock()
+
+ var wg sync.WaitGroup
+ for _, process := range m.processes {
+ wg.Add(1)
+ go func(p *Process) {
+ defer wg.Done()
+ switch strategy {
+ case StopImmediately:
+ p.StopImmediately()
+ default:
+ p.Stop()
+ }
+ }(process)
+ }
+ wg.Wait()
+}
+
+// StopProcess stops a single process by model ID.
+func (m *Matrix) StopProcess(modelID string, strategy StopStrategy) error {
+ process, ok := m.processes[modelID]
+ if !ok {
+ return fmt.Errorf("process not found for %s", modelID)
+ }
+
+ switch strategy {
+ case StopImmediately:
+ process.StopImmediately()
+ default:
+ process.Stop()
+ }
+ return nil
+}
+
+// Shutdown shuts down all processes.
+func (m *Matrix) Shutdown() {
+ var wg sync.WaitGroup
+ for _, process := range m.processes {
+ wg.Add(1)
+ go func(p *Process) {
+ defer wg.Done()
+ p.Shutdown()
+ }(process)
+ }
+ wg.Wait()
+}
+
+// RunningModels returns model names currently in StateReady.
+func (m *Matrix) RunningModels() []string {
+ m.Lock()
+ defer m.Unlock()
+ return m.runningModels()
+}
+
+// runningModels returns running model names (caller must hold lock).
+func (m *Matrix) runningModels() []string {
+ var running []string
+ for id, process := range m.processes {
+ if process.CurrentState() == StateReady {
+ running = append(running, id)
+ }
+ }
+ sort.Strings(running)
+ return running
+}
+
+// GetProcess returns the Process for a model.
+func (m *Matrix) GetProcess(modelID string) (*Process, bool) {
+ p, ok := m.processes[modelID]
+ return p, ok
+}
+
+// HasModel returns true if the model is managed by this matrix.
+func (m *Matrix) HasModel(modelID string) bool {
+ _, ok := m.processes[modelID]
+ return ok
+}
diff --git a/proxy/matrix_test.go b/proxy/matrix_test.go
new file mode 100644
index 00000000..81d5a1a8
--- /dev/null
+++ b/proxy/matrix_test.go
@@ -0,0 +1,227 @@
+package proxy
+
+import (
+ "testing"
+
+ "github.com/mostlygeek/llama-swap/proxy/config"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// Helper to build expanded sets for solver tests
+func makeExpandedSets(sets ...struct {
+ name string
+ models []string
+}) []config.ExpandedSet {
+ var result []config.ExpandedSet
+ for _, s := range sets {
+ result = append(result, config.ExpandedSet{
+ SetName: s.name,
+ Models: s.models,
+ })
+ }
+ return result
+}
+
+func es(name string, models ...string) struct {
+ name string
+ models []string
+} {
+ return struct {
+ name string
+ models []string
+ }{name, models}
+}
+
+func TestMatrixSolver_AlreadyRunning(t *testing.T) {
+ solver := NewMatrixSolver(
+ makeExpandedSets(es("s1", "a", "b")),
+ nil,
+ )
+
+ result, err := solver.Solve("a", []string{"a"})
+ require.NoError(t, err)
+ assert.Empty(t, result.Evict)
+ assert.Equal(t, []string{"a"}, result.TargetSet)
+ assert.Equal(t, "s1", result.SetName)
+}
+
+func TestMatrixSolver_NotInAnySet_RunsAlone(t *testing.T) {
+ solver := NewMatrixSolver(
+ makeExpandedSets(es("s1", "a", "b")),
+ nil,
+ )
+
+ // Model "c" not in any set
+ result, err := solver.Solve("c", []string{"a", "b"})
+ require.NoError(t, err)
+ assert.ElementsMatch(t, []string{"a", "b"}, result.Evict)
+ assert.Equal(t, []string{"c"}, result.TargetSet)
+}
+
+func TestMatrixSolver_NotInAnySet_NothingRunning(t *testing.T) {
+ solver := NewMatrixSolver(
+ makeExpandedSets(es("s1", "a", "b")),
+ nil,
+ )
+
+ result, err := solver.Solve("c", []string{})
+ require.NoError(t, err)
+ assert.Empty(t, result.Evict)
+ assert.Equal(t, []string{"c"}, result.TargetSet)
+}
+
+func TestMatrixSolver_SingleSet_EvictsNonMembers(t *testing.T) {
+ // Set: [a, b]. Request a when b and c are running.
+ solver := NewMatrixSolver(
+ makeExpandedSets(es("s1", "a", "b")),
+ nil,
+ )
+
+ result, err := solver.Solve("a", []string{"b", "c"})
+ require.NoError(t, err)
+ // c is not in the set, so it gets evicted. b is in the set, so it stays.
+ assert.Equal(t, []string{"c"}, result.Evict)
+ assert.Equal(t, []string{"a", "b"}, result.TargetSet)
+}
+
+func TestMatrixSolver_PicksLowestCost(t *testing.T) {
+ // Two sets containing model "a":
+ // s1: [a, v] — if v is running, cost=0; if L is running, cost=30
+ // s2: [a, L] — if L is running, cost=0; if v is running, cost=50
+ solver := NewMatrixSolver(
+ makeExpandedSets(
+ es("s1", "a", "v"),
+ es("s2", "a", "L"),
+ ),
+ map[string]int{"v": 50, "L": 30},
+ )
+
+ // v is running. Switching to a:
+ // s1 cost: v is in s1, so 0
+ // s2 cost: v is NOT in s2, so 50
+ // => pick s1
+ result, err := solver.Solve("a", []string{"v"})
+ require.NoError(t, err)
+ assert.Empty(t, result.Evict)
+ assert.Equal(t, []string{"a", "v"}, result.TargetSet)
+
+ // L is running. Switching to a:
+ // s1 cost: L is NOT in s1, so 30
+ // s2 cost: L is in s2, so 0
+ // => pick s2
+ result, err = solver.Solve("a", []string{"L"})
+ require.NoError(t, err)
+ assert.Empty(t, result.Evict)
+ assert.Equal(t, []string{"a", "L"}, result.TargetSet)
+}
+
+func TestMatrixSolver_TieBreakingByDefinitionOrder(t *testing.T) {
+ // Two sets with identical cost. Definition order should win.
+ solver := NewMatrixSolver(
+ makeExpandedSets(
+ es("s1", "a", "x"),
+ es("s2", "a", "y"),
+ ),
+ nil,
+ )
+
+ // Nothing running, both sets cost 0. s1 is first.
+ result, err := solver.Solve("a", []string{})
+ require.NoError(t, err)
+ assert.Empty(t, result.Evict)
+ assert.Equal(t, []string{"a", "x"}, result.TargetSet)
+}
+
+func TestMatrixSolver_EvictCostPreservesExpensive(t *testing.T) {
+ // Model "v" costs 50 to evict, "m" costs 1 (default).
+ // Sets: [g,v], [g,m]
+ // Running: v, m. Request g.
+ // s1=[g,v]: evict m (cost 1), keep v
+ // s2=[g,m]: evict v (cost 50), keep m
+ // => pick s1
+ solver := NewMatrixSolver(
+ makeExpandedSets(
+ es("s1", "g", "v"),
+ es("s2", "g", "m"),
+ ),
+ map[string]int{"v": 50},
+ )
+
+ result, err := solver.Solve("g", []string{"v", "m"})
+ require.NoError(t, err)
+ assert.Equal(t, []string{"m"}, result.Evict)
+ assert.Equal(t, []string{"g", "v"}, result.TargetSet)
+}
+
+func TestMatrixSolver_NothingRunning(t *testing.T) {
+ solver := NewMatrixSolver(
+ makeExpandedSets(
+ es("s1", "g", "v"),
+ es("s2", "q", "v"),
+ ),
+ nil,
+ )
+
+ result, err := solver.Solve("g", []string{})
+ require.NoError(t, err)
+ assert.Empty(t, result.Evict)
+ assert.Equal(t, []string{"g", "v"}, result.TargetSet)
+}
+
+func TestMatrixSolver_FullScenario(t *testing.T) {
+ // Simulates the example config:
+ // standard: [g,v], [q,v], [m,v]
+ // with_rerank: [g,v,e], [q,v,e]
+ // creative: [g,sd], [q,sd]
+ // full: [L]
+ solver := NewMatrixSolver(
+ makeExpandedSets(
+ es("standard", "g", "v"),
+ es("standard", "q", "v"),
+ es("standard", "m", "v"),
+ es("with_rerank", "e", "g", "v"),
+ es("with_rerank", "e", "q", "v"),
+ es("creative", "g", "sd"),
+ es("creative", "q", "sd"),
+ es("full", "L"),
+ ),
+ map[string]int{"v": 50, "L": 30, "whisper": 10},
+ )
+
+ // Running: g, v. Request q.
+ // standard[q,v]: evict g (cost 1), keep v. Total: 1.
+ // with_rerank[q,v,e]: evict g (cost 1), keep v. Total: 1.
+ // => tie, pick first by definition order = standard[q,v]
+ result, err := solver.Solve("q", []string{"g", "v"})
+ require.NoError(t, err)
+ assert.Equal(t, []string{"g"}, result.Evict)
+ assert.Equal(t, []string{"q", "v"}, result.TargetSet)
+
+ // Running: g, v. Request L.
+ // full[L]: evict g (cost 1) + v (cost 50). Total: 51.
+ // Only one set contains L, so pick it.
+ result, err = solver.Solve("L", []string{"g", "v"})
+ require.NoError(t, err)
+ assert.ElementsMatch(t, []string{"g", "v"}, result.Evict)
+ assert.Equal(t, []string{"L"}, result.TargetSet)
+
+ // Running: g, v. Request sd.
+ // creative[g,sd]: evict v (cost 50). Total: 50.
+ // creative[q,sd]: evict g (cost 1) + v (cost 50). Total: 51.
+ // => pick creative[g,sd]
+ result, err = solver.Solve("sd", []string{"g", "v"})
+ require.NoError(t, err)
+ assert.Equal(t, []string{"v"}, result.Evict)
+ assert.Equal(t, []string{"g", "sd"}, result.TargetSet)
+
+ // Running: q, v, e. Request g.
+ // standard[g,v]: evict q (1) + e (1). Total: 2.
+ // with_rerank[g,v,e]: evict q (1). Total: 1.
+ // creative[g,sd]: evict q (1) + v (50) + e (1). Total: 52.
+ // => pick with_rerank[g,v,e]
+ result, err = solver.Solve("g", []string{"e", "q", "v"})
+ require.NoError(t, err)
+ assert.Equal(t, []string{"q"}, result.Evict)
+ assert.Equal(t, []string{"e", "g", "v"}, result.TargetSet)
+}
diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go
index a04d1cca..ee1d3484 100644
--- a/proxy/proxymanager.go
+++ b/proxy/proxymanager.go
@@ -77,6 +77,9 @@ type ProxyManager struct {
processGroups map[string]*ProcessGroup
+ // matrix-based swap (mutually exclusive with processGroups)
+ matrix *Matrix
+
inFlightCounter *InflightCounter
// shutdown signaling
@@ -203,10 +206,14 @@ func New(proxyConfig config.Config) *ProxyManager {
peerProxy: peerProxy,
}
- // create the process groups
- for groupID := range proxyConfig.Groups {
- processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger)
- pm.processGroups[groupID] = processGroup
+ // create either matrix or process groups (mutually exclusive)
+ if proxyConfig.Matrix != nil {
+ pm.matrix = NewMatrix(proxyConfig, proxyLogger, upstreamLogger)
+ } else {
+ for groupID := range proxyConfig.Groups {
+ processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger)
+ pm.processGroups[groupID] = processGroup
+ }
}
pm.setupGinEngine()
@@ -225,18 +232,29 @@ func New(proxyConfig config.Config) *ProxyManager {
}
proxyLogger.Infof("Preloading model: %s", modelID)
- processGroup, err := pm.swapProcessGroup(modelID)
- if err != nil {
+ var preloadErr error
+ req, _ := http.NewRequest("GET", "/", nil)
+
+ if pm.matrix != nil {
+ preloadErr = pm.matrix.ProxyRequest(modelID, discardWriter, req)
+ } else {
+ processGroup, err := pm.swapProcessGroup(modelID)
+ if err != nil {
+ preloadErr = err
+ } else {
+ preloadErr = processGroup.ProxyRequest(modelID, discardWriter, req)
+ }
+ }
+
+ if preloadErr != nil {
event.Emit(ModelPreloadedEvent{
ModelName: modelID,
Success: false,
})
- proxyLogger.Errorf("Failed to preload model %s: %v", modelID, err)
+ proxyLogger.Errorf("Failed to preload model %s: %v", modelID, preloadErr)
continue
} else {
- req, _ := http.NewRequest("GET", "/", nil)
- processGroup.ProxyRequest(modelID, discardWriter, req)
event.Emit(ModelPreloadedEvent{
ModelName: modelID,
Success: true,
@@ -453,6 +471,11 @@ func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
pm.Lock()
defer pm.Unlock()
+ if pm.matrix != nil {
+ pm.matrix.StopProcesses(strategy)
+ return
+ }
+
// stop Processes in parallel
var wg sync.WaitGroup
for _, processGroup := range pm.processGroups {
@@ -473,6 +496,12 @@ func (pm *ProxyManager) Shutdown() {
pm.proxyLogger.Debug("Shutdown() called in proxy manager")
+ if pm.matrix != nil {
+ pm.matrix.Shutdown()
+ pm.shutdownCancel()
+ return
+ }
+
var wg sync.WaitGroup
// Send shutdown signal to all process in groups
for _, processGroup := range pm.processGroups {
@@ -639,10 +668,16 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
return
}
- processGroup, err := pm.swapProcessGroup(modelID)
- if err != nil {
- pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
- return
+ var handler func(string, http.ResponseWriter, *http.Request) error
+ if pm.matrix != nil {
+ handler = pm.matrix.ProxyRequest
+ } else {
+ processGroup, err := pm.swapProcessGroup(modelID)
+ if err != nil {
+ pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
+ return
+ }
+ handler = processGroup.ProxyRequest
}
// rewrite the path
@@ -651,13 +686,13 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
// attempt to record metrics if it is a POST request
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
- if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
+ if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, handler); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
return
}
} else {
- if err := processGroup.ProxyRequest(modelID, c.Writer, c.Request); err != nil {
+ if err := handler(modelID, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", modelID, originalPath)
return
@@ -683,10 +718,16 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
modelID, found := pm.config.RealModelName(requestedModel)
if found {
- processGroup, err := pm.swapProcessGroup(modelID)
- if err != nil {
- pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
- return
+ var localHandler func(string, http.ResponseWriter, *http.Request) error
+ if pm.matrix != nil {
+ localHandler = pm.matrix.ProxyRequest
+ } else {
+ processGroup, err := pm.swapProcessGroup(modelID)
+ if err != nil {
+ pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
+ return
+ }
+ localHandler = processGroup.ProxyRequest
}
// issue #69 allow custom model names to be sent to upstream
@@ -737,7 +778,7 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
}
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
- nextHandler = processGroup.ProxyRequest
+ nextHandler = localHandler
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
@@ -823,15 +864,19 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
modelID, found := pm.config.RealModelName(requestedModel)
if found {
- processGroup, err := pm.swapProcessGroup(modelID)
- if err != nil {
- pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
- return
+ if pm.matrix != nil {
+ nextHandler = pm.matrix.ProxyRequest
+ } else {
+ processGroup, err := pm.swapProcessGroup(modelID)
+ if err != nil {
+ pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
+ return
+ }
+ nextHandler = processGroup.ProxyRequest
}
useModelName = pm.config.Models[modelID].UseModelName
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
- nextHandler = processGroup.ProxyRequest
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
@@ -942,14 +987,18 @@ func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) {
var modelID string
if realModelID, found := pm.config.RealModelName(requestedModel); found {
- processGroup, err := pm.swapProcessGroup(realModelID)
- if err != nil {
- pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
- return
- }
modelID = realModelID
+ if pm.matrix != nil {
+ nextHandler = pm.matrix.ProxyRequest
+ } else {
+ processGroup, err := pm.swapProcessGroup(realModelID)
+ if err != nil {
+ pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
+ return
+ }
+ nextHandler = processGroup.ProxyRequest
+ }
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
- nextHandler = processGroup.ProxyRequest
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
modelID = requestedModel
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
@@ -1048,9 +1097,9 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
context.Header("Content-Type", "application/json")
runningProcesses := make([]gin.H, 0) // Default to an empty response.
- for _, processGroup := range pm.processGroups {
- for _, process := range processGroup.processes {
- if process.CurrentState() == StateReady {
+ if pm.matrix != nil {
+ for _, modelID := range pm.matrix.RunningModels() {
+ if process, ok := pm.matrix.GetProcess(modelID); ok {
runningProcesses = append(runningProcesses, gin.H{
"model": process.ID,
"state": process.state,
@@ -1062,6 +1111,22 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
})
}
}
+ } else {
+ for _, processGroup := range pm.processGroups {
+ for _, process := range processGroup.processes {
+ if process.CurrentState() == StateReady {
+ runningProcesses = append(runningProcesses, gin.H{
+ "model": process.ID,
+ "state": process.state,
+ "cmd": process.config.Cmd,
+ "proxy": process.config.Proxy,
+ "ttl": process.config.UnloadAfter,
+ "name": process.config.Name,
+ "description": process.config.Description,
+ })
+ }
+ }
+ }
}
// Put the results under the `running` key.
diff --git a/proxy/proxymanager_api.go b/proxy/proxymanager_api.go
index 00897c65..ba0506f7 100644
--- a/proxy/proxymanager_api.go
+++ b/proxy/proxymanager_api.go
@@ -55,27 +55,28 @@ func (pm *ProxyManager) getModelStatus() []Model {
// Iterate over sorted keys
for _, modelID := range modelIDs {
// Get process state
- processGroup := pm.findGroupByModelName(modelID)
state := "unknown"
- if processGroup != nil {
- process := processGroup.processes[modelID]
- if process != nil {
- var stateStr string
- switch process.CurrentState() {
- case StateReady:
- stateStr = "ready"
- case StateStarting:
- stateStr = "starting"
- case StateStopping:
- stateStr = "stopping"
- case StateShutdown:
- stateStr = "shutdown"
- case StateStopped:
- stateStr = "stopped"
- default:
- stateStr = "unknown"
- }
- state = stateStr
+ var process *Process
+ if pm.matrix != nil {
+ process, _ = pm.matrix.GetProcess(modelID)
+ } else {
+ processGroup := pm.findGroupByModelName(modelID)
+ if processGroup != nil {
+ process = processGroup.processes[modelID]
+ }
+ }
+ if process != nil {
+ switch process.CurrentState() {
+ case StateReady:
+ state = "ready"
+ case StateStarting:
+ state = "starting"
+ case StateStopping:
+ state = "stopping"
+ case StateShutdown:
+ state = "shutdown"
+ case StateStopped:
+ state = "stopped"
}
}
models = append(models, Model{
@@ -254,18 +255,23 @@ func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) {
return
}
- processGroup := pm.findGroupByModelName(realModelName)
- if processGroup == nil {
- pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
- return
+ var stopErr error
+ if pm.matrix != nil {
+ stopErr = pm.matrix.StopProcess(realModelName, StopImmediately)
+ } else {
+ processGroup := pm.findGroupByModelName(realModelName)
+ if processGroup == nil {
+ pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
+ return
+ }
+ stopErr = processGroup.StopProcess(realModelName, StopImmediately)
}
- if err := processGroup.StopProcess(realModelName, StopImmediately); err != nil {
- pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", err.Error()))
+ if stopErr != nil {
+ pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", stopErr.Error()))
return
- } else {
- c.String(http.StatusOK, "OK")
}
+ c.String(http.StatusOK, "OK")
}
func (pm *ProxyManager) apiGetVersion(c *gin.Context) {