proxy: add swap matrix with solver-based model swapping (#646)
Add a new swap matrix to supersede groups for running concurrent models. The matrix uses a solver that picks the lowest cost evictions to make a requested model available. This simple approach along with a very basic DSL grammar can enable very complex swapping scenarios. - add DSL parser for set expressions with & (AND), | (OR), (), +ref - add MatrixConfig structs, validation, and topological sort for +ref - add MatrixSolver with cost-minimizing swap decisions - add Matrix runtime integrating solver with Process lifecycle - integrate matrix into ProxyManager with if-branches at all endpoints - update config.example.yaml and config-schema.json with matrix schema - config enforces groups XOR matrix (cannot use both) fixes #643
This commit is contained in:
+46
-47
@@ -2,69 +2,68 @@ name: Linux CI
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ "main" ]
|
branches: ["main"]
|
||||||
# only run when backend source changes
|
# only run when backend source changes
|
||||||
# cmd/ is excluded because it contains utilities without tests
|
# cmd/ is excluded because it contains utilities without tests
|
||||||
paths:
|
paths:
|
||||||
- '**/*.go'
|
- "**/*.go"
|
||||||
- '!cmd/**'
|
- "!cmd/**"
|
||||||
- 'go.mod'
|
- "go.mod"
|
||||||
- 'go.sum'
|
- "go.sum"
|
||||||
- 'Makefile'
|
- "Makefile"
|
||||||
- '.github/workflows/go-ci.yml'
|
- ".github/workflows/go-ci.yml"
|
||||||
|
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ "main" ]
|
branches: ["main"]
|
||||||
paths:
|
paths:
|
||||||
- '**/*.go'
|
- "**/*.go"
|
||||||
- '!cmd/**'
|
- "!cmd/**"
|
||||||
- 'go.mod'
|
- "go.mod"
|
||||||
- 'go.sum'
|
- "go.sum"
|
||||||
- 'Makefile'
|
- "Makefile"
|
||||||
- '.github/workflows/go-ci.yml'
|
- ".github/workflows/go-ci.yml"
|
||||||
|
|
||||||
# Allows manual triggering of the workflow
|
# Allows manual triggering of the workflow
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
||||||
run-tests:
|
run-tests:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v4
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
go-version-file: go.mod
|
go-version-file: go.mod
|
||||||
|
|
||||||
# Only run in this linux based runner
|
# Only run in this linux based runner
|
||||||
- name: Check Formatting
|
- name: Check Formatting
|
||||||
run: |
|
run: |
|
||||||
if [ "$(gofmt -l . | grep -v 'event/.*_test.go' | wc -l)" -gt 0 ]; then
|
if [ "$(gofmt -l . | grep -v 'event/.*_test.go' | wc -l)" -gt 0 ]; then
|
||||||
gofmt -l . | grep -v 'event/.*_test.go'
|
gofmt -l . | grep -v 'event/.*_test.go'
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
# cache simple-responder to save the build time
|
# cache simple-responder to save the build time
|
||||||
- name: Restore Simple Responder
|
- name: Restore Simple Responder
|
||||||
id: restore-simple-responder
|
id: restore-simple-responder
|
||||||
uses: actions/cache/restore@v4
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
path: ./build
|
path: ./build
|
||||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
||||||
|
|
||||||
# necessary for testing proxy/Process swapping
|
# necessary for testing proxy/Process swapping
|
||||||
- name: Create simple-responder
|
- name: Create simple-responder
|
||||||
run: make simple-responder
|
run: make simple-responder
|
||||||
|
|
||||||
- name: Save Simple Responder
|
- name: Save Simple Responder
|
||||||
# nothing new to save ... skip this step
|
# nothing new to save ... skip this step
|
||||||
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||||
id: save-simple-responder
|
id: save-simple-responder
|
||||||
uses: actions/cache/save@v4
|
uses: actions/cache/save@v4
|
||||||
with:
|
with:
|
||||||
path: ./build
|
path: ./build
|
||||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
||||||
|
|
||||||
- name: Test all
|
- name: Test all
|
||||||
run: make test-all
|
run: make test-all
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
# llama-swap
|
# llama-swap
|
||||||
|
|
||||||
Run multiple generative AI models on your machine and hot-swap between them on demand. llama-swap works with any OpenAI and Anthropic API compatible server and is used by thousands of people to power their local AI workflows.
|
Run multiple generative AI models on your machine and hot-swap between them on demand. llama-swap works with any OpenAI and Anthropic API compatible server and is used by thousands of people to power their local AI workflows.
|
||||||
|
|
||||||
Built in Go for performance and simplicity, llama-swap has zero dependencies and is incredibly easy to set up. Get started in minutes - just one binary and one configuration file.
|
Built in Go for performance and simplicity, llama-swap has zero dependencies and is incredibly easy to set up. Get started in minutes - just one binary and one configuration file.
|
||||||
|
|
||||||
@@ -45,14 +45,14 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
|
|||||||
- `/health` - just returns "OK"
|
- `/health` - just returns "OK"
|
||||||
- ✅ API Key support - define keys to restrict access to API endpoints
|
- ✅ API Key support - define keys to restrict access to API endpoints
|
||||||
- ✅ Customizable
|
- ✅ Customizable
|
||||||
- Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
|
- Run concurrent models with a custom DSL swap matrix ([#643](https://github.com/mostlygeek/llama-swap/issues/643))
|
||||||
- Automatic unloading of models after timeout by setting a `ttl`
|
- Automatic unloading of models after timeout by setting a `ttl`
|
||||||
- Reliable Docker and Podman support using `cmd` and `cmdStop` together
|
- Reliable Docker and Podman support using `cmd` and `cmdStop` together
|
||||||
- Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
|
- Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
|
||||||
|
|
||||||
### Web UI
|
### Web UI
|
||||||
|
|
||||||
llama-swap includes a real time web interface with a playground for testing out all sorts of local models:
|
llama-swap includes a real time web interface with a playground for testing out all sorts of local models:
|
||||||
|
|
||||||
<img width="1125" height="876" alt="image" src="https://github.com/user-attachments/assets/8ee41947-97af-463d-b0f0-8e9c478fac07" />
|
<img width="1125" height="876" alt="image" src="https://github.com/user-attachments/assets/8ee41947-97af-463d-b0f0-8e9c478fac07" />
|
||||||
|
|
||||||
@@ -64,16 +64,14 @@ Inspect request and responses:
|
|||||||
|
|
||||||
<img width="1111" height="720" alt="image" src="https://github.com/user-attachments/assets/24fe4aca-1448-4d7c-b9e8-a967589bda6c" />
|
<img width="1111" height="720" alt="image" src="https://github.com/user-attachments/assets/24fe4aca-1448-4d7c-b9e8-a967589bda6c" />
|
||||||
|
|
||||||
Manually load and unload models:
|
Manually load and unload models:
|
||||||
|
|
||||||
<img width="1109" height="719" alt="image" src="https://github.com/user-attachments/assets/02b1e1f2-abd0-4050-84ae-facd66ff01c4" />
|
<img width="1109" height="719" alt="image" src="https://github.com/user-attachments/assets/02b1e1f2-abd0-4050-84ae-facd66ff01c4" />
|
||||||
|
|
||||||
|
Real time log streaming:
|
||||||
Real time log streaming:
|
|
||||||
|
|
||||||
<img width="1107" height="559" alt="image" src="https://github.com/user-attachments/assets/39669a10-cff2-409e-836a-5bad8bd0140c" />
|
<img width="1107" height="559" alt="image" src="https://github.com/user-attachments/assets/39669a10-cff2-409e-836a-5bad8bd0140c" />
|
||||||
|
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
llama-swap can be installed in multiple ways
|
llama-swap can be installed in multiple ways
|
||||||
|
|||||||
+61
-1
@@ -325,6 +325,44 @@
|
|||||||
},
|
},
|
||||||
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
|
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
|
||||||
},
|
},
|
||||||
|
"matrix": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Solver-based alternative to groups. Declares valid combinations of concurrent models. The solver minimizes eviction cost when swapping. A config must use either groups or matrix, not both.",
|
||||||
|
"required": [
|
||||||
|
"vars",
|
||||||
|
"sets"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"vars": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Short names for models. Keys must be alphanumeric, 1-8 characters. All sets and evict_costs must use these IDs.",
|
||||||
|
"minProperties": 1,
|
||||||
|
"additionalProperties": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"propertyNames": {
|
||||||
|
"pattern": "^[a-zA-Z0-9]{1,8}$"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"evict_costs": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Relative cost of evicting a running model. Models not listed default to 1. Values must be positive integers.",
|
||||||
|
"additionalProperties": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 1
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"sets": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Named sets of concurrent model combinations. Values are DSL strings using & (AND), | (OR), () (grouping), and +ref (inline another set). Definition order is used for tie-breaking.",
|
||||||
|
"minProperties": 1,
|
||||||
|
"additionalProperties": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false
|
||||||
|
},
|
||||||
"hooks": {
|
"hooks": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -456,5 +494,27 @@
|
|||||||
"default": {},
|
"default": {},
|
||||||
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
|
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"allOf": [
|
||||||
|
{
|
||||||
|
"if": {
|
||||||
|
"required": ["groups"]
|
||||||
|
},
|
||||||
|
"then": {
|
||||||
|
"not": {
|
||||||
|
"required": ["matrix"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"if": {
|
||||||
|
"required": ["matrix"]
|
||||||
|
},
|
||||||
|
"then": {
|
||||||
|
"not": {
|
||||||
|
"required": ["groups"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
+71
-56
@@ -331,68 +331,83 @@ models:
|
|||||||
# - processes have 5 seconds to shutdown until forceful termination is attempted
|
# - processes have 5 seconds to shutdown until forceful termination is attempted
|
||||||
cmdStop: docker stop ${MODEL_ID}
|
cmdStop: docker stop ${MODEL_ID}
|
||||||
|
|
||||||
# groups: a dictionary of group settings
|
# =============================================================================
|
||||||
# - optional, default: empty dictionary
|
# matrix: run concurrent models with a solver-based swap DSL
|
||||||
# - provides advanced controls over model swapping behaviour
|
# =============================================================================
|
||||||
# - using groups some models can be kept loaded indefinitely, while others are swapped out
|
|
||||||
# - model IDs must be defined in the Models section
|
|
||||||
# - a model can only be a member of one group
|
|
||||||
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
|
|
||||||
# - see issue #109 for details
|
|
||||||
#
|
#
|
||||||
# NOTE: the example below uses model names that are not defined above for demonstration purposes
|
# Note:
|
||||||
groups:
|
# A config must use either a matrix or legacy groups, not both. A configuration error
|
||||||
# group1 works the same as the default behaviour of llama-swap where only one model is allowed
|
# will occur if both are defined. Configuration examples for legacy Groups can be found:
|
||||||
# to run a time across the whole llama-swap instance
|
# https://github.com/mostlygeek/llama-swap/blob/40e39f7/config.example.yaml#L334-L396
|
||||||
"group1":
|
#
|
||||||
# swap: controls the model swapping behaviour in within the group
|
# The matrix declares valid combinations of models that can run concurrently.
|
||||||
# - optional, default: true
|
# When a model is requested, the solver finds the cheapest way to make it
|
||||||
# - true : only one model is allowed to run at a time
|
# available by evicting as few (and least costly) running models as possible.
|
||||||
# - false: all models can run together, no swapping
|
#
|
||||||
swap: true
|
# Solver behavior:
|
||||||
|
# 1. Request arrives for model X
|
||||||
|
# 2. If X is already running, forward immediately. Done.
|
||||||
|
# 3. Find all sets containing X
|
||||||
|
# 4. For each candidate set, compute cost: sum of evict_costs for
|
||||||
|
# every running model NOT in that set
|
||||||
|
# 5. Pick lowest cost candidate. Ties broken by definition order.
|
||||||
|
# 6. Evict what needs to stop. Start X. Forward request.
|
||||||
|
#
|
||||||
|
# Subset semantics: a set [a, b, c] means any subset is valid.
|
||||||
|
# Only the requested model is started — others are not preloaded.
|
||||||
|
#
|
||||||
|
# A model not appearing in any set can only run alone.
|
||||||
|
#
|
||||||
|
matrix:
|
||||||
|
# vars: short names for models (alphanumeric, 1-8 chars)
|
||||||
|
# - required for sets and evict_costs settings
|
||||||
|
# - each entry is a short name to a real model ID. Do not use an alias
|
||||||
|
# - used to keep set DSL logic short and easier to read
|
||||||
|
# - sets and evict_costs only use identifiers defined in vars
|
||||||
|
vars:
|
||||||
|
g: gemma-model
|
||||||
|
q: qwen-model
|
||||||
|
m: mistral-model
|
||||||
|
v: voxtral-model
|
||||||
|
e: reranker-model
|
||||||
|
L: llama-70B
|
||||||
|
sd: stable-diffusion
|
||||||
|
|
||||||
# exclusive: controls how the group affects other groups
|
# evict_costs: relative cost of losing a running model (default: 1)
|
||||||
# - optional, default: true
|
evict_costs:
|
||||||
# - true: causes all other groups to unload when this group runs a model
|
v: 50 # vllm backend, slow cold start
|
||||||
# - false: does not affect other groups
|
L: 30 # 70B weights, slow to load
|
||||||
exclusive: true
|
|
||||||
|
|
||||||
# members references the models defined above
|
# sets: named sets of concurrent model combinations
|
||||||
# required
|
# Values are DSL strings with operators:
|
||||||
members:
|
# & AND (models run together)
|
||||||
- "llama"
|
# | OR (alternatives)
|
||||||
- "qwen-unlisted"
|
# () grouping
|
||||||
|
# +ref inline another set's expression
|
||||||
|
#
|
||||||
|
# Expansion examples:
|
||||||
|
# "L" → [L]
|
||||||
|
# "a & b" → [a, b]
|
||||||
|
# "a | b" → [a], [b]
|
||||||
|
# "(a | b) & c" → [a, c], [b, c]
|
||||||
|
# "(a | b) & (c | d)" → [a,c], [a,d], [b,c], [b,d]
|
||||||
|
# "+llms & v" → expands llms inline, then applies & v
|
||||||
|
sets:
|
||||||
|
# LLM + TTS: switching between g/q/m won't evict v
|
||||||
|
# expands to: [g,v], [q,v], [m,v]
|
||||||
|
standard: "(g | q | m) & v"
|
||||||
|
|
||||||
# Example:
|
# LLM + TTS + reranker
|
||||||
# - in group2 all models can run at the same time
|
# expands to: [g,v,e], [q,v,e]
|
||||||
# - when a different group is loaded it causes all running models in this group to unload
|
with_rerank: "(g | q) & v & e"
|
||||||
"group2":
|
|
||||||
swap: false
|
|
||||||
|
|
||||||
# exclusive: false does not unload other groups when a model in group2 is requested
|
# LLM + image generation, no TTS
|
||||||
# - the models in group2 will be loaded but will not unload any other groups
|
# expands to: [g,sd], [q,sd]
|
||||||
exclusive: false
|
creative: "(g | q) & sd"
|
||||||
members:
|
|
||||||
- "docker-llama"
|
|
||||||
- "modelA"
|
|
||||||
- "modelB"
|
|
||||||
|
|
||||||
# Example:
|
# 70B model uses all GPUs, can only run alone
|
||||||
# - a persistent group, prevents other groups from unloading it
|
# expands to: [L]
|
||||||
"forever":
|
full: "L"
|
||||||
# persistent: prevents over groups from unloading the models in this group
|
|
||||||
# - optional, default: false
|
|
||||||
# - does not affect individual model behaviour
|
|
||||||
persistent: true
|
|
||||||
|
|
||||||
# set swap/exclusive to false to prevent swapping inside the group
|
|
||||||
# and the unloading of other groups
|
|
||||||
swap: false
|
|
||||||
exclusive: false
|
|
||||||
members:
|
|
||||||
- "forever-modelA"
|
|
||||||
- "forever-modelB"
|
|
||||||
- "forever-modelc"
|
|
||||||
|
|
||||||
# hooks: a dictionary of event triggers and actions
|
# hooks: a dictionary of event triggers and actions
|
||||||
# - optional, default: empty dictionary
|
# - optional, default: empty dictionary
|
||||||
|
|||||||
+32
-13
@@ -129,6 +129,12 @@ type Config struct {
|
|||||||
Profiles map[string][]string `yaml:"profiles"`
|
Profiles map[string][]string `yaml:"profiles"`
|
||||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||||
|
|
||||||
|
// swap matrix: solver-based alternative to groups
|
||||||
|
Matrix *MatrixConfig `yaml:"matrix"`
|
||||||
|
|
||||||
|
// populated during validation when matrix is configured
|
||||||
|
ExpandedSets []ExpandedSet `yaml:"-"`
|
||||||
|
|
||||||
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
||||||
Macros MacroList `yaml:"macros"`
|
Macros MacroList `yaml:"macros"`
|
||||||
|
|
||||||
@@ -438,22 +444,35 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
config.Models[modelId] = modelConfig
|
config.Models[modelId] = modelConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
config = AddDefaultGroupToConfig(config)
|
// groups XOR matrix
|
||||||
|
if config.Matrix != nil && len(config.Groups) > 0 {
|
||||||
|
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
|
||||||
|
}
|
||||||
|
|
||||||
// Validate group members
|
if config.Matrix != nil {
|
||||||
memberUsage := make(map[string]string)
|
expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
|
||||||
for groupID, groupConfig := range config.Groups {
|
if err != nil {
|
||||||
prevSet := make(map[string]bool)
|
return Config{}, fmt.Errorf("matrix: %w", err)
|
||||||
for _, member := range groupConfig.Members {
|
}
|
||||||
if _, found := prevSet[member]; found {
|
config.ExpandedSets = expandedSets
|
||||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
} else {
|
||||||
}
|
config = AddDefaultGroupToConfig(config)
|
||||||
prevSet[member] = true
|
|
||||||
|
|
||||||
if existingGroup, exists := memberUsage[member]; exists {
|
// Validate group members
|
||||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
memberUsage := make(map[string]string)
|
||||||
|
for groupID, groupConfig := range config.Groups {
|
||||||
|
prevSet := make(map[string]bool)
|
||||||
|
for _, member := range groupConfig.Members {
|
||||||
|
if _, found := prevSet[member]; found {
|
||||||
|
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||||
|
}
|
||||||
|
prevSet[member] = true
|
||||||
|
|
||||||
|
if existingGroup, exists := memberUsage[member]; exists {
|
||||||
|
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||||
|
}
|
||||||
|
memberUsage[member] = groupID
|
||||||
}
|
}
|
||||||
memberUsage[member] = groupID
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,226 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
var varKeyPattern = regexp.MustCompile(`^[a-zA-Z0-9]{1,8}$`)
|
||||||
|
|
||||||
|
// MatrixConfig represents the swap matrix configuration block.
|
||||||
|
type MatrixConfig struct {
|
||||||
|
Var map[string]string `yaml:"vars"`
|
||||||
|
EvictCosts map[string]int `yaml:"evict_costs"`
|
||||||
|
Sets OrderedSets `yaml:"sets"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEntry is a single named set with its DSL expression.
|
||||||
|
type SetEntry struct {
|
||||||
|
Name string
|
||||||
|
DSL string
|
||||||
|
}
|
||||||
|
|
||||||
|
// OrderedSets preserves YAML definition order of sets (used for tie-breaking).
|
||||||
|
type OrderedSets []SetEntry
|
||||||
|
|
||||||
|
func (os *OrderedSets) UnmarshalYAML(value *yaml.Node) error {
|
||||||
|
if value.Kind != yaml.MappingNode {
|
||||||
|
return fmt.Errorf("sets must be a mapping")
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := make([]SetEntry, 0, len(value.Content)/2)
|
||||||
|
for i := 0; i < len(value.Content); i += 2 {
|
||||||
|
keyNode := value.Content[i]
|
||||||
|
valueNode := value.Content[i+1]
|
||||||
|
|
||||||
|
var name string
|
||||||
|
if err := keyNode.Decode(&name); err != nil {
|
||||||
|
return fmt.Errorf("failed to decode set name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var dsl string
|
||||||
|
if err := valueNode.Decode(&dsl); err != nil {
|
||||||
|
return fmt.Errorf("failed to decode DSL for set %q: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries = append(entries, SetEntry{Name: name, DSL: dsl})
|
||||||
|
}
|
||||||
|
|
||||||
|
*os = entries
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpandedSet is one valid combination of concurrent models (real model names).
|
||||||
|
type ExpandedSet struct {
|
||||||
|
SetName string
|
||||||
|
DSL string
|
||||||
|
Models []string // real model names, sorted
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateMatrix validates the matrix config and returns all expanded sets.
|
||||||
|
func ValidateMatrix(matrix MatrixConfig, models map[string]ModelConfig) ([]ExpandedSet, error) {
|
||||||
|
if len(matrix.Sets) == 0 {
|
||||||
|
return nil, fmt.Errorf("matrix must define at least one set")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(matrix.Var) == 0 {
|
||||||
|
return nil, fmt.Errorf("matrix must define at least one var")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate var entries
|
||||||
|
if matrix.Var != nil {
|
||||||
|
for id, modelName := range matrix.Var {
|
||||||
|
if !varKeyPattern.MatchString(id) {
|
||||||
|
return nil, fmt.Errorf("var key %q must be alphanumeric and 1-8 characters", id)
|
||||||
|
}
|
||||||
|
if _, exists := models[modelName]; !exists {
|
||||||
|
return nil, fmt.Errorf("var key %q references unknown model %q", id, modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate evict_costs
|
||||||
|
if matrix.EvictCosts != nil {
|
||||||
|
for key, cost := range matrix.EvictCosts {
|
||||||
|
if cost <= 0 {
|
||||||
|
return nil, fmt.Errorf("evict_cost for %q must be a positive integer, got %d", key, cost)
|
||||||
|
}
|
||||||
|
if _, ok := matrix.Var[key]; !ok {
|
||||||
|
return nil, fmt.Errorf("evict_costs: unknown var ID %q", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build dependency graph for +ref topological sort
|
||||||
|
setNames := make(map[string]bool)
|
||||||
|
for _, entry := range matrix.Sets {
|
||||||
|
setNames[entry.Name] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
deps := make(map[string][]string) // setName -> set names it depends on
|
||||||
|
for _, entry := range matrix.Sets {
|
||||||
|
refs, err := extractRefs(entry.DSL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("set %q: %w", entry.Name, err)
|
||||||
|
}
|
||||||
|
for _, ref := range refs {
|
||||||
|
if !setNames[ref] {
|
||||||
|
return nil, fmt.Errorf("set %q references undefined set %q", entry.Name, ref)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
deps[entry.Name] = refs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Topological sort with cycle detection
|
||||||
|
order, err := topologicalSort(matrix.Sets, deps)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expand sets in topological order
|
||||||
|
resolvedRefs := make(map[string][][]string) // set name -> expanded alias-level combos
|
||||||
|
var allExpanded []ExpandedSet
|
||||||
|
totalCombinations := 0
|
||||||
|
|
||||||
|
// Build ordered map for efficient lookup
|
||||||
|
setDSL := make(map[string]string)
|
||||||
|
for _, entry := range matrix.Sets {
|
||||||
|
setDSL[entry.Name] = entry.DSL
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range order {
|
||||||
|
dsl := setDSL[name]
|
||||||
|
combos, err := ParseAndExpandDSL(dsl, resolvedRefs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("set %q: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedRefs[name] = combos
|
||||||
|
|
||||||
|
// Resolve var IDs to real model names
|
||||||
|
for _, combo := range combos {
|
||||||
|
resolved := make([]string, len(combo))
|
||||||
|
for i, ident := range combo {
|
||||||
|
realName, ok := matrix.Var[ident]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("set %q: unknown var ID %q", name, ident)
|
||||||
|
}
|
||||||
|
resolved[i] = realName
|
||||||
|
}
|
||||||
|
sort.Strings(resolved)
|
||||||
|
allExpanded = append(allExpanded, ExpandedSet{
|
||||||
|
SetName: name,
|
||||||
|
DSL: dsl,
|
||||||
|
Models: resolved,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
totalCombinations += len(combos)
|
||||||
|
if totalCombinations > maxDSLExpansions {
|
||||||
|
return nil, fmt.Errorf("total expanded combinations (%d) exceed limit of %d", totalCombinations, maxDSLExpansions)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return allExpanded, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// topologicalSort returns set names in dependency order.
|
||||||
|
// Returns an error if a cycle is detected.
|
||||||
|
func topologicalSort(sets OrderedSets, deps map[string][]string) ([]string, error) {
|
||||||
|
// States: 0 = unvisited, 1 = visiting, 2 = visited
|
||||||
|
state := make(map[string]int)
|
||||||
|
var order []string
|
||||||
|
|
||||||
|
var visit func(name string) error
|
||||||
|
visit = func(name string) error {
|
||||||
|
switch state[name] {
|
||||||
|
case 1:
|
||||||
|
return fmt.Errorf("circular reference detected involving set %q", name)
|
||||||
|
case 2:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
state[name] = 1
|
||||||
|
|
||||||
|
for _, dep := range deps[name] {
|
||||||
|
if err := visit(dep); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
state[name] = 2
|
||||||
|
order = append(order, name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Visit in definition order for deterministic output
|
||||||
|
for _, entry := range sets {
|
||||||
|
if state[entry.Name] == 0 {
|
||||||
|
if err := visit(entry.Name); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return order, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolvedEvictCosts returns a map of real model name -> evict cost,
|
||||||
|
// resolving var IDs. Models not listed default to 1.
|
||||||
|
func (m *MatrixConfig) ResolvedEvictCosts() map[string]int {
|
||||||
|
costs := make(map[string]int)
|
||||||
|
if m.EvictCosts == nil {
|
||||||
|
return costs
|
||||||
|
}
|
||||||
|
for key, cost := range m.EvictCosts {
|
||||||
|
// Resolve var ID if present
|
||||||
|
if realName, ok := m.Var[key]; ok {
|
||||||
|
costs[realName] = cost
|
||||||
|
} else {
|
||||||
|
costs[key] = cost
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return costs
|
||||||
|
}
|
||||||
@@ -0,0 +1,376 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
const maxDSLExpansions = 1000
|
||||||
|
|
||||||
|
// Token types for the DSL lexer
|
||||||
|
type tokenType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
tokIdent tokenType = iota // model alias or name
|
||||||
|
tokAnd // &
|
||||||
|
tokOr // |
|
||||||
|
tokLParen // (
|
||||||
|
tokRParen // )
|
||||||
|
tokRef // +setName
|
||||||
|
tokEOF
|
||||||
|
)
|
||||||
|
|
||||||
|
type token struct {
|
||||||
|
typ tokenType
|
||||||
|
val string
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokenize splits a DSL string into tokens.
|
||||||
|
func tokenize(input string) ([]token, error) {
|
||||||
|
var tokens []token
|
||||||
|
i := 0
|
||||||
|
runes := []rune(input)
|
||||||
|
|
||||||
|
for i < len(runes) {
|
||||||
|
ch := runes[i]
|
||||||
|
|
||||||
|
// skip whitespace
|
||||||
|
if unicode.IsSpace(ch) {
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ch {
|
||||||
|
case '&':
|
||||||
|
tokens = append(tokens, token{tokAnd, "&"})
|
||||||
|
i++
|
||||||
|
case '|':
|
||||||
|
tokens = append(tokens, token{tokOr, "|"})
|
||||||
|
i++
|
||||||
|
case '(':
|
||||||
|
tokens = append(tokens, token{tokLParen, "("})
|
||||||
|
i++
|
||||||
|
case ')':
|
||||||
|
tokens = append(tokens, token{tokRParen, ")"})
|
||||||
|
i++
|
||||||
|
case '+':
|
||||||
|
// +ref: read the identifier that follows
|
||||||
|
i++
|
||||||
|
start := i
|
||||||
|
for i < len(runes) && isIdentChar(runes[i]) {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
if i == start {
|
||||||
|
return nil, fmt.Errorf("expected set name after '+' at position %d", start)
|
||||||
|
}
|
||||||
|
tokens = append(tokens, token{tokRef, string(runes[start:i])})
|
||||||
|
default:
|
||||||
|
if isIdentChar(ch) {
|
||||||
|
start := i
|
||||||
|
for i < len(runes) && isIdentChar(runes[i]) {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
tokens = append(tokens, token{tokIdent, string(runes[start:i])})
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("unexpected character %q at position %d", ch, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens = append(tokens, token{tokEOF, ""})
|
||||||
|
return tokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIdentChar(ch rune) bool {
|
||||||
|
return unicode.IsLetter(ch) || unicode.IsDigit(ch) || ch == '_' || ch == '-' || ch == '.'
|
||||||
|
}
|
||||||
|
|
||||||
|
// AST node types
|
||||||
|
type dslNode interface {
|
||||||
|
dslNode()
|
||||||
|
}
|
||||||
|
|
||||||
|
type andNode struct {
|
||||||
|
children []dslNode
|
||||||
|
}
|
||||||
|
|
||||||
|
type orNode struct {
|
||||||
|
children []dslNode
|
||||||
|
}
|
||||||
|
|
||||||
|
type leafNode struct {
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type refNode struct {
|
||||||
|
setName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (andNode) dslNode() {}
|
||||||
|
func (orNode) dslNode() {}
|
||||||
|
func (leafNode) dslNode() {}
|
||||||
|
func (refNode) dslNode() {}
|
||||||
|
|
||||||
|
// parser holds state for recursive-descent parsing.
|
||||||
|
type parser struct {
|
||||||
|
tokens []token
|
||||||
|
pos int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) peek() token {
|
||||||
|
if p.pos < len(p.tokens) {
|
||||||
|
return p.tokens[p.pos]
|
||||||
|
}
|
||||||
|
return token{tokEOF, ""}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) next() token {
|
||||||
|
t := p.peek()
|
||||||
|
if t.typ != tokEOF {
|
||||||
|
p.pos++
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) expect(typ tokenType) (token, error) {
|
||||||
|
t := p.next()
|
||||||
|
if t.typ != typ {
|
||||||
|
return t, fmt.Errorf("expected token type %d, got %q", typ, t.val)
|
||||||
|
}
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Grammar:
|
||||||
|
//
|
||||||
|
// expr = andExpr
|
||||||
|
// andExpr = orExpr ('&' orExpr)*
|
||||||
|
// orExpr = atom ('|' atom)*
|
||||||
|
// atom = ident | '+' ident | '(' expr ')'
|
||||||
|
//
|
||||||
|
// & binds tighter than |, so "a | b & c" means "a | (b & c)"
|
||||||
|
func parse(tokens []token) (dslNode, error) {
|
||||||
|
p := &parser{tokens: tokens}
|
||||||
|
node, err := p.parseExpr()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if p.peek().typ != tokEOF {
|
||||||
|
return nil, fmt.Errorf("unexpected token %q after expression", p.peek().val)
|
||||||
|
}
|
||||||
|
return node, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) parseExpr() (dslNode, error) {
|
||||||
|
return p.parseOrExpr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) parseOrExpr() (dslNode, error) {
|
||||||
|
left, err := p.parseAndExpr()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.peek().typ == tokOr {
|
||||||
|
children := []dslNode{left}
|
||||||
|
for p.peek().typ == tokOr {
|
||||||
|
p.next() // consume |
|
||||||
|
right, err := p.parseAndExpr()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
children = append(children, right)
|
||||||
|
}
|
||||||
|
return orNode{children: children}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return left, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) parseAndExpr() (dslNode, error) {
|
||||||
|
left, err := p.parseAtom()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.peek().typ == tokAnd {
|
||||||
|
children := []dslNode{left}
|
||||||
|
for p.peek().typ == tokAnd {
|
||||||
|
p.next() // consume &
|
||||||
|
right, err := p.parseAtom()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
children = append(children, right)
|
||||||
|
}
|
||||||
|
return andNode{children: children}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return left, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) parseAtom() (dslNode, error) {
|
||||||
|
t := p.peek()
|
||||||
|
|
||||||
|
switch t.typ {
|
||||||
|
case tokIdent:
|
||||||
|
p.next()
|
||||||
|
return leafNode{name: t.val}, nil
|
||||||
|
|
||||||
|
case tokRef:
|
||||||
|
p.next()
|
||||||
|
return refNode{setName: t.val}, nil
|
||||||
|
|
||||||
|
case tokLParen:
|
||||||
|
p.next() // consume (
|
||||||
|
node, err := p.parseExpr()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if _, err := p.expect(tokRParen); err != nil {
|
||||||
|
return nil, fmt.Errorf("missing closing parenthesis")
|
||||||
|
}
|
||||||
|
return node, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unexpected token %q", t.val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// expand walks the AST and produces all combinations.
|
||||||
|
// resolvedRefs contains previously expanded sets for +ref resolution.
|
||||||
|
func expand(node dslNode, resolvedRefs map[string][][]string) ([][]string, error) {
|
||||||
|
switch n := node.(type) {
|
||||||
|
case leafNode:
|
||||||
|
return [][]string{{n.name}}, nil
|
||||||
|
|
||||||
|
case refNode:
|
||||||
|
expanded, ok := resolvedRefs[n.setName]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unknown set reference +%s", n.setName)
|
||||||
|
}
|
||||||
|
// Return a copy
|
||||||
|
result := make([][]string, len(expanded))
|
||||||
|
for i, combo := range expanded {
|
||||||
|
result[i] = make([]string, len(combo))
|
||||||
|
copy(result[i], combo)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
|
||||||
|
case orNode:
|
||||||
|
// Union of all children's expansions
|
||||||
|
var result [][]string
|
||||||
|
for _, child := range n.children {
|
||||||
|
childResult, err := expand(child, resolvedRefs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, childResult...)
|
||||||
|
if len(result) > maxDSLExpansions {
|
||||||
|
return nil, fmt.Errorf("DSL expansion exceeded %d combinations", maxDSLExpansions)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
|
||||||
|
case andNode:
|
||||||
|
// Cartesian product across children
|
||||||
|
result := [][]string{{}} // start with one empty combo
|
||||||
|
for _, child := range n.children {
|
||||||
|
childResult, err := expand(child, resolvedRefs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result, err = cartesianProduct(result, childResult, maxDSLExpansions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown node type %T", node)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cartesianProduct computes the cartesian product of two sets of combinations.
|
||||||
|
// It returns an error if the product would exceed cap.
|
||||||
|
func cartesianProduct(left, right [][]string, cap int) ([][]string, error) {
|
||||||
|
if int64(len(left))*int64(len(right)) > int64(cap) {
|
||||||
|
return nil, fmt.Errorf("DSL expansion exceeded %d combinations", cap)
|
||||||
|
}
|
||||||
|
result := make([][]string, 0, len(left)*len(right))
|
||||||
|
for _, l := range left {
|
||||||
|
for _, r := range right {
|
||||||
|
combo := make([]string, 0, len(l)+len(r))
|
||||||
|
combo = append(combo, l...)
|
||||||
|
combo = append(combo, r...)
|
||||||
|
result = append(result, combo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseAndExpandDSL tokenizes, parses, and expands a DSL string.
|
||||||
|
// resolvedRefs contains previously expanded sets for +ref inlining.
|
||||||
|
func ParseAndExpandDSL(dsl string, resolvedRefs map[string][][]string) ([][]string, error) {
|
||||||
|
dsl = strings.TrimSpace(dsl)
|
||||||
|
if dsl == "" {
|
||||||
|
return nil, fmt.Errorf("empty DSL expression")
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens, err := tokenize(dsl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("tokenize: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tree, err := parse(tokens)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := expand(tree, resolvedRefs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deduplicate models within each combination and sort for consistency
|
||||||
|
for i, combo := range result {
|
||||||
|
result[i] = dedupAndSort(combo)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// dedupAndSort removes duplicate entries and sorts alphabetically.
|
||||||
|
func dedupAndSort(items []string) []string {
|
||||||
|
seen := make(map[string]bool, len(items))
|
||||||
|
var unique []string
|
||||||
|
for _, item := range items {
|
||||||
|
if !seen[item] {
|
||||||
|
seen[item] = true
|
||||||
|
unique = append(unique, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Strings(unique)
|
||||||
|
return unique
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractRefs scans a DSL string for +ref tokens without full parsing.
|
||||||
|
// Used for building the dependency graph for topological sorting.
|
||||||
|
func extractRefs(dsl string) ([]string, error) {
|
||||||
|
tokens, err := tokenize(dsl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var refs []string
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
for _, t := range tokens {
|
||||||
|
if t.typ == tokRef && !seen[t.val] {
|
||||||
|
seen[t.val] = true
|
||||||
|
refs = append(refs, t.val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return refs, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,300 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDSL_Tokenize(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expect []token
|
||||||
|
errMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single identifier",
|
||||||
|
input: "abc",
|
||||||
|
expect: []token{
|
||||||
|
{tokIdent, "abc"},
|
||||||
|
{tokEOF, ""},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "identifier with hyphens and dots",
|
||||||
|
input: "model-name.v2",
|
||||||
|
expect: []token{
|
||||||
|
{tokIdent, "model-name.v2"},
|
||||||
|
{tokEOF, ""},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "and expression",
|
||||||
|
input: "a & b",
|
||||||
|
expect: []token{
|
||||||
|
{tokIdent, "a"},
|
||||||
|
{tokAnd, "&"},
|
||||||
|
{tokIdent, "b"},
|
||||||
|
{tokEOF, ""},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "or expression",
|
||||||
|
input: "a | b",
|
||||||
|
expect: []token{
|
||||||
|
{tokIdent, "a"},
|
||||||
|
{tokOr, "|"},
|
||||||
|
{tokIdent, "b"},
|
||||||
|
{tokEOF, ""},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "parentheses",
|
||||||
|
input: "(a | b) & c",
|
||||||
|
expect: []token{
|
||||||
|
{tokLParen, "("},
|
||||||
|
{tokIdent, "a"},
|
||||||
|
{tokOr, "|"},
|
||||||
|
{tokIdent, "b"},
|
||||||
|
{tokRParen, ")"},
|
||||||
|
{tokAnd, "&"},
|
||||||
|
{tokIdent, "c"},
|
||||||
|
{tokEOF, ""},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ref token",
|
||||||
|
input: "+llms & v",
|
||||||
|
expect: []token{
|
||||||
|
{tokRef, "llms"},
|
||||||
|
{tokAnd, "&"},
|
||||||
|
{tokIdent, "v"},
|
||||||
|
{tokEOF, ""},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no whitespace",
|
||||||
|
input: "(a|b)&c",
|
||||||
|
expect: []token{
|
||||||
|
{tokLParen, "("},
|
||||||
|
{tokIdent, "a"},
|
||||||
|
{tokOr, "|"},
|
||||||
|
{tokIdent, "b"},
|
||||||
|
{tokRParen, ")"},
|
||||||
|
{tokAnd, "&"},
|
||||||
|
{tokIdent, "c"},
|
||||||
|
{tokEOF, ""},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty ref",
|
||||||
|
input: "+",
|
||||||
|
errMsg: "expected set name after '+'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid character",
|
||||||
|
input: "a @ b",
|
||||||
|
errMsg: "unexpected character",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tokens, err := tokenize(tt.input)
|
||||||
|
if tt.errMsg != "" {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.errMsg)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expect, tokens)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDSL_ParseAndExpand(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
dsl string
|
||||||
|
refs map[string][][]string
|
||||||
|
expect [][]string
|
||||||
|
errMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single model",
|
||||||
|
dsl: "L",
|
||||||
|
expect: [][]string{{"L"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two models with AND",
|
||||||
|
dsl: "a & b",
|
||||||
|
expect: [][]string{{"a", "b"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two models with OR",
|
||||||
|
dsl: "a | b",
|
||||||
|
expect: [][]string{{"a"}, {"b"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "three models with OR",
|
||||||
|
dsl: "a | b | c",
|
||||||
|
expect: [][]string{{"a"}, {"b"}, {"c"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cartesian product (a|b) & (c|d)",
|
||||||
|
dsl: "(a | b) & (c | d)",
|
||||||
|
expect: [][]string{
|
||||||
|
{"a", "c"},
|
||||||
|
{"a", "d"},
|
||||||
|
{"b", "c"},
|
||||||
|
{"b", "d"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "three-way AND",
|
||||||
|
dsl: "a & b & c",
|
||||||
|
expect: [][]string{
|
||||||
|
{"a", "b", "c"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "(g | q | m) & v",
|
||||||
|
dsl: "(g | q | m) & v",
|
||||||
|
expect: [][]string{
|
||||||
|
{"g", "v"},
|
||||||
|
{"q", "v"},
|
||||||
|
{"m", "v"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "(g | q) & v & e",
|
||||||
|
dsl: "(g | q) & v & e",
|
||||||
|
expect: [][]string{
|
||||||
|
{"e", "g", "v"},
|
||||||
|
{"e", "q", "v"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "precedence: a | b & c means a | (b & c)",
|
||||||
|
dsl: "a | b & c",
|
||||||
|
expect: [][]string{
|
||||||
|
{"a"},
|
||||||
|
{"b", "c"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "+ref inlining",
|
||||||
|
dsl: "+llms & v",
|
||||||
|
refs: map[string][][]string{
|
||||||
|
"llms": {{"g"}, {"q"}, {"m"}},
|
||||||
|
},
|
||||||
|
expect: [][]string{
|
||||||
|
{"g", "v"},
|
||||||
|
{"q", "v"},
|
||||||
|
{"m", "v"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "+ref chained",
|
||||||
|
dsl: "+with_tts & e",
|
||||||
|
refs: map[string][][]string{
|
||||||
|
"with_tts": {{"g", "v"}, {"q", "v"}, {"m", "v"}},
|
||||||
|
},
|
||||||
|
expect: [][]string{
|
||||||
|
{"e", "g", "v"},
|
||||||
|
{"e", "q", "v"},
|
||||||
|
{"e", "m", "v"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dedup within combination",
|
||||||
|
dsl: "a & a",
|
||||||
|
expect: [][]string{
|
||||||
|
{"a"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty expression",
|
||||||
|
dsl: "",
|
||||||
|
errMsg: "empty DSL expression",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unmatched open paren",
|
||||||
|
dsl: "(a | b",
|
||||||
|
errMsg: "missing closing parenthesis",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unmatched close paren",
|
||||||
|
dsl: "a | b)",
|
||||||
|
errMsg: "unexpected token",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown ref",
|
||||||
|
dsl: "+unknown",
|
||||||
|
errMsg: "unknown set reference +unknown",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty parens",
|
||||||
|
dsl: "()",
|
||||||
|
errMsg: "unexpected token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
refs := tt.refs
|
||||||
|
if refs == nil {
|
||||||
|
refs = map[string][][]string{}
|
||||||
|
}
|
||||||
|
result, err := ParseAndExpandDSL(tt.dsl, refs)
|
||||||
|
if tt.errMsg != "" {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.errMsg)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expect, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDSL_ExpansionCap(t *testing.T) {
|
||||||
|
// Build an expression that would exceed 1000 combinations:
|
||||||
|
// (a1|a2|...|a32) & (b1|b2|...|b32) = 1024 combos
|
||||||
|
var aItems, bItems []string
|
||||||
|
for i := 0; i < 32; i++ {
|
||||||
|
aItems = append(aItems, fmt.Sprintf("a%d", i))
|
||||||
|
bItems = append(bItems, fmt.Sprintf("b%d", i))
|
||||||
|
}
|
||||||
|
dsl := fmt.Sprintf("(%s) & (%s)",
|
||||||
|
join(aItems, " | "),
|
||||||
|
join(bItems, " | "),
|
||||||
|
)
|
||||||
|
_, err := ParseAndExpandDSL(dsl, map[string][][]string{})
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "exceeded")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDSL_ExtractRefs(t *testing.T) {
|
||||||
|
refs, err := extractRefs("+llms & v & +other")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"llms", "other"}, refs)
|
||||||
|
|
||||||
|
refs, err = extractRefs("a & b")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, refs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func join(items []string, sep string) string {
|
||||||
|
result := ""
|
||||||
|
for i, item := range items {
|
||||||
|
if i > 0 {
|
||||||
|
result += sep
|
||||||
|
}
|
||||||
|
result += item
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
@@ -0,0 +1,305 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func makeModels(names ...string) map[string]ModelConfig {
|
||||||
|
m := make(map[string]ModelConfig)
|
||||||
|
for _, name := range names {
|
||||||
|
m[name] = ModelConfig{Cmd: "echo " + name}
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_Basic(t *testing.T) {
|
||||||
|
models := makeModels("gemma", "qwen", "mistral", "voxtral", "llama70B")
|
||||||
|
|
||||||
|
matrix := MatrixConfig{
|
||||||
|
Var: map[string]string{
|
||||||
|
"g": "gemma",
|
||||||
|
"q": "qwen",
|
||||||
|
"m": "mistral",
|
||||||
|
"v": "voxtral",
|
||||||
|
"L": "llama70B",
|
||||||
|
},
|
||||||
|
EvictCosts: map[string]int{
|
||||||
|
"L": 30,
|
||||||
|
"v": 50,
|
||||||
|
},
|
||||||
|
Sets: OrderedSets{
|
||||||
|
{Name: "standard", DSL: "(g | q | m) & v"},
|
||||||
|
{Name: "full", DSL: "L"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
expanded, err := ValidateMatrix(matrix, models)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// standard expands to [gemma,voxtral], [qwen,voxtral], [mistral,voxtral]
|
||||||
|
// full expands to [llama70B]
|
||||||
|
assert.Len(t, expanded, 4)
|
||||||
|
|
||||||
|
assert.Equal(t, "standard", expanded[0].SetName)
|
||||||
|
assert.Equal(t, []string{"gemma", "voxtral"}, expanded[0].Models)
|
||||||
|
|
||||||
|
assert.Equal(t, "standard", expanded[1].SetName)
|
||||||
|
assert.Equal(t, []string{"qwen", "voxtral"}, expanded[1].Models)
|
||||||
|
|
||||||
|
assert.Equal(t, "standard", expanded[2].SetName)
|
||||||
|
assert.Equal(t, []string{"mistral", "voxtral"}, expanded[2].Models)
|
||||||
|
|
||||||
|
assert.Equal(t, "full", expanded[3].SetName)
|
||||||
|
assert.Equal(t, []string{"llama70B"}, expanded[3].Models)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_WithRef(t *testing.T) {
|
||||||
|
models := makeModels("gemma", "qwen", "mistral", "voxtral", "reranker")
|
||||||
|
|
||||||
|
matrix := MatrixConfig{
|
||||||
|
Var: map[string]string{
|
||||||
|
"g": "gemma",
|
||||||
|
"q": "qwen",
|
||||||
|
"m": "mistral",
|
||||||
|
"v": "voxtral",
|
||||||
|
"e": "reranker",
|
||||||
|
},
|
||||||
|
Sets: OrderedSets{
|
||||||
|
{Name: "llms", DSL: "g | q | m"},
|
||||||
|
{Name: "with_tts", DSL: "+llms & v"},
|
||||||
|
{Name: "mega", DSL: "+with_tts & e"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
expanded, err := ValidateMatrix(matrix, models)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// llms: [gemma], [qwen], [mistral]
|
||||||
|
// with_tts: [gemma,voxtral], [qwen,voxtral], [mistral,voxtral]
|
||||||
|
// mega: [gemma,reranker,voxtral], [qwen,reranker,voxtral], [mistral,reranker,voxtral]
|
||||||
|
assert.Len(t, expanded, 9)
|
||||||
|
|
||||||
|
// Check mega entries
|
||||||
|
megaEntries := filterBySetName(expanded, "mega")
|
||||||
|
assert.Len(t, megaEntries, 3)
|
||||||
|
assert.Equal(t, []string{"gemma", "reranker", "voxtral"}, megaEntries[0].Models)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_MapIDRequired(t *testing.T) {
|
||||||
|
// DSL cannot use real model names directly — must use var IDs
|
||||||
|
models := makeModels("gemma", "voxtral")
|
||||||
|
|
||||||
|
matrix := MatrixConfig{
|
||||||
|
Var: map[string]string{"g": "gemma"},
|
||||||
|
Sets: OrderedSets{
|
||||||
|
{Name: "combo", DSL: "g & voxtral"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := ValidateMatrix(matrix, models)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "unknown var ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_InvalidAliasKey(t *testing.T) {
|
||||||
|
models := makeModels("gemma")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
alias string
|
||||||
|
errMsg string
|
||||||
|
}{
|
||||||
|
{"too long", "abcdefghi", "alphanumeric and 1-8 characters"},
|
||||||
|
{"has underscore", "a_b", "alphanumeric and 1-8 characters"},
|
||||||
|
{"has hyphen", "a-b", "alphanumeric and 1-8 characters"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
matrix := MatrixConfig{
|
||||||
|
Var: map[string]string{tt.alias: "gemma"},
|
||||||
|
Sets: OrderedSets{{Name: "s", DSL: tt.alias}},
|
||||||
|
}
|
||||||
|
_, err := ValidateMatrix(matrix, models)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.errMsg)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_AliasReferencesUnknownModel(t *testing.T) {
|
||||||
|
models := makeModels("gemma")
|
||||||
|
|
||||||
|
matrix := MatrixConfig{
|
||||||
|
Var: map[string]string{"x": "nonexistent"},
|
||||||
|
Sets: OrderedSets{{Name: "s", DSL: "x"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := ValidateMatrix(matrix, models)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "unknown model")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_EvictCostInvalid(t *testing.T) {
|
||||||
|
models := makeModels("gemma")
|
||||||
|
|
||||||
|
t.Run("zero cost", func(t *testing.T) {
|
||||||
|
matrix := MatrixConfig{
|
||||||
|
Var: map[string]string{"g": "gemma"},
|
||||||
|
EvictCosts: map[string]int{"g": 0},
|
||||||
|
Sets: OrderedSets{{Name: "s", DSL: "g"}},
|
||||||
|
}
|
||||||
|
_, err := ValidateMatrix(matrix, models)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "positive integer")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("negative cost", func(t *testing.T) {
|
||||||
|
matrix := MatrixConfig{
|
||||||
|
Var: map[string]string{"g": "gemma"},
|
||||||
|
EvictCosts: map[string]int{"g": -1},
|
||||||
|
Sets: OrderedSets{{Name: "s", DSL: "g"}},
|
||||||
|
}
|
||||||
|
_, err := ValidateMatrix(matrix, models)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "positive integer")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unknown var ID in evict_costs", func(t *testing.T) {
|
||||||
|
matrix := MatrixConfig{
|
||||||
|
Var: map[string]string{"g": "gemma"},
|
||||||
|
EvictCosts: map[string]int{"unknown": 5},
|
||||||
|
Sets: OrderedSets{{Name: "s", DSL: "g"}},
|
||||||
|
}
|
||||||
|
_, err := ValidateMatrix(matrix, models)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "unknown var ID")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_CycleDetection(t *testing.T) {
|
||||||
|
models := makeModels("gemma")
|
||||||
|
|
||||||
|
matrix := MatrixConfig{
|
||||||
|
Var: map[string]string{"g": "gemma"},
|
||||||
|
Sets: OrderedSets{
|
||||||
|
{Name: "a", DSL: "+b"},
|
||||||
|
{Name: "b", DSL: "+a"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := ValidateMatrix(matrix, models)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "circular reference")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_UndefinedRefTarget(t *testing.T) {
|
||||||
|
models := makeModels("gemma")
|
||||||
|
|
||||||
|
matrix := MatrixConfig{
|
||||||
|
Var: map[string]string{"g": "gemma"},
|
||||||
|
Sets: OrderedSets{
|
||||||
|
{Name: "a", DSL: "+nonexistent"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := ValidateMatrix(matrix, models)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "references undefined set")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_NoSets(t *testing.T) {
|
||||||
|
_, err := ValidateMatrix(MatrixConfig{}, makeModels("gemma"))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "at least one set")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_UnknownMapIDInDSL(t *testing.T) {
|
||||||
|
models := makeModels("gemma")
|
||||||
|
|
||||||
|
matrix := MatrixConfig{
|
||||||
|
Var: map[string]string{"g": "gemma"},
|
||||||
|
Sets: OrderedSets{
|
||||||
|
{Name: "s", DSL: "g & nonexistent"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := ValidateMatrix(matrix, models)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "unknown var ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_ResolvedEvictCosts(t *testing.T) {
|
||||||
|
mc := &MatrixConfig{
|
||||||
|
Var: map[string]string{
|
||||||
|
"g": "gemma",
|
||||||
|
"L": "llama70B",
|
||||||
|
},
|
||||||
|
EvictCosts: map[string]int{
|
||||||
|
"L": 30,
|
||||||
|
"g": 5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
costs := mc.ResolvedEvictCosts()
|
||||||
|
assert.Equal(t, 30, costs["llama70B"])
|
||||||
|
assert.Equal(t, 5, costs["gemma"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_ConfigXOR(t *testing.T) {
|
||||||
|
// groups and matrix both defined
|
||||||
|
yaml := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: echo model1
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
groups:
|
||||||
|
group1:
|
||||||
|
members:
|
||||||
|
- model1
|
||||||
|
matrix:
|
||||||
|
sets:
|
||||||
|
s: "model1"
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "cannot use both")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_ConfigMatrixOnly(t *testing.T) {
|
||||||
|
yaml := `
|
||||||
|
models:
|
||||||
|
gemma:
|
||||||
|
cmd: echo gemma
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
qwen:
|
||||||
|
cmd: echo qwen
|
||||||
|
proxy: http://localhost:8081
|
||||||
|
matrix:
|
||||||
|
vars:
|
||||||
|
g: gemma
|
||||||
|
q: qwen
|
||||||
|
sets:
|
||||||
|
combo: "g | q"
|
||||||
|
`
|
||||||
|
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, cfg.Matrix)
|
||||||
|
assert.Len(t, cfg.ExpandedSets, 2)
|
||||||
|
// Groups should be empty when matrix is used
|
||||||
|
assert.Empty(t, cfg.Groups)
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterBySetName(sets []ExpandedSet, name string) []ExpandedSet {
|
||||||
|
var result []ExpandedSet
|
||||||
|
for _, s := range sets {
|
||||||
|
if s.SetName == name {
|
||||||
|
result = append(result, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
+298
@@ -0,0 +1,298 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"slices"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MatrixSolver contains pure swap-decision logic with no Process dependencies.
|
||||||
|
// It is safe for concurrent reads after construction.
|
||||||
|
type MatrixSolver struct {
|
||||||
|
expandedSets []config.ExpandedSet // all valid model combinations
|
||||||
|
evictCosts map[string]int // real model name -> eviction cost (default 1)
|
||||||
|
modelToSets map[string][]int // model name -> indices into expandedSets
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMatrixSolver builds a solver from expanded sets and eviction costs.
|
||||||
|
func NewMatrixSolver(expandedSets []config.ExpandedSet, evictCosts map[string]int) *MatrixSolver {
|
||||||
|
modelToSets := make(map[string][]int)
|
||||||
|
for i, es := range expandedSets {
|
||||||
|
for _, model := range es.Models {
|
||||||
|
modelToSets[model] = append(modelToSets[model], i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &MatrixSolver{
|
||||||
|
expandedSets: expandedSets,
|
||||||
|
evictCosts: evictCosts,
|
||||||
|
modelToSets: modelToSets,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SolveResult describes what the solver decided.
|
||||||
|
type SolveResult struct {
|
||||||
|
Evict []string // running models that must be stopped
|
||||||
|
TargetSet []string // the chosen set of models (for informational purposes)
|
||||||
|
SetName string // name of the chosen set
|
||||||
|
DSL string // original DSL expression for the chosen set
|
||||||
|
TotalCost int // total eviction cost
|
||||||
|
}
|
||||||
|
|
||||||
|
// Solve determines which models to evict when a model is requested.
|
||||||
|
//
|
||||||
|
// Algorithm:
|
||||||
|
// 1. If requestedModel is already running, no eviction needed.
|
||||||
|
// 2. Find all sets containing requestedModel.
|
||||||
|
// 3. If no sets found, the model runs alone; evict all running models.
|
||||||
|
// 4. For each candidate set, compute cost = sum of evict_costs for running
|
||||||
|
// models NOT in that set.
|
||||||
|
// 5. Pick lowest cost. Ties broken by definition order (index in expandedSets).
|
||||||
|
// 6. Return models to evict and the chosen set.
|
||||||
|
func (s *MatrixSolver) Solve(requestedModel string, runningModels []string) (SolveResult, error) {
|
||||||
|
// If already running, nothing to do (but fill in set info for logging)
|
||||||
|
if slices.Contains(runningModels, requestedModel) {
|
||||||
|
setName, dsl := s.findMatchingSet(requestedModel, runningModels)
|
||||||
|
return SolveResult{
|
||||||
|
TargetSet: runningModels,
|
||||||
|
SetName: setName,
|
||||||
|
DSL: dsl,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
candidateIndices := s.modelToSets[requestedModel]
|
||||||
|
|
||||||
|
// Model not in any set: runs alone, evict everything
|
||||||
|
if len(candidateIndices) == 0 {
|
||||||
|
evict := make([]string, len(runningModels))
|
||||||
|
copy(evict, runningModels)
|
||||||
|
return SolveResult{
|
||||||
|
Evict: evict,
|
||||||
|
TargetSet: []string{requestedModel},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the cheapest candidate set
|
||||||
|
bestCost := -1
|
||||||
|
bestIdx := -1
|
||||||
|
|
||||||
|
for _, idx := range candidateIndices {
|
||||||
|
setModels := s.expandedSets[idx].Models
|
||||||
|
cost := 0
|
||||||
|
for _, running := range runningModels {
|
||||||
|
if !slices.Contains(setModels, running) {
|
||||||
|
cost += s.evictCost(running)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if bestCost < 0 || cost < bestCost || (cost == bestCost && idx < bestIdx) {
|
||||||
|
bestCost = cost
|
||||||
|
bestIdx = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine which running models to evict
|
||||||
|
chosen := s.expandedSets[bestIdx]
|
||||||
|
var evict []string
|
||||||
|
for _, running := range runningModels {
|
||||||
|
if !slices.Contains(chosen.Models, running) {
|
||||||
|
evict = append(evict, running)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return SolveResult{
|
||||||
|
Evict: evict,
|
||||||
|
TargetSet: chosen.Models,
|
||||||
|
SetName: chosen.SetName,
|
||||||
|
DSL: chosen.DSL,
|
||||||
|
TotalCost: bestCost,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// findMatchingSet finds the expanded set that contains all running models.
|
||||||
|
// Returns the set name and DSL, or empty strings if no match.
|
||||||
|
func (s *MatrixSolver) findMatchingSet(requestedModel string, runningModels []string) (string, string) {
|
||||||
|
for _, idx := range s.modelToSets[requestedModel] {
|
||||||
|
set := s.expandedSets[idx]
|
||||||
|
allInSet := true
|
||||||
|
for _, m := range runningModels {
|
||||||
|
if !slices.Contains(set.Models, m) {
|
||||||
|
allInSet = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if allInSet {
|
||||||
|
return set.SetName, set.DSL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MatrixSolver) evictCost(model string) int {
|
||||||
|
if cost, ok := s.evictCosts[model]; ok {
|
||||||
|
return cost
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Matrix manages processes using solver-based swap logic.
|
||||||
|
type Matrix struct {
|
||||||
|
sync.Mutex
|
||||||
|
solver *MatrixSolver
|
||||||
|
processes map[string]*Process // all processes keyed by real model name
|
||||||
|
config config.Config
|
||||||
|
proxyLogger *LogMonitor
|
||||||
|
upstreamLogger *LogMonitor
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMatrix creates a Matrix from config. It creates a Process for every
|
||||||
|
// model defined in the config (any model can run alone even if not in a set).
|
||||||
|
func NewMatrix(cfg config.Config, proxyLogger, upstreamLogger *LogMonitor) *Matrix {
|
||||||
|
processes := make(map[string]*Process)
|
||||||
|
for modelID, modelConfig := range cfg.Models {
|
||||||
|
processLogger := NewLogMonitorWriter(upstreamLogger)
|
||||||
|
process := NewProcess(modelID, cfg.HealthCheckTimeout, modelConfig, processLogger, proxyLogger)
|
||||||
|
processes[modelID] = process
|
||||||
|
}
|
||||||
|
|
||||||
|
evictCosts := cfg.Matrix.ResolvedEvictCosts()
|
||||||
|
|
||||||
|
return &Matrix{
|
||||||
|
solver: NewMatrixSolver(cfg.ExpandedSets, evictCosts),
|
||||||
|
processes: processes,
|
||||||
|
config: cfg,
|
||||||
|
proxyLogger: proxyLogger,
|
||||||
|
upstreamLogger: upstreamLogger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProxyRequest handles the swap logic and proxies the request to the model.
|
||||||
|
func (m *Matrix) ProxyRequest(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
process, ok := m.processes[modelID]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("model %s not found in matrix", modelID)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Lock()
|
||||||
|
running := m.runningModels()
|
||||||
|
result, err := m.solver.Solve(modelID, running)
|
||||||
|
if err != nil {
|
||||||
|
m.Unlock()
|
||||||
|
return fmt.Errorf("matrix solver error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log solver decision
|
||||||
|
if len(result.Evict) > 0 {
|
||||||
|
m.proxyLogger.Infof("Matrix: model=%s set=%s dsl=%q evict=%v target=%v cost=%d",
|
||||||
|
modelID, result.SetName, result.DSL, result.Evict, result.TargetSet, result.TotalCost)
|
||||||
|
} else if len(running) == 0 {
|
||||||
|
m.proxyLogger.Infof("Matrix: model=%s starting (no models running)", modelID)
|
||||||
|
} else {
|
||||||
|
m.proxyLogger.Debugf("Matrix: model=%s already running in set=%s dsl=%q", modelID, result.SetName, result.DSL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Evict models that need to be stopped
|
||||||
|
if len(result.Evict) > 0 {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, evictModel := range result.Evict {
|
||||||
|
if p, exists := m.processes[evictModel]; exists {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(p *Process) {
|
||||||
|
defer wg.Done()
|
||||||
|
p.Stop()
|
||||||
|
}(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
m.Unlock()
|
||||||
|
|
||||||
|
// Proxy the request (Process handles on-demand start)
|
||||||
|
process.ProxyRequest(w, r)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopProcesses stops all running processes.
|
||||||
|
func (m *Matrix) StopProcesses(strategy StopStrategy) {
|
||||||
|
m.Lock()
|
||||||
|
defer m.Unlock()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, process := range m.processes {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(p *Process) {
|
||||||
|
defer wg.Done()
|
||||||
|
switch strategy {
|
||||||
|
case StopImmediately:
|
||||||
|
p.StopImmediately()
|
||||||
|
default:
|
||||||
|
p.Stop()
|
||||||
|
}
|
||||||
|
}(process)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopProcess stops a single process by model ID.
|
||||||
|
func (m *Matrix) StopProcess(modelID string, strategy StopStrategy) error {
|
||||||
|
process, ok := m.processes[modelID]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("process not found for %s", modelID)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch strategy {
|
||||||
|
case StopImmediately:
|
||||||
|
process.StopImmediately()
|
||||||
|
default:
|
||||||
|
process.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown shuts down all processes.
|
||||||
|
func (m *Matrix) Shutdown() {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, process := range m.processes {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(p *Process) {
|
||||||
|
defer wg.Done()
|
||||||
|
p.Shutdown()
|
||||||
|
}(process)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunningModels returns model names currently in StateReady.
|
||||||
|
func (m *Matrix) RunningModels() []string {
|
||||||
|
m.Lock()
|
||||||
|
defer m.Unlock()
|
||||||
|
return m.runningModels()
|
||||||
|
}
|
||||||
|
|
||||||
|
// runningModels returns running model names (caller must hold lock).
|
||||||
|
func (m *Matrix) runningModels() []string {
|
||||||
|
var running []string
|
||||||
|
for id, process := range m.processes {
|
||||||
|
if process.CurrentState() == StateReady {
|
||||||
|
running = append(running, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Strings(running)
|
||||||
|
return running
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProcess returns the Process for a model.
|
||||||
|
func (m *Matrix) GetProcess(modelID string) (*Process, bool) {
|
||||||
|
p, ok := m.processes[modelID]
|
||||||
|
return p, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasModel returns true if the model is managed by this matrix.
|
||||||
|
func (m *Matrix) HasModel(modelID string) bool {
|
||||||
|
_, ok := m.processes[modelID]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
@@ -0,0 +1,227 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Helper to build expanded sets for solver tests
|
||||||
|
func makeExpandedSets(sets ...struct {
|
||||||
|
name string
|
||||||
|
models []string
|
||||||
|
}) []config.ExpandedSet {
|
||||||
|
var result []config.ExpandedSet
|
||||||
|
for _, s := range sets {
|
||||||
|
result = append(result, config.ExpandedSet{
|
||||||
|
SetName: s.name,
|
||||||
|
Models: s.models,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func es(name string, models ...string) struct {
|
||||||
|
name string
|
||||||
|
models []string
|
||||||
|
} {
|
||||||
|
return struct {
|
||||||
|
name string
|
||||||
|
models []string
|
||||||
|
}{name, models}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrixSolver_AlreadyRunning(t *testing.T) {
|
||||||
|
solver := NewMatrixSolver(
|
||||||
|
makeExpandedSets(es("s1", "a", "b")),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err := solver.Solve("a", []string{"a"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, result.Evict)
|
||||||
|
assert.Equal(t, []string{"a"}, result.TargetSet)
|
||||||
|
assert.Equal(t, "s1", result.SetName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrixSolver_NotInAnySet_RunsAlone(t *testing.T) {
|
||||||
|
solver := NewMatrixSolver(
|
||||||
|
makeExpandedSets(es("s1", "a", "b")),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Model "c" not in any set
|
||||||
|
result, err := solver.Solve("c", []string{"a", "b"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.ElementsMatch(t, []string{"a", "b"}, result.Evict)
|
||||||
|
assert.Equal(t, []string{"c"}, result.TargetSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrixSolver_NotInAnySet_NothingRunning(t *testing.T) {
|
||||||
|
solver := NewMatrixSolver(
|
||||||
|
makeExpandedSets(es("s1", "a", "b")),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err := solver.Solve("c", []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, result.Evict)
|
||||||
|
assert.Equal(t, []string{"c"}, result.TargetSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrixSolver_SingleSet_EvictsNonMembers(t *testing.T) {
|
||||||
|
// Set: [a, b]. Request a when b and c are running.
|
||||||
|
solver := NewMatrixSolver(
|
||||||
|
makeExpandedSets(es("s1", "a", "b")),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err := solver.Solve("a", []string{"b", "c"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
// c is not in the set, so it gets evicted. b is in the set, so it stays.
|
||||||
|
assert.Equal(t, []string{"c"}, result.Evict)
|
||||||
|
assert.Equal(t, []string{"a", "b"}, result.TargetSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrixSolver_PicksLowestCost(t *testing.T) {
|
||||||
|
// Two sets containing model "a":
|
||||||
|
// s1: [a, v] — if v is running, cost=0; if L is running, cost=30
|
||||||
|
// s2: [a, L] — if L is running, cost=0; if v is running, cost=50
|
||||||
|
solver := NewMatrixSolver(
|
||||||
|
makeExpandedSets(
|
||||||
|
es("s1", "a", "v"),
|
||||||
|
es("s2", "a", "L"),
|
||||||
|
),
|
||||||
|
map[string]int{"v": 50, "L": 30},
|
||||||
|
)
|
||||||
|
|
||||||
|
// v is running. Switching to a:
|
||||||
|
// s1 cost: v is in s1, so 0
|
||||||
|
// s2 cost: v is NOT in s2, so 50
|
||||||
|
// => pick s1
|
||||||
|
result, err := solver.Solve("a", []string{"v"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, result.Evict)
|
||||||
|
assert.Equal(t, []string{"a", "v"}, result.TargetSet)
|
||||||
|
|
||||||
|
// L is running. Switching to a:
|
||||||
|
// s1 cost: L is NOT in s1, so 30
|
||||||
|
// s2 cost: L is in s2, so 0
|
||||||
|
// => pick s2
|
||||||
|
result, err = solver.Solve("a", []string{"L"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, result.Evict)
|
||||||
|
assert.Equal(t, []string{"a", "L"}, result.TargetSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrixSolver_TieBreakingByDefinitionOrder(t *testing.T) {
|
||||||
|
// Two sets with identical cost. Definition order should win.
|
||||||
|
solver := NewMatrixSolver(
|
||||||
|
makeExpandedSets(
|
||||||
|
es("s1", "a", "x"),
|
||||||
|
es("s2", "a", "y"),
|
||||||
|
),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Nothing running, both sets cost 0. s1 is first.
|
||||||
|
result, err := solver.Solve("a", []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, result.Evict)
|
||||||
|
assert.Equal(t, []string{"a", "x"}, result.TargetSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrixSolver_EvictCostPreservesExpensive(t *testing.T) {
|
||||||
|
// Model "v" costs 50 to evict, "m" costs 1 (default).
|
||||||
|
// Sets: [g,v], [g,m]
|
||||||
|
// Running: v, m. Request g.
|
||||||
|
// s1=[g,v]: evict m (cost 1), keep v
|
||||||
|
// s2=[g,m]: evict v (cost 50), keep m
|
||||||
|
// => pick s1
|
||||||
|
solver := NewMatrixSolver(
|
||||||
|
makeExpandedSets(
|
||||||
|
es("s1", "g", "v"),
|
||||||
|
es("s2", "g", "m"),
|
||||||
|
),
|
||||||
|
map[string]int{"v": 50},
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err := solver.Solve("g", []string{"v", "m"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"m"}, result.Evict)
|
||||||
|
assert.Equal(t, []string{"g", "v"}, result.TargetSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrixSolver_NothingRunning(t *testing.T) {
|
||||||
|
solver := NewMatrixSolver(
|
||||||
|
makeExpandedSets(
|
||||||
|
es("s1", "g", "v"),
|
||||||
|
es("s2", "q", "v"),
|
||||||
|
),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err := solver.Solve("g", []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, result.Evict)
|
||||||
|
assert.Equal(t, []string{"g", "v"}, result.TargetSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrixSolver_FullScenario(t *testing.T) {
|
||||||
|
// Simulates the example config:
|
||||||
|
// standard: [g,v], [q,v], [m,v]
|
||||||
|
// with_rerank: [g,v,e], [q,v,e]
|
||||||
|
// creative: [g,sd], [q,sd]
|
||||||
|
// full: [L]
|
||||||
|
solver := NewMatrixSolver(
|
||||||
|
makeExpandedSets(
|
||||||
|
es("standard", "g", "v"),
|
||||||
|
es("standard", "q", "v"),
|
||||||
|
es("standard", "m", "v"),
|
||||||
|
es("with_rerank", "e", "g", "v"),
|
||||||
|
es("with_rerank", "e", "q", "v"),
|
||||||
|
es("creative", "g", "sd"),
|
||||||
|
es("creative", "q", "sd"),
|
||||||
|
es("full", "L"),
|
||||||
|
),
|
||||||
|
map[string]int{"v": 50, "L": 30, "whisper": 10},
|
||||||
|
)
|
||||||
|
|
||||||
|
// Running: g, v. Request q.
|
||||||
|
// standard[q,v]: evict g (cost 1), keep v. Total: 1.
|
||||||
|
// with_rerank[q,v,e]: evict g (cost 1), keep v. Total: 1.
|
||||||
|
// => tie, pick first by definition order = standard[q,v]
|
||||||
|
result, err := solver.Solve("q", []string{"g", "v"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"g"}, result.Evict)
|
||||||
|
assert.Equal(t, []string{"q", "v"}, result.TargetSet)
|
||||||
|
|
||||||
|
// Running: g, v. Request L.
|
||||||
|
// full[L]: evict g (cost 1) + v (cost 50). Total: 51.
|
||||||
|
// Only one set contains L, so pick it.
|
||||||
|
result, err = solver.Solve("L", []string{"g", "v"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.ElementsMatch(t, []string{"g", "v"}, result.Evict)
|
||||||
|
assert.Equal(t, []string{"L"}, result.TargetSet)
|
||||||
|
|
||||||
|
// Running: g, v. Request sd.
|
||||||
|
// creative[g,sd]: evict v (cost 50). Total: 50.
|
||||||
|
// creative[q,sd]: evict g (cost 1) + v (cost 50). Total: 51.
|
||||||
|
// => pick creative[g,sd]
|
||||||
|
result, err = solver.Solve("sd", []string{"g", "v"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"v"}, result.Evict)
|
||||||
|
assert.Equal(t, []string{"g", "sd"}, result.TargetSet)
|
||||||
|
|
||||||
|
// Running: q, v, e. Request g.
|
||||||
|
// standard[g,v]: evict q (1) + e (1). Total: 2.
|
||||||
|
// with_rerank[g,v,e]: evict q (1). Total: 1.
|
||||||
|
// creative[g,sd]: evict q (1) + v (50) + e (1). Total: 52.
|
||||||
|
// => pick with_rerank[g,v,e]
|
||||||
|
result, err = solver.Solve("g", []string{"e", "q", "v"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"q"}, result.Evict)
|
||||||
|
assert.Equal(t, []string{"e", "g", "v"}, result.TargetSet)
|
||||||
|
}
|
||||||
+99
-34
@@ -77,6 +77,9 @@ type ProxyManager struct {
|
|||||||
|
|
||||||
processGroups map[string]*ProcessGroup
|
processGroups map[string]*ProcessGroup
|
||||||
|
|
||||||
|
// matrix-based swap (mutually exclusive with processGroups)
|
||||||
|
matrix *Matrix
|
||||||
|
|
||||||
inFlightCounter *InflightCounter
|
inFlightCounter *InflightCounter
|
||||||
|
|
||||||
// shutdown signaling
|
// shutdown signaling
|
||||||
@@ -203,10 +206,14 @@ func New(proxyConfig config.Config) *ProxyManager {
|
|||||||
peerProxy: peerProxy,
|
peerProxy: peerProxy,
|
||||||
}
|
}
|
||||||
|
|
||||||
// create the process groups
|
// create either matrix or process groups (mutually exclusive)
|
||||||
for groupID := range proxyConfig.Groups {
|
if proxyConfig.Matrix != nil {
|
||||||
processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger)
|
pm.matrix = NewMatrix(proxyConfig, proxyLogger, upstreamLogger)
|
||||||
pm.processGroups[groupID] = processGroup
|
} else {
|
||||||
|
for groupID := range proxyConfig.Groups {
|
||||||
|
processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger)
|
||||||
|
pm.processGroups[groupID] = processGroup
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pm.setupGinEngine()
|
pm.setupGinEngine()
|
||||||
@@ -225,18 +232,29 @@ func New(proxyConfig config.Config) *ProxyManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
proxyLogger.Infof("Preloading model: %s", modelID)
|
proxyLogger.Infof("Preloading model: %s", modelID)
|
||||||
processGroup, err := pm.swapProcessGroup(modelID)
|
|
||||||
|
|
||||||
if err != nil {
|
var preloadErr error
|
||||||
|
req, _ := http.NewRequest("GET", "/", nil)
|
||||||
|
|
||||||
|
if pm.matrix != nil {
|
||||||
|
preloadErr = pm.matrix.ProxyRequest(modelID, discardWriter, req)
|
||||||
|
} else {
|
||||||
|
processGroup, err := pm.swapProcessGroup(modelID)
|
||||||
|
if err != nil {
|
||||||
|
preloadErr = err
|
||||||
|
} else {
|
||||||
|
preloadErr = processGroup.ProxyRequest(modelID, discardWriter, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if preloadErr != nil {
|
||||||
event.Emit(ModelPreloadedEvent{
|
event.Emit(ModelPreloadedEvent{
|
||||||
ModelName: modelID,
|
ModelName: modelID,
|
||||||
Success: false,
|
Success: false,
|
||||||
})
|
})
|
||||||
proxyLogger.Errorf("Failed to preload model %s: %v", modelID, err)
|
proxyLogger.Errorf("Failed to preload model %s: %v", modelID, preloadErr)
|
||||||
continue
|
continue
|
||||||
} else {
|
} else {
|
||||||
req, _ := http.NewRequest("GET", "/", nil)
|
|
||||||
processGroup.ProxyRequest(modelID, discardWriter, req)
|
|
||||||
event.Emit(ModelPreloadedEvent{
|
event.Emit(ModelPreloadedEvent{
|
||||||
ModelName: modelID,
|
ModelName: modelID,
|
||||||
Success: true,
|
Success: true,
|
||||||
@@ -453,6 +471,11 @@ func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
|
|||||||
pm.Lock()
|
pm.Lock()
|
||||||
defer pm.Unlock()
|
defer pm.Unlock()
|
||||||
|
|
||||||
|
if pm.matrix != nil {
|
||||||
|
pm.matrix.StopProcesses(strategy)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// stop Processes in parallel
|
// stop Processes in parallel
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for _, processGroup := range pm.processGroups {
|
for _, processGroup := range pm.processGroups {
|
||||||
@@ -473,6 +496,12 @@ func (pm *ProxyManager) Shutdown() {
|
|||||||
|
|
||||||
pm.proxyLogger.Debug("Shutdown() called in proxy manager")
|
pm.proxyLogger.Debug("Shutdown() called in proxy manager")
|
||||||
|
|
||||||
|
if pm.matrix != nil {
|
||||||
|
pm.matrix.Shutdown()
|
||||||
|
pm.shutdownCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
// Send shutdown signal to all process in groups
|
// Send shutdown signal to all process in groups
|
||||||
for _, processGroup := range pm.processGroups {
|
for _, processGroup := range pm.processGroups {
|
||||||
@@ -639,10 +668,16 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
processGroup, err := pm.swapProcessGroup(modelID)
|
var handler func(string, http.ResponseWriter, *http.Request) error
|
||||||
if err != nil {
|
if pm.matrix != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
handler = pm.matrix.ProxyRequest
|
||||||
return
|
} else {
|
||||||
|
processGroup, err := pm.swapProcessGroup(modelID)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
handler = processGroup.ProxyRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
// rewrite the path
|
// rewrite the path
|
||||||
@@ -651,13 +686,13 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
|||||||
|
|
||||||
// attempt to record metrics if it is a POST request
|
// attempt to record metrics if it is a POST request
|
||||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||||
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, handler); err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||||
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
|
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := processGroup.ProxyRequest(modelID, c.Writer, c.Request); err != nil {
|
if err := handler(modelID, c.Writer, c.Request); err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", modelID, originalPath)
|
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", modelID, originalPath)
|
||||||
return
|
return
|
||||||
@@ -683,10 +718,16 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
|
|||||||
|
|
||||||
modelID, found := pm.config.RealModelName(requestedModel)
|
modelID, found := pm.config.RealModelName(requestedModel)
|
||||||
if found {
|
if found {
|
||||||
processGroup, err := pm.swapProcessGroup(modelID)
|
var localHandler func(string, http.ResponseWriter, *http.Request) error
|
||||||
if err != nil {
|
if pm.matrix != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
localHandler = pm.matrix.ProxyRequest
|
||||||
return
|
} else {
|
||||||
|
processGroup, err := pm.swapProcessGroup(modelID)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
localHandler = processGroup.ProxyRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
// issue #69 allow custom model names to be sent to upstream
|
// issue #69 allow custom model names to be sent to upstream
|
||||||
@@ -737,7 +778,7 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||||
nextHandler = processGroup.ProxyRequest
|
nextHandler = localHandler
|
||||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||||
modelID = requestedModel
|
modelID = requestedModel
|
||||||
@@ -823,15 +864,19 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|||||||
|
|
||||||
modelID, found := pm.config.RealModelName(requestedModel)
|
modelID, found := pm.config.RealModelName(requestedModel)
|
||||||
if found {
|
if found {
|
||||||
processGroup, err := pm.swapProcessGroup(modelID)
|
if pm.matrix != nil {
|
||||||
if err != nil {
|
nextHandler = pm.matrix.ProxyRequest
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
} else {
|
||||||
return
|
processGroup, err := pm.swapProcessGroup(modelID)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nextHandler = processGroup.ProxyRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
useModelName = pm.config.Models[modelID].UseModelName
|
useModelName = pm.config.Models[modelID].UseModelName
|
||||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||||
nextHandler = processGroup.ProxyRequest
|
|
||||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||||
modelID = requestedModel
|
modelID = requestedModel
|
||||||
@@ -942,14 +987,18 @@ func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) {
|
|||||||
var modelID string
|
var modelID string
|
||||||
|
|
||||||
if realModelID, found := pm.config.RealModelName(requestedModel); found {
|
if realModelID, found := pm.config.RealModelName(requestedModel); found {
|
||||||
processGroup, err := pm.swapProcessGroup(realModelID)
|
|
||||||
if err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
modelID = realModelID
|
modelID = realModelID
|
||||||
|
if pm.matrix != nil {
|
||||||
|
nextHandler = pm.matrix.ProxyRequest
|
||||||
|
} else {
|
||||||
|
processGroup, err := pm.swapProcessGroup(realModelID)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nextHandler = processGroup.ProxyRequest
|
||||||
|
}
|
||||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||||
nextHandler = processGroup.ProxyRequest
|
|
||||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||||
modelID = requestedModel
|
modelID = requestedModel
|
||||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||||
@@ -1048,9 +1097,9 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
|||||||
context.Header("Content-Type", "application/json")
|
context.Header("Content-Type", "application/json")
|
||||||
runningProcesses := make([]gin.H, 0) // Default to an empty response.
|
runningProcesses := make([]gin.H, 0) // Default to an empty response.
|
||||||
|
|
||||||
for _, processGroup := range pm.processGroups {
|
if pm.matrix != nil {
|
||||||
for _, process := range processGroup.processes {
|
for _, modelID := range pm.matrix.RunningModels() {
|
||||||
if process.CurrentState() == StateReady {
|
if process, ok := pm.matrix.GetProcess(modelID); ok {
|
||||||
runningProcesses = append(runningProcesses, gin.H{
|
runningProcesses = append(runningProcesses, gin.H{
|
||||||
"model": process.ID,
|
"model": process.ID,
|
||||||
"state": process.state,
|
"state": process.state,
|
||||||
@@ -1062,6 +1111,22 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
for _, processGroup := range pm.processGroups {
|
||||||
|
for _, process := range processGroup.processes {
|
||||||
|
if process.CurrentState() == StateReady {
|
||||||
|
runningProcesses = append(runningProcesses, gin.H{
|
||||||
|
"model": process.ID,
|
||||||
|
"state": process.state,
|
||||||
|
"cmd": process.config.Cmd,
|
||||||
|
"proxy": process.config.Proxy,
|
||||||
|
"ttl": process.config.UnloadAfter,
|
||||||
|
"name": process.config.Name,
|
||||||
|
"description": process.config.Description,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put the results under the `running` key.
|
// Put the results under the `running` key.
|
||||||
|
|||||||
+34
-28
@@ -55,27 +55,28 @@ func (pm *ProxyManager) getModelStatus() []Model {
|
|||||||
// Iterate over sorted keys
|
// Iterate over sorted keys
|
||||||
for _, modelID := range modelIDs {
|
for _, modelID := range modelIDs {
|
||||||
// Get process state
|
// Get process state
|
||||||
processGroup := pm.findGroupByModelName(modelID)
|
|
||||||
state := "unknown"
|
state := "unknown"
|
||||||
if processGroup != nil {
|
var process *Process
|
||||||
process := processGroup.processes[modelID]
|
if pm.matrix != nil {
|
||||||
if process != nil {
|
process, _ = pm.matrix.GetProcess(modelID)
|
||||||
var stateStr string
|
} else {
|
||||||
switch process.CurrentState() {
|
processGroup := pm.findGroupByModelName(modelID)
|
||||||
case StateReady:
|
if processGroup != nil {
|
||||||
stateStr = "ready"
|
process = processGroup.processes[modelID]
|
||||||
case StateStarting:
|
}
|
||||||
stateStr = "starting"
|
}
|
||||||
case StateStopping:
|
if process != nil {
|
||||||
stateStr = "stopping"
|
switch process.CurrentState() {
|
||||||
case StateShutdown:
|
case StateReady:
|
||||||
stateStr = "shutdown"
|
state = "ready"
|
||||||
case StateStopped:
|
case StateStarting:
|
||||||
stateStr = "stopped"
|
state = "starting"
|
||||||
default:
|
case StateStopping:
|
||||||
stateStr = "unknown"
|
state = "stopping"
|
||||||
}
|
case StateShutdown:
|
||||||
state = stateStr
|
state = "shutdown"
|
||||||
|
case StateStopped:
|
||||||
|
state = "stopped"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
models = append(models, Model{
|
models = append(models, Model{
|
||||||
@@ -254,18 +255,23 @@ func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
processGroup := pm.findGroupByModelName(realModelName)
|
var stopErr error
|
||||||
if processGroup == nil {
|
if pm.matrix != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
|
stopErr = pm.matrix.StopProcess(realModelName, StopImmediately)
|
||||||
return
|
} else {
|
||||||
|
processGroup := pm.findGroupByModelName(realModelName)
|
||||||
|
if processGroup == nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stopErr = processGroup.StopProcess(realModelName, StopImmediately)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := processGroup.StopProcess(realModelName, StopImmediately); err != nil {
|
if stopErr != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", stopErr.Error()))
|
||||||
return
|
return
|
||||||
} else {
|
|
||||||
c.String(http.StatusOK, "OK")
|
|
||||||
}
|
}
|
||||||
|
c.String(http.StatusOK, "OK")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) apiGetVersion(c *gin.Context) {
|
func (pm *ProxyManager) apiGetVersion(c *gin.Context) {
|
||||||
|
|||||||
Reference in New Issue
Block a user