Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 316ad63f76 | |||
| e37077a963 | |||
| eff9b60434 | |||
| 9bcddad91b | |||
| a15e47922c | |||
| 0ab214d1c8 | |||
| d07b063ab6 | |||
| 826210dac9 | |||
| 6cf1317341 | |||
| 8e84b2ec4f | |||
| ed77385d08 | |||
| 92b90447e8 | |||
| 62aea0e83d | |||
| 8c660dcb90 | |||
| f6877b8175 | |||
| 9b3a33d7b9 |
@@ -15,6 +15,8 @@ reviews:
|
|||||||
auto_review:
|
auto_review:
|
||||||
enabled: false
|
enabled: false
|
||||||
drafts: false
|
drafts: false
|
||||||
|
unit_tests:
|
||||||
|
enabled: false
|
||||||
chat:
|
chat:
|
||||||
auto_reply: true
|
auto_reply: true
|
||||||
issue_enrichment:
|
issue_enrichment:
|
||||||
|
|||||||
@@ -44,13 +44,10 @@ jobs:
|
|||||||
|
|
||||||
echo "✓ config-schema.json is valid"
|
echo "✓ config-schema.json is valid"
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Go
|
||||||
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 #v6.2.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c #6.4.0
|
||||||
with:
|
with:
|
||||||
python-version: "3.x"
|
go-version-file: go.mod
|
||||||
|
|
||||||
- name: Install check-jsonschema
|
|
||||||
run: pip install check-jsonschema
|
|
||||||
|
|
||||||
- name: Validate config.example.yaml against schema
|
- name: Validate config.example.yaml against schema
|
||||||
run: check-jsonschema --schemafile config-schema.json config.example.yaml
|
run: go test ./internal/config/ -run TestConfig_ExampleMatchesSchema -v
|
||||||
|
|||||||
@@ -88,10 +88,11 @@ Real time log streaming:
|
|||||||
llama-swap can be installed in multiple ways
|
llama-swap can be installed in multiple ways
|
||||||
|
|
||||||
1. Docker
|
1. Docker
|
||||||
2. Homebrew (OSX and Linux)
|
2. Homebrew (macOS and Linux)
|
||||||
3. WinGet
|
3. MacPorts (macOS)
|
||||||
4. From release binaries
|
4. WinGet
|
||||||
5. From source
|
5. From release binaries
|
||||||
|
6. From source
|
||||||
|
|
||||||
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||||
|
|
||||||
@@ -155,6 +156,16 @@ brew install llama-swap
|
|||||||
llama-swap --config path/to/config.yaml --listen localhost:8080
|
llama-swap --config path/to/config.yaml --listen localhost:8080
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### MacPorts (macOS)
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Maintained by MacPorts community - [llama-swap port](https://ports.macports.org/port/llama-swap). It is not an official part of llama-swap.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
sudo port install llama-swap
|
||||||
|
llama-swap --config path/to/config.yaml --listen localhost:8080
|
||||||
|
```
|
||||||
|
|
||||||
### WinGet Install (Windows)
|
### WinGet Install (Windows)
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
|
|||||||
+227
-73
@@ -82,6 +82,78 @@
|
|||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"description": "Timeout settings for proxy connections."
|
"description": "Timeout settings for proxy connections."
|
||||||
|
},
|
||||||
|
"groupsConfig": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"members"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"swap": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": true,
|
||||||
|
"description": "Controls model swapping behaviour within the group. True: only one model runs at a time. False: all models can run together."
|
||||||
|
},
|
||||||
|
"exclusive": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": true,
|
||||||
|
"description": "Controls how the group affects other groups. True: causes all other groups to unload when this group runs a model. False: does not affect other groups."
|
||||||
|
},
|
||||||
|
"persistent": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": false,
|
||||||
|
"description": "Prevents other groups from unloading the models in this group. Does not affect individual model behaviour."
|
||||||
|
},
|
||||||
|
"members": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"description": "Array of model IDs that are members of this group. Model IDs must be defined in models."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
|
||||||
|
},
|
||||||
|
"matrixConfig": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Solver-based alternative to groups. Declares valid combinations of concurrent models. The solver minimizes eviction cost when swapping. A config must use either groups or matrix, not both.",
|
||||||
|
"required": [
|
||||||
|
"vars",
|
||||||
|
"sets"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"vars": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Short names for models. Keys must be alphanumeric, 1-8 characters. All sets and evict_costs must use these IDs.",
|
||||||
|
"minProperties": 1,
|
||||||
|
"additionalProperties": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"propertyNames": {
|
||||||
|
"pattern": "^[a-zA-Z0-9]{1,8}$"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"evict_costs": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Relative cost of evicting a running model. Models not listed default to 1. Values must be positive integers.",
|
||||||
|
"additionalProperties": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 1
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"sets": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Named sets of concurrent model combinations. Values are DSL strings using & (AND), | (OR), () (grouping), and +ref (inline another set). Definition order is used for tie-breaking.",
|
||||||
|
"minProperties": 1,
|
||||||
|
"additionalProperties": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -306,81 +378,68 @@
|
|||||||
},
|
},
|
||||||
"timeouts": {
|
"timeouts": {
|
||||||
"$ref": "#/definitions/timeouts"
|
"$ref": "#/definitions/timeouts"
|
||||||
|
},
|
||||||
|
"capabilities": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"in": {
|
||||||
|
"type": "array",
|
||||||
|
"minItems": 1,
|
||||||
|
"uniqueItems": true,
|
||||||
|
"default": [],
|
||||||
|
"items": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"text",
|
||||||
|
"audio",
|
||||||
|
"image"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"description": "List of input modalities understood by the model."
|
||||||
|
},
|
||||||
|
"out": {
|
||||||
|
"type": "array",
|
||||||
|
"minItems": 1,
|
||||||
|
"uniqueItems": true,
|
||||||
|
"default": [],
|
||||||
|
"items": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"text",
|
||||||
|
"audio",
|
||||||
|
"image"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"description": "List of output modalities generated by the model."
|
||||||
|
},
|
||||||
|
"tools": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": false,
|
||||||
|
"description": "Whether the model supports function calling."
|
||||||
|
},
|
||||||
|
"reranker": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": false,
|
||||||
|
"description": "Whether the model supports the /v1/rerank endpoint."
|
||||||
|
},
|
||||||
|
"context": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
"default": 0,
|
||||||
|
"description": "Maximum token context length supported by the model."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"description": "Defines what the model accepts for input, output and other metadata. Used in v1/models to inform clients what the model can do. An empty capabilities block (all zero values) is treated as not configured."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"groups": {
|
"groups": {
|
||||||
"type": "object",
|
"$ref": "#/definitions/groupsConfig"
|
||||||
"additionalProperties": {
|
|
||||||
"type": "object",
|
|
||||||
"required": [
|
|
||||||
"members"
|
|
||||||
],
|
|
||||||
"properties": {
|
|
||||||
"swap": {
|
|
||||||
"type": "boolean",
|
|
||||||
"default": true,
|
|
||||||
"description": "Controls model swapping behaviour within the group. True: only one model runs at a time. False: all models can run together."
|
|
||||||
},
|
|
||||||
"exclusive": {
|
|
||||||
"type": "boolean",
|
|
||||||
"default": true,
|
|
||||||
"description": "Controls how the group affects other groups. True: causes all other groups to unload when this group runs a model. False: does not affect other groups."
|
|
||||||
},
|
|
||||||
"persistent": {
|
|
||||||
"type": "boolean",
|
|
||||||
"default": false,
|
|
||||||
"description": "Prevents other groups from unloading the models in this group. Does not affect individual model behaviour."
|
|
||||||
},
|
|
||||||
"members": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"description": "Array of model IDs that are members of this group. Model IDs must be defined in models."
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
|
|
||||||
},
|
},
|
||||||
"matrix": {
|
"matrix": {
|
||||||
"type": "object",
|
"$ref": "#/definitions/matrixConfig"
|
||||||
"description": "Solver-based alternative to groups. Declares valid combinations of concurrent models. The solver minimizes eviction cost when swapping. A config must use either groups or matrix, not both.",
|
|
||||||
"required": [
|
|
||||||
"vars",
|
|
||||||
"sets"
|
|
||||||
],
|
|
||||||
"properties": {
|
|
||||||
"vars": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "Short names for models. Keys must be alphanumeric, 1-8 characters. All sets and evict_costs must use these IDs.",
|
|
||||||
"minProperties": 1,
|
|
||||||
"additionalProperties": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"propertyNames": {
|
|
||||||
"pattern": "^[a-zA-Z0-9]{1,8}$"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"evict_costs": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "Relative cost of evicting a running model. Models not listed default to 1. Values must be positive integers.",
|
|
||||||
"additionalProperties": {
|
|
||||||
"type": "integer",
|
|
||||||
"minimum": 1
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"sets": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "Named sets of concurrent model combinations. Values are DSL strings using & (AND), | (OR), () (grouping), and +ref (inline another set). Definition order is used for tie-breaking.",
|
|
||||||
"minProperties": 1,
|
|
||||||
"additionalProperties": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"additionalProperties": false
|
|
||||||
},
|
},
|
||||||
"hooks": {
|
"hooks": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@@ -512,28 +571,123 @@
|
|||||||
},
|
},
|
||||||
"default": {},
|
"default": {},
|
||||||
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
|
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
|
||||||
|
},
|
||||||
|
"upstream": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Controls behaviour of the /upstream passthrough endpoint. Recommended to only use in special use cases; leaving it as the default will typically be the best experience.",
|
||||||
|
"properties": {
|
||||||
|
"ignorePaths": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"default": [
|
||||||
|
".*\\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$"
|
||||||
|
],
|
||||||
|
"description": "List of RE2 compatible regular expressions. Any request to a path matching any of the regular expressions will be ignored and not trigger a swap. When not specified, defaults to a pattern matching common static-asset suffixes (.js, .json, .css, .png, .gif, .jpg, .jpeg, .ico, .txt)."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"default": {}
|
||||||
|
},
|
||||||
|
"routing": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Canonical routing/scheduling configuration. Alternative to the legacy top-level 'groups'/'matrix' keys; a config must not use both styles.",
|
||||||
|
"properties": {
|
||||||
|
"scheduler": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Scheduler configuration. Decides the order in which queued requests are serviced.",
|
||||||
|
"properties": {
|
||||||
|
"use": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"fifo"
|
||||||
|
],
|
||||||
|
"default": "fifo",
|
||||||
|
"description": "Scheduler to use. Only 'fifo' is currently supported."
|
||||||
|
},
|
||||||
|
"settings": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"fifo": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"priority": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Per-model priority. Keys are model IDs, values are integers (default 0). Higher values are serviced first.",
|
||||||
|
"additionalProperties": {
|
||||||
|
"type": "integer"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false
|
||||||
|
},
|
||||||
|
"router": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Router configuration. Selects between the group and matrix swapping strategies.",
|
||||||
|
"properties": {
|
||||||
|
"use": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"group",
|
||||||
|
"matrix"
|
||||||
|
],
|
||||||
|
"default": "group",
|
||||||
|
"description": "Router to use. 'group' uses static groups, 'matrix' uses the solver-based swap matrix."
|
||||||
|
},
|
||||||
|
"settings": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"groups": {
|
||||||
|
"$ref": "#/definitions/groupsConfig"
|
||||||
|
},
|
||||||
|
"matrix": {
|
||||||
|
"$ref": "#/definitions/matrixConfig"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"allOf": [
|
"allOf": [
|
||||||
{
|
{
|
||||||
"if": {
|
"if": {
|
||||||
"required": ["groups"]
|
"required": [
|
||||||
|
"groups"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"then": {
|
"then": {
|
||||||
"not": {
|
"not": {
|
||||||
"required": ["matrix"]
|
"required": [
|
||||||
|
"matrix"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"if": {
|
"if": {
|
||||||
"required": ["matrix"]
|
"required": [
|
||||||
|
"matrix"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"then": {
|
"then": {
|
||||||
"not": {
|
"not": {
|
||||||
"required": ["groups"]
|
"required": [
|
||||||
|
"groups"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
+213
-87
@@ -134,6 +134,18 @@ apiKeys:
|
|||||||
- "${env.API_KEY_1}"
|
- "${env.API_KEY_1}"
|
||||||
- "${env.API_KEY_2}"
|
- "${env.API_KEY_2}"
|
||||||
|
|
||||||
|
# upstream: controls behaviour of the /upstream passthrough endpoint
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - recommended to only use in special use cases. Leaving it as the
|
||||||
|
# default will typically be the best experience
|
||||||
|
upstream:
|
||||||
|
# ignorePaths: list of RE2 compatible regular expressions
|
||||||
|
# - default: (see below)
|
||||||
|
# - any request to a path matching any of the regular expressions
|
||||||
|
# will be ignored and not trigger a swap
|
||||||
|
ignorePaths:
|
||||||
|
- '.*\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$'
|
||||||
|
|
||||||
# models: a dictionary of model configurations
|
# models: a dictionary of model configurations
|
||||||
# - required
|
# - required
|
||||||
# - each key is the model's ID, used in API requests
|
# - each key is the model's ID, used in API requests
|
||||||
@@ -312,6 +324,37 @@ models:
|
|||||||
tlsHandshake: 10
|
tlsHandshake: 10
|
||||||
idleConn: 90
|
idleConn: 90
|
||||||
|
|
||||||
|
# capabilities: defines what the model accepts for input, output and other metadata
|
||||||
|
# - optional; omitted or all-zero means no capabilities
|
||||||
|
# - used in v1/models to inform clients what the model can do
|
||||||
|
capabilities:
|
||||||
|
# in: list of modalities understood by the model
|
||||||
|
# - default: []
|
||||||
|
# - valid: text, audio, image
|
||||||
|
in:
|
||||||
|
- text
|
||||||
|
- audio
|
||||||
|
- image
|
||||||
|
# out: list of modalities generated by the model
|
||||||
|
# - default: []
|
||||||
|
# - valid: text, audio, image
|
||||||
|
out:
|
||||||
|
- text
|
||||||
|
- audio
|
||||||
|
- image
|
||||||
|
# tools: the model supports function calling
|
||||||
|
# - default: false
|
||||||
|
tools: true
|
||||||
|
|
||||||
|
# reranker: the model supports the /v1/rerank endpoint
|
||||||
|
# - default: false
|
||||||
|
reranker: false
|
||||||
|
|
||||||
|
# context: the maximum token context length supported
|
||||||
|
# - default: 0
|
||||||
|
# - must be an integer > 0
|
||||||
|
context: 32000
|
||||||
|
|
||||||
# Unlisted model example:
|
# Unlisted model example:
|
||||||
"qwen-unlisted":
|
"qwen-unlisted":
|
||||||
# unlisted: boolean, true or false
|
# unlisted: boolean, true or false
|
||||||
@@ -343,93 +386,6 @@ models:
|
|||||||
# - processes have 5 seconds to shutdown until forceful termination is attempted
|
# - processes have 5 seconds to shutdown until forceful termination is attempted
|
||||||
cmdStop: docker stop ${MODEL_ID}
|
cmdStop: docker stop ${MODEL_ID}
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# matrix: run concurrent models with a solver-based swap DSL
|
|
||||||
# =============================================================================
|
|
||||||
#
|
|
||||||
# Matrix or Groups?
|
|
||||||
#
|
|
||||||
# Groups are available and fully supported. The syntax may be easier to use
|
|
||||||
# for simple use cases.
|
|
||||||
#
|
|
||||||
# Documentation can be found here:
|
|
||||||
# https://github.com/mostlygeek/llama-swap/blob/40e39f7/config.example.yaml#L334-L396
|
|
||||||
#
|
|
||||||
# A config can only use a matrix (recommended) or groups. A configuration error
|
|
||||||
# will occur if both are defined. Groups is legacy but is fully supported with
|
|
||||||
# no plans to deprecate it.
|
|
||||||
#
|
|
||||||
# ~~~~~
|
|
||||||
#
|
|
||||||
# The matrix declares valid combinations of models that can run concurrently.
|
|
||||||
# When a model is requested, the solver finds the cheapest way to make it
|
|
||||||
# available by evicting as few (and least costly) running models as possible.
|
|
||||||
#
|
|
||||||
# Solver behavior:
|
|
||||||
# 1. Request arrives for model X
|
|
||||||
# 2. If X is already running, forward immediately. Done.
|
|
||||||
# 3. Find all sets containing X
|
|
||||||
# 4. For each candidate set, compute cost: sum of evict_costs for
|
|
||||||
# every running model NOT in that set
|
|
||||||
# 5. Pick lowest cost candidate. Ties broken by definition order.
|
|
||||||
# 6. Evict what needs to stop. Start X. Forward request.
|
|
||||||
#
|
|
||||||
# Subset semantics: a set [a, b, c] means any subset is valid.
|
|
||||||
# Only the requested model is started — others are not preloaded.
|
|
||||||
#
|
|
||||||
# A model not appearing in any set can only run alone.
|
|
||||||
#
|
|
||||||
matrix:
|
|
||||||
# vars: short names for models (alphanumeric, 1-8 chars)
|
|
||||||
# - required for sets and evict_costs settings
|
|
||||||
# - each entry is a short name to a real model ID. Do not use an alias
|
|
||||||
# - used to keep set DSL logic short and easier to read
|
|
||||||
# - sets and evict_costs only use identifiers defined in vars
|
|
||||||
vars:
|
|
||||||
g: gemma-model
|
|
||||||
q: qwen-model
|
|
||||||
m: mistral-model
|
|
||||||
v: voxtral-model
|
|
||||||
e: reranker-model
|
|
||||||
L: llama-70B
|
|
||||||
sd: stable-diffusion
|
|
||||||
|
|
||||||
# evict_costs: relative cost of losing a running model (default: 1)
|
|
||||||
evict_costs:
|
|
||||||
v: 50 # vllm backend, slow cold start
|
|
||||||
L: 30 # 70B weights, slow to load
|
|
||||||
|
|
||||||
# sets: named sets of concurrent model combinations
|
|
||||||
# Values are DSL strings with operators:
|
|
||||||
# & AND (models run together)
|
|
||||||
# | OR (alternatives)
|
|
||||||
# () grouping
|
|
||||||
# +ref inline another set's expression
|
|
||||||
#
|
|
||||||
# Expansion examples:
|
|
||||||
# "L" → [L]
|
|
||||||
# "a & b" → [a, b]
|
|
||||||
# "a | b" → [a], [b]
|
|
||||||
# "(a | b) & c" → [a, c], [b, c]
|
|
||||||
# "(a | b) & (c | d)" → [a,c], [a,d], [b,c], [b,d]
|
|
||||||
# "+llms & v" → expands llms inline, then applies & v
|
|
||||||
sets:
|
|
||||||
# LLM + TTS: switching between g/q/m won't evict v
|
|
||||||
# expands to: [g,v], [q,v], [m,v]
|
|
||||||
standard: "(g | q | m) & v"
|
|
||||||
|
|
||||||
# LLM + TTS + reranker
|
|
||||||
# expands to: [g,v,e], [q,v,e]
|
|
||||||
with_rerank: "(g | q) & v & e"
|
|
||||||
|
|
||||||
# LLM + image generation, no TTS
|
|
||||||
# expands to: [g,sd], [q,sd]
|
|
||||||
creative: "(g | q) & sd"
|
|
||||||
|
|
||||||
# 70B model uses all GPUs, can only run alone
|
|
||||||
# expands to: [L]
|
|
||||||
full: "L"
|
|
||||||
|
|
||||||
# hooks: a dictionary of event triggers and actions
|
# hooks: a dictionary of event triggers and actions
|
||||||
# - optional, default: empty dictionary
|
# - optional, default: empty dictionary
|
||||||
# - the only supported hook is on_startup
|
# - the only supported hook is on_startup
|
||||||
@@ -446,6 +402,176 @@ hooks:
|
|||||||
preload:
|
preload:
|
||||||
- "llama"
|
- "llama"
|
||||||
|
|
||||||
|
# routing:
|
||||||
|
# Controls how llama-swap decides which models can run at the same time and
|
||||||
|
# which get swapped out. Choose one of two swap engines:
|
||||||
|
#
|
||||||
|
# - group: the default engine. Simpler to configure. You define groups of
|
||||||
|
# models that run together, and loading one group typically unloads
|
||||||
|
# the others.
|
||||||
|
#
|
||||||
|
# - matrix: the newer engine. More involved to configure, but far more
|
||||||
|
# flexible. It uses a small expression language to describe which
|
||||||
|
# model combinations are allowed to run concurrently, enabling
|
||||||
|
# setups that groups cannot express.
|
||||||
|
#
|
||||||
|
# The routing section is optional.
|
||||||
|
routing:
|
||||||
|
router:
|
||||||
|
# use: a string defining which engine to use
|
||||||
|
# - optional, default: "group"
|
||||||
|
# - valid values: group, matrix
|
||||||
|
use: group
|
||||||
|
|
||||||
|
# settings: a dictionary of settings for the specific engines
|
||||||
|
settings:
|
||||||
|
# groups: a dictionary of named groups
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - lets you keep some models loaded while others swap out
|
||||||
|
# - every member must be a model ID defined in the models section
|
||||||
|
# - a model can belong to only one group
|
||||||
|
# - behaviour is set per group with the `swap`, `exclusive` and
|
||||||
|
# `persistent` fields
|
||||||
|
# - see issue #109 for details
|
||||||
|
#
|
||||||
|
# NOTE: the model names below are illustrative and are not defined above.
|
||||||
|
groups:
|
||||||
|
# group1 reproduces llama-swap's default behaviour: only one model
|
||||||
|
# runs at a time across the entire instance.
|
||||||
|
"group1":
|
||||||
|
# swap: how members of this group swap among themselves
|
||||||
|
# - optional, default: true
|
||||||
|
# - true: only one member runs at a time
|
||||||
|
# - false: all members can run together, no swapping
|
||||||
|
swap: true
|
||||||
|
|
||||||
|
# exclusive: how this group affects other groups
|
||||||
|
# - optional, default: true
|
||||||
|
# - true: running a member unloads every other group
|
||||||
|
# - false: running a member leaves other groups untouched
|
||||||
|
exclusive: true
|
||||||
|
|
||||||
|
# members: the model IDs in this group
|
||||||
|
# required
|
||||||
|
members:
|
||||||
|
- "llama"
|
||||||
|
- "qwen-unlisted"
|
||||||
|
|
||||||
|
# group2: members all run together, but loading any other group
|
||||||
|
# unloads them.
|
||||||
|
"group2":
|
||||||
|
# swap: false lets all members stay loaded at once
|
||||||
|
swap: false
|
||||||
|
|
||||||
|
# exclusive: false means requesting a member loads it without
|
||||||
|
# unloading any other group
|
||||||
|
exclusive: false
|
||||||
|
members:
|
||||||
|
- "docker-llama"
|
||||||
|
- "modelA"
|
||||||
|
- "modelB"
|
||||||
|
|
||||||
|
# forever: a persistent group that other groups can never unload.
|
||||||
|
"forever":
|
||||||
|
# persistent: other groups cannot unload this group's members
|
||||||
|
# - optional, default: false
|
||||||
|
# - has no effect on swapping within the group
|
||||||
|
persistent: true
|
||||||
|
|
||||||
|
# swap/exclusive: false keeps all members loaded and avoids
|
||||||
|
# unloading other groups
|
||||||
|
swap: false
|
||||||
|
exclusive: false
|
||||||
|
members:
|
||||||
|
- "forever-modelA"
|
||||||
|
- "forever-modelB"
|
||||||
|
- "forever-modelc"
|
||||||
|
|
||||||
|
# The matrix lists the model combinations that are allowed to run
|
||||||
|
# concurrently. When a model is requested, the solver makes room for it
|
||||||
|
# by evicting as few running models as possible, preferring to keep the
|
||||||
|
# costliest ones loaded.
|
||||||
|
#
|
||||||
|
# Solver behaviour:
|
||||||
|
# 1. A request arrives for model X.
|
||||||
|
# 2. If X is already running, forward the request. Done.
|
||||||
|
# 3. Collect every set that contains X.
|
||||||
|
# 4. For each set, add up the evict_costs of the running models that
|
||||||
|
# are NOT in that set — that is the set's cost.
|
||||||
|
# 5. Choose the lowest-cost set. Break ties by definition order.
|
||||||
|
# 6. Evict the models outside that set, start X, forward the request.
|
||||||
|
#
|
||||||
|
# Subset semantics: a set [a, b, c] also permits any subset of itself.
|
||||||
|
# Only the requested model is started; the others are not preloaded.
|
||||||
|
#
|
||||||
|
# A model that appears in no set can only run on its own.
|
||||||
|
#
|
||||||
|
matrix:
|
||||||
|
# vars: short aliases for model IDs (alphanumeric, 1-8 chars)
|
||||||
|
# - required: sets and evict_costs reference these names, not model IDs
|
||||||
|
# - map each short name to a real model ID (not a model alias)
|
||||||
|
# - keeps the set expressions short and readable
|
||||||
|
vars:
|
||||||
|
g: gemma-model
|
||||||
|
q: qwen-model
|
||||||
|
m: mistral-model
|
||||||
|
v: voxtral-model
|
||||||
|
e: reranker-model
|
||||||
|
L: llama-70B
|
||||||
|
sd: stable-diffusion
|
||||||
|
|
||||||
|
# evict_costs: relative cost of losing a running model (default: 1)
|
||||||
|
evict_costs:
|
||||||
|
v: 50 # vllm backend, slow cold start
|
||||||
|
L: 30 # 70B weights, slow to load
|
||||||
|
|
||||||
|
# sets: named combinations of models that may run together.
|
||||||
|
# Each value is an expression built from these operators:
|
||||||
|
# & AND (models run together)
|
||||||
|
# | OR (alternatives)
|
||||||
|
# () grouping
|
||||||
|
# +ref inline the expression of another set
|
||||||
|
#
|
||||||
|
# Each expression expands into one or more concrete sets:
|
||||||
|
# "L" → [L]
|
||||||
|
# "a & b" → [a, b]
|
||||||
|
# "a | b" → [a], [b]
|
||||||
|
# "(a | b) & c" → [a, c], [b, c]
|
||||||
|
# "(a | b) & (c | d)" → [a,c], [a,d], [b,c], [b,d]
|
||||||
|
# "+llms & v" → inline the llms set, then AND with v
|
||||||
|
sets:
|
||||||
|
# An LLM plus TTS. Switching between g/q/m keeps v loaded.
|
||||||
|
# expands to: [g,v], [q,v], [m,v]
|
||||||
|
standard: "(g | q | m) & v"
|
||||||
|
|
||||||
|
# An LLM plus TTS plus reranker.
|
||||||
|
# expands to: [g,v,e], [q,v,e]
|
||||||
|
with_rerank: "(g | q) & v & e"
|
||||||
|
|
||||||
|
# An LLM plus image generation, no TTS.
|
||||||
|
# expands to: [g,sd], [q,sd]
|
||||||
|
creative: "(g | q) & sd"
|
||||||
|
|
||||||
|
# The 70B model uses every GPU, so it can only run alone.
|
||||||
|
# expands to: [L]
|
||||||
|
full: "L"
|
||||||
|
|
||||||
|
# scheduler: how queued requests are ordered.
|
||||||
|
# The default and only valid scheduler is "fifo"
|
||||||
|
scheduler:
|
||||||
|
use: fifo
|
||||||
|
settings:
|
||||||
|
fifo:
|
||||||
|
# priority: a dictionary of model ID -> priority
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - models default to priority 0
|
||||||
|
# - higher priority requests are serviced first in the queue
|
||||||
|
priority:
|
||||||
|
A: 10
|
||||||
|
B: 5
|
||||||
|
C: 5
|
||||||
|
D: 1
|
||||||
|
|
||||||
# peers: a dictionary of remote peers and models they provide
|
# peers: a dictionary of remote peers and models they provide
|
||||||
# - optional, default empty dictionary
|
# - optional, default empty dictionary
|
||||||
# - peers can be another llama-swap
|
# - peers can be another llama-swap
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ require (
|
|||||||
github.com/charmbracelet/lipgloss v1.1.0
|
github.com/charmbracelet/lipgloss v1.1.0
|
||||||
github.com/fxamacker/cbor/v2 v2.9.1
|
github.com/fxamacker/cbor/v2 v2.9.1
|
||||||
github.com/gin-gonic/gin v1.10.0
|
github.com/gin-gonic/gin v1.10.0
|
||||||
|
github.com/google/jsonschema-go v0.4.3
|
||||||
github.com/klauspost/compress v1.18.5
|
github.com/klauspost/compress v1.18.5
|
||||||
github.com/shirou/gopsutil/v4 v4.26.4
|
github.com/shirou/gopsutil/v4 v4.26.4
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
|
|||||||
@@ -61,6 +61,8 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
|||||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
|
github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0=
|
||||||
|
github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE=
|
github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE=
|
||||||
|
|||||||
+102
-6
@@ -129,13 +129,16 @@ type Config struct {
|
|||||||
GlobalTTL int `yaml:"globalTTL"`
|
GlobalTTL int `yaml:"globalTTL"`
|
||||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||||
Profiles map[string][]string `yaml:"profiles"`
|
Profiles map[string][]string `yaml:"profiles"`
|
||||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
|
||||||
|
|
||||||
// swap matrix: solver-based alternative to groups
|
// routing is the canonical source for swap/scheduling configuration.
|
||||||
Matrix *MatrixConfig `yaml:"matrix"`
|
// New code must read Routing, never the backwards-compat fields below.
|
||||||
|
Routing RoutingConfig `yaml:"routing"`
|
||||||
|
|
||||||
// populated during validation when matrix is configured
|
// Groups and Matrix are permanent backwards-compat input fields for the
|
||||||
ExpandedSets []ExpandedSet `yaml:"-"`
|
// legacy top-level `groups:`/`matrix:` keys. They are normalized into
|
||||||
|
// Routing by LoadConfigFromReader. New code must not read them directly.
|
||||||
|
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||||
|
Matrix *MatrixConfig `yaml:"matrix"`
|
||||||
|
|
||||||
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
||||||
Macros MacroList `yaml:"macros"`
|
Macros MacroList `yaml:"macros"`
|
||||||
@@ -160,6 +163,38 @@ type Config struct {
|
|||||||
|
|
||||||
// support remote peers, see issue #433, #296
|
// support remote peers, see issue #433, #296
|
||||||
Peers PeerDictionaryConfig `yaml:"peers"`
|
Peers PeerDictionaryConfig `yaml:"peers"`
|
||||||
|
|
||||||
|
// upstream controls behaviour of the /upstream passthrough endpoint
|
||||||
|
Upstream UpstreamConfig `yaml:"upstream"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoutingConfig is the canonical, normalized routing/scheduling configuration.
|
||||||
|
type RoutingConfig struct {
|
||||||
|
Scheduler SchedulerConfig `yaml:"scheduler"`
|
||||||
|
Router RouterConfig `yaml:"router"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SchedulerConfig struct {
|
||||||
|
Use string `yaml:"use"` // default "fifo"
|
||||||
|
Settings SchedulerSettings `yaml:"settings"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SchedulerSettings struct {
|
||||||
|
Fifo FifoConfig `yaml:"fifo"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FifoConfig struct {
|
||||||
|
Priority map[string]int `yaml:"priority"` // model ID -> priority, default 0
|
||||||
|
}
|
||||||
|
|
||||||
|
type RouterConfig struct {
|
||||||
|
Use string `yaml:"use"` // "group" (default) | "matrix"
|
||||||
|
Settings RouterSettings `yaml:"settings"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RouterSettings struct {
|
||||||
|
Groups map[string]GroupConfig `yaml:"groups"`
|
||||||
|
Matrix *MatrixConfig `yaml:"matrix"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) RealModelName(search string) (string, bool) {
|
func (c *Config) RealModelName(search string) (string, bool) {
|
||||||
@@ -238,6 +273,12 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply default for upstream.ignorePaths when not specified. The default
|
||||||
|
// matches common static-asset suffixes so they do not trigger a swap.
|
||||||
|
if len(config.Upstream.IgnorePaths) == 0 {
|
||||||
|
config.Upstream.IgnorePaths = DefaultUpstreamIgnorePaths()
|
||||||
|
}
|
||||||
|
|
||||||
switch config.LogToStdout {
|
switch config.LogToStdout {
|
||||||
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
||||||
default:
|
default:
|
||||||
@@ -415,6 +456,10 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err = modelConfig.Capabilities.Validate(); err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s: %w", modelId, err)
|
||||||
|
}
|
||||||
|
|
||||||
// Validate SetParamsByID keys and values
|
// Validate SetParamsByID keys and values
|
||||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||||
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
|
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
|
||||||
@@ -455,6 +500,34 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
config.Models[modelId] = modelConfig
|
config.Models[modelId] = modelConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Normalize routing config. The legacy top-level `matrix`/`groups` keys and
|
||||||
|
// the new `routing.router` block are mutually exclusive: a config may use
|
||||||
|
// either style, never both.
|
||||||
|
hasTopLevel := config.Matrix != nil || len(config.Groups) > 0
|
||||||
|
rtr := config.Routing.Router
|
||||||
|
hasRouting := rtr.Use != "" || rtr.Settings.Matrix != nil || len(rtr.Settings.Groups) > 0
|
||||||
|
|
||||||
|
if hasTopLevel && hasRouting {
|
||||||
|
return Config{}, fmt.Errorf("config uses both the legacy top-level 'matrix'/'groups' keys and the new 'routing.router' block; please migrate the top-level keys into 'routing.router' and remove them")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasTopLevel {
|
||||||
|
// Both groups and matrix may be defined under routing.router.settings;
|
||||||
|
// routing.router.use selects which one is active, so there is no conflict.
|
||||||
|
rs := config.Routing.Router.Settings
|
||||||
|
switch config.Routing.Router.Use {
|
||||||
|
case "matrix":
|
||||||
|
if rs.Matrix == nil {
|
||||||
|
return Config{}, fmt.Errorf("routing.router.use is 'matrix' but routing.router.settings.matrix is not set")
|
||||||
|
}
|
||||||
|
config.Matrix = rs.Matrix
|
||||||
|
case "group", "":
|
||||||
|
config.Groups = rs.Groups
|
||||||
|
default:
|
||||||
|
return Config{}, fmt.Errorf("routing.router.use: unknown router %q (valid: group, matrix)", config.Routing.Router.Use)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// groups XOR matrix
|
// groups XOR matrix
|
||||||
if config.Matrix != nil && len(config.Groups) > 0 {
|
if config.Matrix != nil && len(config.Groups) > 0 {
|
||||||
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
|
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
|
||||||
@@ -465,7 +538,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return Config{}, fmt.Errorf("matrix: %w", err)
|
return Config{}, fmt.Errorf("matrix: %w", err)
|
||||||
}
|
}
|
||||||
config.ExpandedSets = expandedSets
|
config.Matrix.ExpandedSets = expandedSets
|
||||||
} else {
|
} else {
|
||||||
config = AddDefaultGroupToConfig(config)
|
config = AddDefaultGroupToConfig(config)
|
||||||
|
|
||||||
@@ -487,6 +560,29 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build the canonical Config.Routing from the effective result. Both legacy
|
||||||
|
// and new-style configs converge here. The Matrix pointer is shared so
|
||||||
|
// ExpandedSets stays in one place.
|
||||||
|
if config.Matrix != nil {
|
||||||
|
config.Routing.Router.Use = "matrix"
|
||||||
|
} else {
|
||||||
|
config.Routing.Router.Use = "group"
|
||||||
|
}
|
||||||
|
config.Routing.Router.Settings.Matrix = config.Matrix
|
||||||
|
config.Routing.Router.Settings.Groups = config.Groups
|
||||||
|
|
||||||
|
if config.Routing.Scheduler.Use == "" {
|
||||||
|
config.Routing.Scheduler.Use = "fifo"
|
||||||
|
}
|
||||||
|
if config.Routing.Scheduler.Use != "fifo" {
|
||||||
|
return Config{}, fmt.Errorf("routing.scheduler.use: unknown scheduler %q (valid: fifo)", config.Routing.Scheduler.Use)
|
||||||
|
}
|
||||||
|
for modelID := range config.Routing.Scheduler.Settings.Fifo.Priority {
|
||||||
|
if _, found := config.RealModelName(modelID); !found {
|
||||||
|
return Config{}, fmt.Errorf("routing.scheduler.settings.fifo.priority references unknown model %q", modelID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Clean up hooks preload
|
// Clean up hooks preload
|
||||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||||
var toPreload []string
|
var toPreload []string
|
||||||
|
|||||||
@@ -173,6 +173,25 @@ groups:
|
|||||||
IdleConn: 90,
|
IdleConn: 90,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
expectedGroups := map[string]GroupConfig{
|
||||||
|
DEFAULT_GROUP_ID: {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: true,
|
||||||
|
Members: []string{"model1", "model3"},
|
||||||
|
},
|
||||||
|
"group1": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: false,
|
||||||
|
Members: []string{"model2"},
|
||||||
|
},
|
||||||
|
"forever": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: false,
|
||||||
|
Persistent: true,
|
||||||
|
Members: []string{"model4"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
expected := Config{
|
expected := Config{
|
||||||
LogLevel: "info",
|
LogLevel: "info",
|
||||||
LogTimeFormat: "",
|
LogTimeFormat: "",
|
||||||
@@ -246,22 +265,19 @@ groups:
|
|||||||
"m2": "model2",
|
"m2": "model2",
|
||||||
"mthree": "model3",
|
"mthree": "model3",
|
||||||
},
|
},
|
||||||
Groups: map[string]GroupConfig{
|
Groups: expectedGroups,
|
||||||
DEFAULT_GROUP_ID: {
|
Upstream: UpstreamConfig{
|
||||||
Swap: true,
|
IgnorePaths: DefaultUpstreamIgnorePaths(),
|
||||||
Exclusive: true,
|
},
|
||||||
Members: []string{"model1", "model3"},
|
Routing: RoutingConfig{
|
||||||
|
Router: RouterConfig{
|
||||||
|
Use: "group",
|
||||||
|
Settings: RouterSettings{
|
||||||
|
Groups: expectedGroups,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"group1": {
|
Scheduler: SchedulerConfig{
|
||||||
Swap: true,
|
Use: "fifo",
|
||||||
Exclusive: false,
|
|
||||||
Members: []string{"model2"},
|
|
||||||
},
|
|
||||||
"forever": {
|
|
||||||
Swap: true,
|
|
||||||
Exclusive: false,
|
|
||||||
Persistent: true,
|
|
||||||
Members: []string{"model4"},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,60 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/jsonschema-go/jsonschema"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestConfig_ExampleMatchesSchema validates that config.example.yaml conforms to
|
||||||
|
// config-schema.json. Both files live at the repository root.
|
||||||
|
func TestConfig_ExampleMatchesSchema(t *testing.T) {
|
||||||
|
const (
|
||||||
|
schemaPath = "../../config-schema.json"
|
||||||
|
examplePath = "../../config.example.yaml"
|
||||||
|
)
|
||||||
|
|
||||||
|
schemaBytes, err := os.ReadFile(schemaPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading %s: %v", schemaPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var schema jsonschema.Schema
|
||||||
|
if err := json.Unmarshal(schemaBytes, &schema); err != nil {
|
||||||
|
t.Fatalf("unmarshalling schema: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resolved, err := schema.Resolve(&jsonschema.ResolveOptions{
|
||||||
|
BaseURI: "https://github.com/mostlygeek/llama-swap/",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("resolving schema: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
exampleBytes, err := os.ReadFile(examplePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading %s: %v", examplePath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert YAML to a JSON-like value so numbers and keys match what the
|
||||||
|
// validator expects.
|
||||||
|
var yamlValue any
|
||||||
|
if err := yaml.Unmarshal(exampleBytes, &yamlValue); err != nil {
|
||||||
|
t.Fatalf("unmarshalling example yaml: %v", err)
|
||||||
|
}
|
||||||
|
jsonBytes, err := json.Marshal(yamlValue)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("converting example to json: %v", err)
|
||||||
|
}
|
||||||
|
var instance any
|
||||||
|
if err := json.Unmarshal(jsonBytes, &instance); err != nil {
|
||||||
|
t.Fatalf("unmarshalling example json: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := resolved.Validate(instance); err != nil {
|
||||||
|
t.Fatalf("config.example.yaml does not match config-schema.json:\n%v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1544,3 +1544,174 @@ peers:
|
|||||||
assert.Equal(t, 1, peerConfig.Timeouts.ExpectContinue)
|
assert.Equal(t, 1, peerConfig.Timeouts.ExpectContinue)
|
||||||
assert.Equal(t, 90, peerConfig.Timeouts.IdleConn)
|
assert.Equal(t, 90, peerConfig.Timeouts.IdleConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// twoModels is a minimal models block reused by the routing tests below.
|
||||||
|
const twoModels = `
|
||||||
|
models:
|
||||||
|
gemma:
|
||||||
|
cmd: echo gemma
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
qwen:
|
||||||
|
cmd: echo qwen
|
||||||
|
proxy: http://localhost:8081
|
||||||
|
`
|
||||||
|
|
||||||
|
func TestConfig_Routing_LegacyTopLevelGroups(t *testing.T) {
|
||||||
|
yaml := twoModels + `
|
||||||
|
groups:
|
||||||
|
g1:
|
||||||
|
members: [gemma, qwen]
|
||||||
|
`
|
||||||
|
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
||||||
|
// default group injected for orphaned models (none here) still leaves g1
|
||||||
|
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
|
||||||
|
assert.Equal(t, "fifo", cfg.Routing.Scheduler.Use)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_Routing_LegacyTopLevelMatrix(t *testing.T) {
|
||||||
|
yaml := twoModels + `
|
||||||
|
matrix:
|
||||||
|
vars:
|
||||||
|
g: gemma
|
||||||
|
q: qwen
|
||||||
|
sets:
|
||||||
|
combo: "g | q"
|
||||||
|
`
|
||||||
|
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "matrix", cfg.Routing.Router.Use)
|
||||||
|
require.NotNil(t, cfg.Routing.Router.Settings.Matrix)
|
||||||
|
assert.Len(t, cfg.Routing.Router.Settings.Matrix.ExpandedSets, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_Routing_RouterUseMatrix(t *testing.T) {
|
||||||
|
yaml := twoModels + `
|
||||||
|
routing:
|
||||||
|
router:
|
||||||
|
use: matrix
|
||||||
|
settings:
|
||||||
|
matrix:
|
||||||
|
vars:
|
||||||
|
g: gemma
|
||||||
|
q: qwen
|
||||||
|
sets:
|
||||||
|
combo: "g | q"
|
||||||
|
`
|
||||||
|
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "matrix", cfg.Routing.Router.Use)
|
||||||
|
require.NotNil(t, cfg.Routing.Router.Settings.Matrix)
|
||||||
|
assert.Len(t, cfg.Routing.Router.Settings.Matrix.ExpandedSets, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_Routing_RouterUseGroup(t *testing.T) {
|
||||||
|
yaml := twoModels + `
|
||||||
|
routing:
|
||||||
|
router:
|
||||||
|
use: group
|
||||||
|
settings:
|
||||||
|
groups:
|
||||||
|
g1:
|
||||||
|
members: [gemma, qwen]
|
||||||
|
`
|
||||||
|
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
||||||
|
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_Routing_DefaultsToGroup(t *testing.T) {
|
||||||
|
cfg, err := LoadConfigFromReader(strings.NewReader(twoModels))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
||||||
|
assert.Equal(t, "fifo", cfg.Routing.Scheduler.Use)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_Routing_LegacyAndRoutingConflict(t *testing.T) {
|
||||||
|
yaml := twoModels + `
|
||||||
|
groups:
|
||||||
|
g1:
|
||||||
|
members: [gemma, qwen]
|
||||||
|
routing:
|
||||||
|
router:
|
||||||
|
use: group
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "migrate")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_Routing_RouterUseMatrixWithoutSettings(t *testing.T) {
|
||||||
|
yaml := twoModels + `
|
||||||
|
routing:
|
||||||
|
router:
|
||||||
|
use: matrix
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "routing.router.settings.matrix is not set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Both groups and matrix may be defined under routing.router.settings;
|
||||||
|
// routing.router.use selects which one is active.
|
||||||
|
func TestConfig_Routing_RouterSettingsBothGroupsAndMatrix(t *testing.T) {
|
||||||
|
yaml := twoModels + `
|
||||||
|
routing:
|
||||||
|
router:
|
||||||
|
use: group
|
||||||
|
settings:
|
||||||
|
groups:
|
||||||
|
g1:
|
||||||
|
members: [gemma, qwen]
|
||||||
|
matrix:
|
||||||
|
sets:
|
||||||
|
s: "gemma"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||||
|
require.NoError(t, err)
|
||||||
|
// use: group means groups are active and matrix is ignored
|
||||||
|
assert.Equal(t, "group", config.Routing.Router.Use)
|
||||||
|
assert.Nil(t, config.Matrix)
|
||||||
|
assert.Contains(t, config.Groups, "g1")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_Routing_UnknownRouter(t *testing.T) {
|
||||||
|
yaml := twoModels + `
|
||||||
|
routing:
|
||||||
|
router:
|
||||||
|
use: bogus
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "unknown router")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_Routing_FifoPriorityUnknownModel(t *testing.T) {
|
||||||
|
yaml := twoModels + `
|
||||||
|
routing:
|
||||||
|
scheduler:
|
||||||
|
settings:
|
||||||
|
fifo:
|
||||||
|
priority:
|
||||||
|
nope: 5
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "unknown model")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_Routing_FifoPriorityKnownModel(t *testing.T) {
|
||||||
|
yaml := twoModels + `
|
||||||
|
routing:
|
||||||
|
scheduler:
|
||||||
|
settings:
|
||||||
|
fifo:
|
||||||
|
priority:
|
||||||
|
gemma: 5
|
||||||
|
`
|
||||||
|
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 5, cfg.Routing.Scheduler.Settings.Fifo.Priority["gemma"])
|
||||||
|
}
|
||||||
|
|||||||
@@ -165,6 +165,25 @@ groups:
|
|||||||
IdleConn: 90,
|
IdleConn: 90,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
expectedGroups := map[string]GroupConfig{
|
||||||
|
DEFAULT_GROUP_ID: {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: true,
|
||||||
|
Members: []string{"model1", "model3"},
|
||||||
|
},
|
||||||
|
"group1": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: false,
|
||||||
|
Members: []string{"model2"},
|
||||||
|
},
|
||||||
|
"forever": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: false,
|
||||||
|
Persistent: true,
|
||||||
|
Members: []string{"model4"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
expected := Config{
|
expected := Config{
|
||||||
LogLevel: "info",
|
LogLevel: "info",
|
||||||
LogTimeFormat: "",
|
LogTimeFormat: "",
|
||||||
@@ -235,22 +254,19 @@ groups:
|
|||||||
"m2": "model2",
|
"m2": "model2",
|
||||||
"mthree": "model3",
|
"mthree": "model3",
|
||||||
},
|
},
|
||||||
Groups: map[string]GroupConfig{
|
Groups: expectedGroups,
|
||||||
DEFAULT_GROUP_ID: {
|
Upstream: UpstreamConfig{
|
||||||
Swap: true,
|
IgnorePaths: DefaultUpstreamIgnorePaths(),
|
||||||
Exclusive: true,
|
},
|
||||||
Members: []string{"model1", "model3"},
|
Routing: RoutingConfig{
|
||||||
|
Router: RouterConfig{
|
||||||
|
Use: "group",
|
||||||
|
Settings: RouterSettings{
|
||||||
|
Groups: expectedGroups,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"group1": {
|
Scheduler: SchedulerConfig{
|
||||||
Swap: true,
|
Use: "fifo",
|
||||||
Exclusive: false,
|
|
||||||
Members: []string{"model2"},
|
|
||||||
},
|
|
||||||
"forever": {
|
|
||||||
Swap: true,
|
|
||||||
Exclusive: false,
|
|
||||||
Persistent: true,
|
|
||||||
Members: []string{"model4"},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,9 @@ type MatrixConfig struct {
|
|||||||
Var map[string]string `yaml:"vars"`
|
Var map[string]string `yaml:"vars"`
|
||||||
EvictCosts map[string]int `yaml:"evict_costs"`
|
EvictCosts map[string]int `yaml:"evict_costs"`
|
||||||
Sets OrderedSets `yaml:"sets"`
|
Sets OrderedSets `yaml:"sets"`
|
||||||
|
|
||||||
|
// populated by ValidateMatrix; not settable from yaml
|
||||||
|
ExpandedSets []ExpandedSet `yaml:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetEntry is a single named set with its DSL expression.
|
// SetEntry is a single named set with its DSL expression.
|
||||||
|
|||||||
@@ -289,7 +289,9 @@ matrix:
|
|||||||
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, cfg.Matrix)
|
assert.NotNil(t, cfg.Matrix)
|
||||||
assert.Len(t, cfg.ExpandedSets, 2)
|
assert.Len(t, cfg.Matrix.ExpandedSets, 2)
|
||||||
|
assert.Equal(t, "matrix", cfg.Routing.Router.Use)
|
||||||
|
assert.Len(t, cfg.Routing.Router.Settings.Matrix.ExpandedSets, 2)
|
||||||
// Groups should be empty when matrix is used
|
// Groups should be empty when matrix is used
|
||||||
assert.Empty(t, cfg.Groups)
|
assert.Empty(t, cfg.Groups)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -9,6 +10,47 @@ const (
|
|||||||
MODEL_CONFIG_DEFAULT_TTL = -1
|
MODEL_CONFIG_DEFAULT_TTL = -1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var validModalities = map[string]struct{}{
|
||||||
|
"text": {},
|
||||||
|
"audio": {},
|
||||||
|
"image": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelCapConfig defines what modalities and features a model supports.
|
||||||
|
// Used in /v1/models to inform clients. An empty block (all zero values) is
|
||||||
|
// treated as not configured.
|
||||||
|
type ModelCapConfig struct {
|
||||||
|
In []string `yaml:"in"`
|
||||||
|
Out []string `yaml:"out"`
|
||||||
|
Tools bool `yaml:"tools"`
|
||||||
|
Reranker bool `yaml:"reranker"`
|
||||||
|
Context int `yaml:"context"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty returns true when all fields are at their zero values.
|
||||||
|
func (c ModelCapConfig) Empty() bool {
|
||||||
|
return len(c.In) == 0 && len(c.Out) == 0 && !c.Tools && !c.Reranker && c.Context == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate checks that all modality values are recognized and context is
|
||||||
|
// non-negative. Returns an error if any value is invalid.
|
||||||
|
func (c ModelCapConfig) Validate() error {
|
||||||
|
for _, m := range c.In {
|
||||||
|
if _, ok := validModalities[m]; !ok {
|
||||||
|
return fmt.Errorf("capabilities.in: invalid modality %q, must be one of: text, audio, image", m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, m := range c.Out {
|
||||||
|
if _, ok := validModalities[m]; !ok {
|
||||||
|
return fmt.Errorf("capabilities.out: invalid modality %q, must be one of: text, audio, image", m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.Context < 0 {
|
||||||
|
return errors.New("capabilities.context: must be >= 0")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// TimeoutsConfig holds timeout settings for proxy connections
|
// TimeoutsConfig holds timeout settings for proxy connections
|
||||||
// 0 = no timeout
|
// 0 = no timeout
|
||||||
type TimeoutsConfig struct {
|
type TimeoutsConfig struct {
|
||||||
@@ -55,6 +97,9 @@ type ModelConfig struct {
|
|||||||
// Timeout settings for proxy connections
|
// Timeout settings for proxy connections
|
||||||
Timeouts TimeoutsConfig `yaml:"timeouts"`
|
Timeouts TimeoutsConfig `yaml:"timeouts"`
|
||||||
|
|
||||||
|
// Capabilities defines what modalities and features the model supports.
|
||||||
|
Capabilities ModelCapConfig `yaml:"capabilities"`
|
||||||
|
|
||||||
// Copy of HealthCheckTimeout from global config
|
// Copy of HealthCheckTimeout from global config
|
||||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ models:
|
|||||||
stop:
|
stop:
|
||||||
- "<|end|>"
|
- "<|end|>"
|
||||||
- "<|stop|>"
|
- "<|stop|>"
|
||||||
`
|
`
|
||||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
@@ -170,3 +170,167 @@ models:
|
|||||||
assert.Equal(t, 0.7, setParams["temperature"])
|
assert.Equal(t, 0.7, setParams["temperature"])
|
||||||
assert.Equal(t, 0.9, setParams["top_p"])
|
assert.Equal(t, 0.9, setParams["top_p"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfig_ModelCapabilities(t *testing.T) {
|
||||||
|
t.Run("all fields", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
capabilities:
|
||||||
|
in:
|
||||||
|
- text
|
||||||
|
- audio
|
||||||
|
- image
|
||||||
|
out:
|
||||||
|
- text
|
||||||
|
- audio
|
||||||
|
- image
|
||||||
|
tools: true
|
||||||
|
context: 32000
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
mc := config.Models["model1"]
|
||||||
|
assert.False(t, mc.Capabilities.Empty())
|
||||||
|
assert.Equal(t, []string{"text", "audio", "image"}, mc.Capabilities.In)
|
||||||
|
assert.Equal(t, []string{"text", "audio", "image"}, mc.Capabilities.Out)
|
||||||
|
assert.True(t, mc.Capabilities.Tools)
|
||||||
|
assert.Equal(t, 32000, mc.Capabilities.Context)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("partial fields", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
capabilities:
|
||||||
|
tools: true
|
||||||
|
context: 8192
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
mc := config.Models["model1"]
|
||||||
|
assert.False(t, mc.Capabilities.Empty())
|
||||||
|
assert.Nil(t, mc.Capabilities.In)
|
||||||
|
assert.Nil(t, mc.Capabilities.Out)
|
||||||
|
assert.True(t, mc.Capabilities.Tools)
|
||||||
|
assert.Equal(t, 8192, mc.Capabilities.Context)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("not set", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
mc := config.Models["model1"]
|
||||||
|
assert.True(t, mc.Capabilities.Empty())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("tools false is empty", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
capabilities:
|
||||||
|
tools: false
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
mc := config.Models["model1"]
|
||||||
|
assert.True(t, mc.Capabilities.Empty())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("reranker true is not empty", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
capabilities:
|
||||||
|
reranker: true
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
mc := config.Models["model1"]
|
||||||
|
assert.False(t, mc.Capabilities.Empty())
|
||||||
|
assert.True(t, mc.Capabilities.Reranker)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("reranker false is empty", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
capabilities:
|
||||||
|
reranker: false
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
mc := config.Models["model1"]
|
||||||
|
assert.True(t, mc.Capabilities.Empty())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_ModelCapabilities_Validate(t *testing.T) {
|
||||||
|
t.Run("valid_modalities", func(t *testing.T) {
|
||||||
|
caps := ModelCapConfig{
|
||||||
|
In: []string{"text", "image"},
|
||||||
|
Out: []string{"text", "audio"},
|
||||||
|
Tools: true,
|
||||||
|
Context: 100000,
|
||||||
|
}
|
||||||
|
assert.NoError(t, caps.Validate())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty_is_valid", func(t *testing.T) {
|
||||||
|
caps := ModelCapConfig{}
|
||||||
|
assert.NoError(t, caps.Validate())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid_in_modality", func(t *testing.T) {
|
||||||
|
caps := ModelCapConfig{In: []string{"video"}}
|
||||||
|
err := caps.Validate()
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "capabilities.in")
|
||||||
|
assert.Contains(t, err.Error(), "video")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid_out_modality", func(t *testing.T) {
|
||||||
|
caps := ModelCapConfig{Out: []string{"video"}}
|
||||||
|
err := caps.Validate()
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "capabilities.out")
|
||||||
|
assert.Contains(t, err.Error(), "video")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("negative_context", func(t *testing.T) {
|
||||||
|
caps := ModelCapConfig{Context: -1}
|
||||||
|
err := caps.Validate()
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "capabilities.context")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rejects_invalid_at_load", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
capabilities:
|
||||||
|
in:
|
||||||
|
- text
|
||||||
|
- video
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "video")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,55 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultUpstreamIgnorePathsPattern is the default regular expression applied
|
||||||
|
// to upstream.ignorePaths when the section is empty or absent from the config.
|
||||||
|
// It matches common static-asset suffixes so requests for .js/.css/.png/etc.
|
||||||
|
// files do not trigger a model swap.
|
||||||
|
const DefaultUpstreamIgnorePathsPattern = `.*\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$`
|
||||||
|
|
||||||
|
// DefaultUpstreamIgnorePaths returns the default compiled ignore paths used
|
||||||
|
// when upstream.ignorePaths is not specified in the config. The returned slice
|
||||||
|
// is fresh so callers may mutate it without affecting other configs.
|
||||||
|
func DefaultUpstreamIgnorePaths() []*regexp.Regexp {
|
||||||
|
return []*regexp.Regexp{regexp.MustCompile(DefaultUpstreamIgnorePathsPattern)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamConfig controls behaviour of the /upstream passthrough endpoint.
|
||||||
|
type UpstreamConfig struct {
|
||||||
|
// IgnorePaths is a slice of compiled regular expressions. Any request to
|
||||||
|
// /upstream/<model>/<path> whose remaining path matches any of these
|
||||||
|
// expressions will be ignored and not trigger a swap. When the config
|
||||||
|
// does not specify any patterns, DefaultUpstreamIgnorePaths is applied.
|
||||||
|
IgnorePaths []*regexp.Regexp `yaml:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// rawUpstreamConfig is the intermediate form used to unmarshal the YAML into
|
||||||
|
// plain strings, which are then compiled into *regexp.Regexp.
|
||||||
|
type rawUpstreamConfig struct {
|
||||||
|
IgnorePaths []string `yaml:"ignorePaths"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalYAML compiles each ignorePaths entry into a *regexp.Regexp. If any
|
||||||
|
// entry fails to compile, an error is returned.
|
||||||
|
func (u *UpstreamConfig) UnmarshalYAML(value *yaml.Node) error {
|
||||||
|
var raw rawUpstreamConfig
|
||||||
|
if err := value.Decode(&raw); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
patterns := make([]*regexp.Regexp, 0, len(raw.IgnorePaths))
|
||||||
|
for _, p := range raw.IgnorePaths {
|
||||||
|
re, err := regexp.Compile(p)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("upstream.ignorePaths: invalid regular expression %q: %w", p, err)
|
||||||
|
}
|
||||||
|
patterns = append(patterns, re)
|
||||||
|
}
|
||||||
|
u.IgnorePaths = patterns
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
const upstreamConfigHeader = `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
`
|
||||||
|
|
||||||
|
func TestConfig_UpstreamIgnorePaths_DefaultWhenAbsent(t *testing.T) {
|
||||||
|
// When upstream is not specified at all, the default pattern is applied.
|
||||||
|
content := upstreamConfigHeader
|
||||||
|
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, cfg.Upstream.IgnorePaths, 1)
|
||||||
|
|
||||||
|
def := cfg.Upstream.IgnorePaths[0]
|
||||||
|
assert.IsType(t, ®exp.Regexp{}, def)
|
||||||
|
assert.Equal(t, DefaultUpstreamIgnorePathsPattern, def.String())
|
||||||
|
|
||||||
|
// The default matches common static-asset suffixes.
|
||||||
|
assert.True(t, def.MatchString("/foo.js"))
|
||||||
|
assert.True(t, def.MatchString("/bar/baz.json"))
|
||||||
|
assert.True(t, def.MatchString("/static/img.png"))
|
||||||
|
assert.True(t, def.MatchString("/notes.txt"))
|
||||||
|
assert.True(t, def.MatchString("/favicon.ico"))
|
||||||
|
// And does not match inference API paths.
|
||||||
|
assert.False(t, def.MatchString("/v1/chat/completions"))
|
||||||
|
assert.False(t, def.MatchString("/v1/models"))
|
||||||
|
assert.False(t, def.MatchString("/health"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_UpstreamIgnorePaths_DefaultWhenSectionEmpty(t *testing.T) {
|
||||||
|
// When upstream is present but ignorePaths is omitted, the default is still
|
||||||
|
// applied.
|
||||||
|
content := `upstream: {}` + "\n" + upstreamConfigHeader
|
||||||
|
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, cfg.Upstream.IgnorePaths, 1)
|
||||||
|
assert.Equal(t, DefaultUpstreamIgnorePathsPattern, cfg.Upstream.IgnorePaths[0].String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_UpstreamIgnorePaths_Compiles(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
upstream:
|
||||||
|
ignorePaths:
|
||||||
|
- ".*\\.(js|json|css|png|gif|jpg|jpeg|txt)$"
|
||||||
|
- "^/static/.*"
|
||||||
|
` + upstreamConfigHeader
|
||||||
|
|
||||||
|
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, cfg.Upstream.IgnorePaths, 2)
|
||||||
|
|
||||||
|
// Verify the patterns are compiled into *regexp.Regexp and match as expected.
|
||||||
|
assert.True(t, cfg.Upstream.IgnorePaths[0].MatchString("/foo.js"))
|
||||||
|
assert.True(t, cfg.Upstream.IgnorePaths[0].MatchString("/bar/baz.json"))
|
||||||
|
assert.False(t, cfg.Upstream.IgnorePaths[0].MatchString("/v1/chat/completions"))
|
||||||
|
assert.True(t, cfg.Upstream.IgnorePaths[1].MatchString("/static/foo.png"))
|
||||||
|
assert.False(t, cfg.Upstream.IgnorePaths[1].MatchString("/v1/chat/completions"))
|
||||||
|
|
||||||
|
// Confirm the type is *regexp.Regexp to satisfy the API contract.
|
||||||
|
for _, re := range cfg.Upstream.IgnorePaths {
|
||||||
|
assert.IsType(t, ®exp.Regexp{}, re)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_UpstreamIgnorePaths_InvalidRegexReturnsError(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
upstream:
|
||||||
|
ignorePaths:
|
||||||
|
- "[invalid("
|
||||||
|
` + upstreamConfigHeader
|
||||||
|
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "upstream.ignorePaths")
|
||||||
|
assert.Contains(t, err.Error(), "invalid regular expression")
|
||||||
|
}
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
package perf
|
||||||
|
|
||||||
|
type LUID struct {
|
||||||
|
LowPart uint32
|
||||||
|
HighPart int32
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxEnumAdapters = 16
|
||||||
|
|
||||||
|
type D3DKMT_ENUMADAPTERS2 struct {
|
||||||
|
NumAdapters uint32
|
||||||
|
pAdapters uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
type D3DKMT_ADAPTERINFO struct {
|
||||||
|
hAdapter uint32
|
||||||
|
AdapterLuid LUID
|
||||||
|
NumOfSources uint32
|
||||||
|
bPresentMoveRegionsPreferred int32
|
||||||
|
}
|
||||||
|
|
||||||
|
type D3DKMT_OPENADAPTERFROMLUID struct {
|
||||||
|
AdapterLuid LUID
|
||||||
|
hAdapter uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type D3DKMT_CLOSEADAPTER struct {
|
||||||
|
hAdapter uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type KMTQUERYADAPTERINFOTYPE int32
|
||||||
|
|
||||||
|
const (
|
||||||
|
KMTQAITYPE_UMDRIVERPRIVATE KMTQUERYADAPTERINFOTYPE = 0
|
||||||
|
KMTQAITYPE_ADAPTERREGISTRYINFO KMTQUERYADAPTERINFOTYPE = 8
|
||||||
|
KMTQAITYPE_DRIVERVERSION KMTQUERYADAPTERINFOTYPE = 13
|
||||||
|
KMTQAITYPE_PHYSICALADAPTERDEVICEIDS KMTQUERYADAPTERINFOTYPE = 31
|
||||||
|
KMTQAITYPE_NODEPERFDATA KMTQUERYADAPTERINFOTYPE = 61
|
||||||
|
KMTQAITYPE_ADAPTERPERFDATA KMTQUERYADAPTERINFOTYPE = 62
|
||||||
|
KMTQAITYPE_ADAPTERPERFDATA_CAPS KMTQUERYADAPTERINFOTYPE = 63
|
||||||
|
)
|
||||||
|
|
||||||
|
type D3DKMT_QUERYADAPTERINFO struct {
|
||||||
|
hAdapter uint32
|
||||||
|
Type KMTQUERYADAPTERINFOTYPE
|
||||||
|
pPrivateDriverData uintptr
|
||||||
|
PrivateDriverDataSize uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type D3DKMT_ADAPTER_PERFDATA struct {
|
||||||
|
PhysicalAdapterIndex uint32
|
||||||
|
MemoryFrequency uint64
|
||||||
|
MaxMemoryFrequency uint64
|
||||||
|
MaxMemoryFrequencyOC uint64
|
||||||
|
MemoryBandwidth uint64
|
||||||
|
PCIEBandwidth uint64
|
||||||
|
FanRPM uint32
|
||||||
|
Power uint32
|
||||||
|
Temperature uint32
|
||||||
|
PowerStateOverride byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type D3DKMT_QUERYSTATISTICS_TYPE int32
|
||||||
|
|
||||||
|
const (
|
||||||
|
D3DKMT_QUERYSTATISTICS_ADAPTER D3DKMT_QUERYSTATISTICS_TYPE = 0
|
||||||
|
D3DKMT_QUERYSTATISTICS_PROCESS D3DKMT_QUERYSTATISTICS_TYPE = 1
|
||||||
|
D3DKMT_QUERYSTATISTICS_PROCESS_ADAPTER D3DKMT_QUERYSTATISTICS_TYPE = 2
|
||||||
|
D3DKMT_QUERYSTATISTICS_SEGMENT D3DKMT_QUERYSTATISTICS_TYPE = 3
|
||||||
|
D3DKMT_QUERYSTATISTICS_PROCESS_SEGMENT D3DKMT_QUERYSTATISTICS_TYPE = 4
|
||||||
|
D3DKMT_QUERYSTATISTICS_NODE D3DKMT_QUERYSTATISTICS_TYPE = 5
|
||||||
|
D3DKMT_QUERYSTATISTICS_PROCESS_NODE D3DKMT_QUERYSTATISTICS_TYPE = 6
|
||||||
|
D3DKMT_QUERYSTATISTICS_VIDPNSOURCE D3DKMT_QUERYSTATISTICS_TYPE = 7
|
||||||
|
D3DKMT_QUERYSTATISTICS_PROCESS_VIDPNSOURCE D3DKMT_QUERYSTATISTICS_TYPE = 8
|
||||||
|
)
|
||||||
|
|
||||||
|
type D3DKMT_ADAPTER_PERFDATACAPS struct {
|
||||||
|
PhysicalAdapterIndex uint32
|
||||||
|
MaxMemoryBandwidth uint64
|
||||||
|
MaxPCIEBandwidth uint64
|
||||||
|
MaxFanRPM uint32
|
||||||
|
TemperatureMax uint32
|
||||||
|
TemperatureWarning uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type D3DKMT_QUERYSTATISTICS_QUERY_SEGMENT struct {
|
||||||
|
SegmentId uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type D3DKMT_QUERYSTATISTICS_QUERY_NODE struct {
|
||||||
|
NodeId uint32
|
||||||
|
}
|
||||||
@@ -0,0 +1,529 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package perf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
d3dkmDLL *windows.LazyDLL
|
||||||
|
procEnumAdapters2 *windows.LazyProc
|
||||||
|
procOpenAdapterFromLuid *windows.LazyProc
|
||||||
|
procCloseAdapter *windows.LazyProc
|
||||||
|
procQueryAdapterInfo *windows.LazyProc
|
||||||
|
procQueryStatistics *windows.LazyProc
|
||||||
|
d3dkmtInitOnce sync.Once
|
||||||
|
d3dkmtInitErr error
|
||||||
|
)
|
||||||
|
|
||||||
|
// initD3DKMT lazily loads gdi32.dll and resolves D3DKMT function pointers.
|
||||||
|
// Safe for concurrent use via sync.Once.
|
||||||
|
func initD3DKMT() error {
|
||||||
|
d3dkmtInitOnce.Do(func() {
|
||||||
|
d3dkmDLL = windows.NewLazySystemDLL("gdi32.dll")
|
||||||
|
|
||||||
|
procEnumAdapters2 = d3dkmDLL.NewProc("D3DKMTEnumAdapters2")
|
||||||
|
procOpenAdapterFromLuid = d3dkmDLL.NewProc("D3DKMTOpenAdapterFromLuid")
|
||||||
|
procCloseAdapter = d3dkmDLL.NewProc("D3DKMTCloseAdapter")
|
||||||
|
procQueryAdapterInfo = d3dkmDLL.NewProc("D3DKMTQueryAdapterInfo")
|
||||||
|
procQueryStatistics = d3dkmDLL.NewProc("D3DKMTQueryStatistics")
|
||||||
|
|
||||||
|
for name, p := range map[string]*windows.LazyProc{
|
||||||
|
"D3DKMTEnumAdapters2": procEnumAdapters2,
|
||||||
|
"D3DKMTOpenAdapterFromLuid": procOpenAdapterFromLuid,
|
||||||
|
"D3DKMTCloseAdapter": procCloseAdapter,
|
||||||
|
"D3DKMTQueryAdapterInfo": procQueryAdapterInfo,
|
||||||
|
"D3DKMTQueryStatistics": procQueryStatistics,
|
||||||
|
} {
|
||||||
|
if err := p.Find(); err != nil {
|
||||||
|
d3dkmtInitErr = fmt.Errorf("D3DKMT %s not found: %w", name, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return d3dkmtInitErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// ntstatusCall invokes a D3DKMT function and returns a non-nil error if the
|
||||||
|
// NTSTATUS result is not STATUS_SUCCESS (0).
|
||||||
|
func ntstatusCall(proc *windows.LazyProc, arg unsafe.Pointer) error {
|
||||||
|
ret, _, _ := proc.Call(uintptr(arg))
|
||||||
|
if ret != 0 {
|
||||||
|
return fmt.Errorf("NTSTATUS 0x%08x", uint32(ret))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// d3dkmEnumerateAdapters enumerates all available graphics adapters via
|
||||||
|
// D3DKMTEnumAdapters2.
|
||||||
|
func d3dkmEnumerateAdapters() ([]D3DKMT_ADAPTERINFO, error) {
|
||||||
|
var adapters [maxEnumAdapters]D3DKMT_ADAPTERINFO
|
||||||
|
enum := D3DKMT_ENUMADAPTERS2{
|
||||||
|
NumAdapters: maxEnumAdapters,
|
||||||
|
pAdapters: uintptr(unsafe.Pointer(&adapters[0])),
|
||||||
|
}
|
||||||
|
if err := ntstatusCall(procEnumAdapters2, unsafe.Pointer(&enum)); err != nil {
|
||||||
|
return nil, fmt.Errorf("EnumAdapters2: %w", err)
|
||||||
|
}
|
||||||
|
if enum.NumAdapters == 0 {
|
||||||
|
return nil, fmt.Errorf("no adapters found")
|
||||||
|
}
|
||||||
|
result := make([]D3DKMT_ADAPTERINFO, enum.NumAdapters)
|
||||||
|
for i := uint32(0); i < enum.NumAdapters; i++ {
|
||||||
|
result[i] = adapters[i]
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// d3dkmOpenAdapter opens a D3DKMT adapter handle for the given LUID.
|
||||||
|
func d3dkmOpenAdapter(luid LUID) (uint32, error) {
|
||||||
|
req := D3DKMT_OPENADAPTERFROMLUID{
|
||||||
|
AdapterLuid: luid,
|
||||||
|
}
|
||||||
|
if err := ntstatusCall(procOpenAdapterFromLuid, unsafe.Pointer(&req)); err != nil {
|
||||||
|
return 0, fmt.Errorf("OpenAdapterFromLuid: %w", err)
|
||||||
|
}
|
||||||
|
return req.hAdapter, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// d3dkmCloseAdapter closes a previously opened D3DKMT adapter handle.
|
||||||
|
func d3dkmCloseAdapter(hAdapter uint32) error {
|
||||||
|
req := D3DKMT_CLOSEADAPTER{hAdapter: hAdapter}
|
||||||
|
return ntstatusCall(procCloseAdapter, unsafe.Pointer(&req))
|
||||||
|
}
|
||||||
|
|
||||||
|
// d3dkmGetAdapterPerfData queries per-adapter performance data (temperature,
|
||||||
|
// fan RPM, power, bandwidth) via KMTQAITYPE_ADAPTERPERFDATA.
|
||||||
|
func d3dkmGetAdapterPerfData(hAdapter uint32) (*D3DKMT_ADAPTER_PERFDATA, error) {
|
||||||
|
var data D3DKMT_ADAPTER_PERFDATA
|
||||||
|
req := D3DKMT_QUERYADAPTERINFO{
|
||||||
|
hAdapter: hAdapter,
|
||||||
|
Type: KMTQAITYPE_ADAPTERPERFDATA,
|
||||||
|
pPrivateDriverData: uintptr(unsafe.Pointer(&data)),
|
||||||
|
PrivateDriverDataSize: uint32(unsafe.Sizeof(data)),
|
||||||
|
}
|
||||||
|
if err := ntstatusCall(procQueryAdapterInfo, unsafe.Pointer(&req)); err != nil {
|
||||||
|
return nil, fmt.Errorf("QueryAdapterInfo(ADAPTERPERFDATA): %w", err)
|
||||||
|
}
|
||||||
|
return &data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// d3dkmGetAdapterPerfDataCaps queries static adapter performance capabilities
|
||||||
|
// (max fan RPM, temperature limits, max bandwidth) via KMTQAITYPE_ADAPTERPERFDATA_CAPS.
|
||||||
|
func d3dkmGetAdapterPerfDataCaps(hAdapter uint32) (*D3DKMT_ADAPTER_PERFDATACAPS, error) {
|
||||||
|
var data D3DKMT_ADAPTER_PERFDATACAPS
|
||||||
|
req := D3DKMT_QUERYADAPTERINFO{
|
||||||
|
hAdapter: hAdapter,
|
||||||
|
Type: KMTQAITYPE_ADAPTERPERFDATA_CAPS,
|
||||||
|
pPrivateDriverData: uintptr(unsafe.Pointer(&data)),
|
||||||
|
PrivateDriverDataSize: uint32(unsafe.Sizeof(data)),
|
||||||
|
}
|
||||||
|
if err := ntstatusCall(procQueryAdapterInfo, unsafe.Pointer(&req)); err != nil {
|
||||||
|
return nil, fmt.Errorf("QueryAdapterInfo(ADAPTERPERFDATACAPS): %w", err)
|
||||||
|
}
|
||||||
|
return &data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type queryStatsBuffer struct {
|
||||||
|
Type int32 // offset 0
|
||||||
|
AdapterLuid LUID // offset 4
|
||||||
|
hProcess uintptr // offset 16
|
||||||
|
// _result mirrors the D3DKMT_QUERYSTATISTICS_RESULT union.
|
||||||
|
// sizeof(D3DKMT_QUERYSTATISTICS) == 0x328 (808 bytes) on x64.
|
||||||
|
//
|
||||||
|
// The C struct layout (x64):
|
||||||
|
// offset 0: Type (int32, 4 bytes)
|
||||||
|
// offset 4: AdapterLuid (LUID, 8 bytes)
|
||||||
|
// offset 12: 4 bytes padding (for 8-byte alignment of hProcess)
|
||||||
|
// offset 16: hProcess (HANDLE, 8 bytes)
|
||||||
|
// offset 24: QueryResult (union, 780 bytes — largest member is AdapterInformation)
|
||||||
|
// offset 804: anonymous input union (QueryNode.NodeId / QuerySegment.SegmentId, 4 bytes)
|
||||||
|
//
|
||||||
|
// Previous bug: _result was [776]byte, placing QueryId at offset 800 instead of 804.
|
||||||
|
// The kernel read NodeId/SegmentId from offset 804 (always zero from _pad),
|
||||||
|
// causing all NODE and SEGMENT queries to use index 0 regardless of the value
|
||||||
|
// passed in QueryId. This produced alternating behavior where only GPU util OR
|
||||||
|
// memory util appeared to work, depending on which test variant happened to put
|
||||||
|
// non-zero data near offset 804 in the result buffer.
|
||||||
|
_result [780]byte // offset 24, size 780 — places QueryId at offset 804
|
||||||
|
QueryId int32 // offset 804 — matches C anonymous union for NodeId/SegmentId
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
var buf queryStatsBuffer
|
||||||
|
if unsafe.Sizeof(buf) != 808 {
|
||||||
|
panic(fmt.Sprintf("queryStatsBuffer size %d != expected 808 (sizeof D3DKMT_QUERYSTATISTICS on x64)", unsafe.Sizeof(buf)))
|
||||||
|
}
|
||||||
|
if unsafe.Offsetof(buf.QueryId) != 804 {
|
||||||
|
panic(fmt.Sprintf("queryStatsBuffer.QueryId offset %d != expected 804 (C anonymous union offset)", unsafe.Offsetof(buf.QueryId)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var perfData D3DKMT_ADAPTER_PERFDATA
|
||||||
|
if unsafe.Sizeof(perfData) != 64 {
|
||||||
|
panic(fmt.Sprintf("D3DKMT_ADAPTER_PERFDATA size %d != expected 64 on x64", unsafe.Sizeof(perfData)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var caps D3DKMT_ADAPTER_PERFDATACAPS
|
||||||
|
if unsafe.Sizeof(caps) != 40 {
|
||||||
|
panic(fmt.Sprintf("D3DKMT_ADAPTER_PERFDATACAPS size %d != expected 40 on x64", unsafe.Sizeof(caps)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
qsoffsetNbSegments = 0
|
||||||
|
qsoffsetNodeCount = 4
|
||||||
|
qsoffsetCommitLimit = 0
|
||||||
|
qsoffsetBytesCommitted = 8
|
||||||
|
qsoffsetBytesResident = 16
|
||||||
|
qsoffsetRunningTime = 0
|
||||||
|
qsoffsetSystemRunningTime = 272
|
||||||
|
)
|
||||||
|
|
||||||
|
// d3dkmQueryAdapterStats returns the number of memory segments and compute
|
||||||
|
// nodes for the adapter identified by luid.
|
||||||
|
func d3dkmQueryAdapterStats(luid LUID) (nbSegments uint32, nodeCount uint32, err error) {
|
||||||
|
buf := queryStatsBuffer{
|
||||||
|
Type: int32(D3DKMT_QUERYSTATISTICS_ADAPTER),
|
||||||
|
AdapterLuid: luid,
|
||||||
|
}
|
||||||
|
if err := ntstatusCall(procQueryStatistics, unsafe.Pointer(&buf)); err != nil {
|
||||||
|
return 0, 0, fmt.Errorf("QueryStatistics(ADAPTER): %w", err)
|
||||||
|
}
|
||||||
|
nbSegments = binary.LittleEndian.Uint32(buf._result[qsoffsetNbSegments : qsoffsetNbSegments+4])
|
||||||
|
nodeCount = binary.LittleEndian.Uint32(buf._result[qsoffsetNodeCount : qsoffsetNodeCount+4])
|
||||||
|
return nbSegments, nodeCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// d3dkmQuerySegmentStats returns the commit limit (total) and resident
|
||||||
|
// (used) bytes for the given memory segment of an adapter.
|
||||||
|
func d3dkmQuerySegmentStats(luid LUID, segmentID uint32) (commitLimit uint64, bytesResident uint64, err error) {
|
||||||
|
buf := queryStatsBuffer{
|
||||||
|
Type: int32(D3DKMT_QUERYSTATISTICS_SEGMENT),
|
||||||
|
AdapterLuid: luid,
|
||||||
|
QueryId: int32(segmentID),
|
||||||
|
}
|
||||||
|
if err := ntstatusCall(procQueryStatistics, unsafe.Pointer(&buf)); err != nil {
|
||||||
|
return 0, 0, fmt.Errorf("QueryStatistics(SEGMENT %d): %w", segmentID, err)
|
||||||
|
}
|
||||||
|
commitLimit = binary.LittleEndian.Uint64(buf._result[qsoffsetCommitLimit : qsoffsetCommitLimit+8])
|
||||||
|
bytesResident = binary.LittleEndian.Uint64(buf._result[qsoffsetBytesResident : qsoffsetBytesResident+8])
|
||||||
|
if bytesResident == 0 {
|
||||||
|
bytesResident = binary.LittleEndian.Uint64(buf._result[qsoffsetBytesCommitted : qsoffsetBytesCommitted+8])
|
||||||
|
}
|
||||||
|
return commitLimit, bytesResident, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// d3dkmQueryNodeStats returns the global and system running time counters
|
||||||
|
// (in 100ns units) for the given compute node of an adapter.
|
||||||
|
func d3dkmQueryNodeStats(luid LUID, nodeID uint32) (runningTime uint64, systemRunningTime uint64, err error) {
|
||||||
|
buf := queryStatsBuffer{
|
||||||
|
Type: int32(D3DKMT_QUERYSTATISTICS_NODE),
|
||||||
|
AdapterLuid: luid,
|
||||||
|
QueryId: int32(nodeID),
|
||||||
|
}
|
||||||
|
if err := ntstatusCall(procQueryStatistics, unsafe.Pointer(&buf)); err != nil {
|
||||||
|
return 0, 0, fmt.Errorf("QueryStatistics(NODE %d): %w", nodeID, err)
|
||||||
|
}
|
||||||
|
runningTime = binary.LittleEndian.Uint64(buf._result[qsoffsetRunningTime : qsoffsetRunningTime+8])
|
||||||
|
systemRunningTime = binary.LittleEndian.Uint64(buf._result[qsoffsetSystemRunningTime : qsoffsetSystemRunningTime+8])
|
||||||
|
return runningTime, systemRunningTime, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type nodeRunningTimes struct {
|
||||||
|
Global uint64
|
||||||
|
System uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// d3dkmtNodeUtil computes GPU node utilization as a percentage from running
|
||||||
|
// time deltas. Returns -1 if counters went backwards (wrap/reset), 0 if idle.
|
||||||
|
func d3dkmtNodeUtil(prevRT, curRT nodeRunningTimes, elapsed100ns int64) float64 {
|
||||||
|
if curRT.Global < prevRT.Global || curRT.System < prevRT.System {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
gd := curRT.Global - prevRT.Global
|
||||||
|
sd := curRT.System - prevRT.System
|
||||||
|
|
||||||
|
if gd > 0 && sd > 0 {
|
||||||
|
util := float64(gd) / float64(sd)
|
||||||
|
if util > 1.0 {
|
||||||
|
util = 1.0
|
||||||
|
}
|
||||||
|
return util * 100.0
|
||||||
|
} else if gd > 0 && elapsed100ns > 0 {
|
||||||
|
util := float64(gd) / float64(elapsed100ns) * 100.0
|
||||||
|
if util > 100.0 {
|
||||||
|
util = 100.0
|
||||||
|
}
|
||||||
|
return util
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// d3dkmtFanPct returns fan speed as a percentage of maxFanRPM, clamped to
|
||||||
|
// 100%. Returns 0 if maxFanRPM is unavailable or fan is not spinning.
|
||||||
|
func d3dkmtFanPct(fanRPM, maxFanRPM uint32) float64 {
|
||||||
|
if maxFanRPM > 0 && fanRPM > 0 {
|
||||||
|
pct := float64(fanRPM) / float64(maxFanRPM) * 100.0
|
||||||
|
if pct > 100.0 {
|
||||||
|
pct = 100.0
|
||||||
|
}
|
||||||
|
return pct
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// d3dkmtPowerW converts power from deci-watts (as reported by D3DKMT) to
|
||||||
|
// watts. Returns 0 if the power value is zero.
|
||||||
|
func d3dkmtPowerW(power uint32) float64 {
|
||||||
|
if power > 0 {
|
||||||
|
return float64(power) / 10.0
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// d3dkmtTempC converts temperature from deci-Celsius (as reported by D3DKMT)
|
||||||
|
// to degrees Celsius.
|
||||||
|
func d3dkmtTempC(tempDeciC uint32) int {
|
||||||
|
return int(tempDeciC / 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
type d3dkmtAdapterState struct {
|
||||||
|
luid LUID
|
||||||
|
hAdapter uint32
|
||||||
|
nbSegments uint32
|
||||||
|
nodeCount uint32
|
||||||
|
maxFanRPM uint32
|
||||||
|
prevNodeRT map[uint32]nodeRunningTimes
|
||||||
|
prevTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryD3DKMT attempts to start GPU monitoring using D3DKMT and optional PDH
|
||||||
|
// counters. It returns a channel of GpuStat snapshots or an error if no
|
||||||
|
// usable adapters are found.
|
||||||
|
func tryD3DKMT(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||||
|
if err := initD3DKMT(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
adapterInfos, err := d3dkmEnumerateAdapters()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type adapterMeta struct {
|
||||||
|
luid LUID
|
||||||
|
nbSegments uint32
|
||||||
|
nodeCount uint32
|
||||||
|
maxFanRPM uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
var metaList []adapterMeta
|
||||||
|
|
||||||
|
for i, ai := range adapterInfos {
|
||||||
|
hAdapter, err := d3dkmOpenAdapter(ai.AdapterLuid)
|
||||||
|
if err != nil {
|
||||||
|
logger.Debugf("adapter %d: open failed: %s", i, err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
nbSegments, nodeCount, err := d3dkmQueryAdapterStats(ai.AdapterLuid)
|
||||||
|
if err != nil {
|
||||||
|
logger.Debugf("adapter %d: query stats failed: %s", i, err.Error())
|
||||||
|
d3dkmCloseAdapter(hAdapter)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
caps, err := d3dkmGetAdapterPerfDataCaps(hAdapter)
|
||||||
|
if err != nil {
|
||||||
|
logger.Debugf("adapter %d: perf caps failed: %s", i, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
d3dkmCloseAdapter(hAdapter)
|
||||||
|
|
||||||
|
var maxFanRPM uint32
|
||||||
|
if caps != nil {
|
||||||
|
maxFanRPM = caps.MaxFanRPM
|
||||||
|
}
|
||||||
|
|
||||||
|
metaList = append(metaList, adapterMeta{
|
||||||
|
luid: ai.AdapterLuid,
|
||||||
|
nbSegments: nbSegments,
|
||||||
|
nodeCount: nodeCount,
|
||||||
|
maxFanRPM: maxFanRPM,
|
||||||
|
})
|
||||||
|
logger.Debugf("adapter %d: segments=%d nodes=%d fan_max=%d luid=%d:%d", i, nbSegments, nodeCount, maxFanRPM, ai.AdapterLuid.HighPart, ai.AdapterLuid.LowPart)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(metaList) == 0 {
|
||||||
|
return nil, fmt.Errorf("no usable D3DKMT adapters found")
|
||||||
|
}
|
||||||
|
|
||||||
|
pdhUtil, pdhErr := initPdhGpuUtil()
|
||||||
|
if pdhErr != nil {
|
||||||
|
logger.Debugf("PDH GPU utilization not available: %s", pdhErr.Error())
|
||||||
|
} else {
|
||||||
|
logger.Info("using PDH performance counters for GPU utilization")
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := make(chan []GpuStat, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
if pdhUtil != nil {
|
||||||
|
defer pdhUtil.close()
|
||||||
|
}
|
||||||
|
|
||||||
|
var adapters []d3dkmtAdapterState
|
||||||
|
for _, m := range metaList {
|
||||||
|
hAdapter, err := d3dkmOpenAdapter(m.luid)
|
||||||
|
if err != nil {
|
||||||
|
logger.Debugf("reopen adapter failed: %s", err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
adapters = append(adapters, d3dkmtAdapterState{
|
||||||
|
luid: m.luid,
|
||||||
|
hAdapter: hAdapter,
|
||||||
|
nbSegments: m.nbSegments,
|
||||||
|
nodeCount: m.nodeCount,
|
||||||
|
maxFanRPM: m.maxFanRPM,
|
||||||
|
prevNodeRT: make(map[uint32]nodeRunningTimes),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(adapters) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
for _, a := range adapters {
|
||||||
|
d3dkmCloseAdapter(a.hAdapter)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for i := range adapters {
|
||||||
|
a := &adapters[i]
|
||||||
|
for node := uint32(0); node < a.nodeCount; node++ {
|
||||||
|
globalRT, systemRT, err := d3dkmQueryNodeStats(a.luid, node)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
a.prevNodeRT[node] = nodeRunningTimes{Global: globalRT, System: systemRT}
|
||||||
|
}
|
||||||
|
a.prevTime = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(every)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
stats := make([]GpuStat, 0, len(adapters))
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
var pdhUtilMap map[LUID]float64
|
||||||
|
if pdhUtil != nil {
|
||||||
|
pdhUtilMap = pdhUtil.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range adapters {
|
||||||
|
a := &adapters[i]
|
||||||
|
|
||||||
|
perfData, err := d3dkmGetAdapterPerfData(a.hAdapter)
|
||||||
|
if err != nil {
|
||||||
|
logger.Debugf("adapter %d perfdata: %s", i, err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var memUsedMB, memTotalMB int
|
||||||
|
for seg := uint32(0); seg < a.nbSegments; seg++ {
|
||||||
|
limit, resident, err := d3dkmQuerySegmentStats(a.luid, seg)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
memUsedMB += int(resident / (1024 * 1024))
|
||||||
|
memTotalMB += int(limit / (1024 * 1024))
|
||||||
|
}
|
||||||
|
|
||||||
|
var gpuUtil float64
|
||||||
|
pdhGaveValue := false
|
||||||
|
if pdhUtilMap != nil {
|
||||||
|
if util, ok := pdhUtilMap[a.luid]; ok {
|
||||||
|
gpuUtil = util
|
||||||
|
pdhGaveValue = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !pdhGaveValue && a.nodeCount > 0 {
|
||||||
|
elapsedNs := now.Sub(a.prevTime).Nanoseconds()
|
||||||
|
elapsed100ns := elapsedNs / 100
|
||||||
|
|
||||||
|
for node := uint32(0); node < a.nodeCount; node++ {
|
||||||
|
globalRT, systemRT, err := d3dkmQueryNodeStats(a.luid, node)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if prevRT, ok := a.prevNodeRT[node]; ok {
|
||||||
|
if globalRT < prevRT.Global || systemRT < prevRT.System {
|
||||||
|
a.prevNodeRT[node] = nodeRunningTimes{Global: globalRT, System: systemRT}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
nodeUtil := d3dkmtNodeUtil(prevRT, nodeRunningTimes{Global: globalRT, System: systemRT}, elapsed100ns)
|
||||||
|
if nodeUtil > gpuUtil {
|
||||||
|
gpuUtil = nodeUtil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
a.prevNodeRT[node] = nodeRunningTimes{Global: globalRT, System: systemRT}
|
||||||
|
}
|
||||||
|
|
||||||
|
a.prevTime = now
|
||||||
|
}
|
||||||
|
|
||||||
|
tempC := d3dkmtTempC(perfData.Temperature)
|
||||||
|
|
||||||
|
fanSpeedPct := d3dkmtFanPct(perfData.FanRPM, a.maxFanRPM)
|
||||||
|
powerDrawW := d3dkmtPowerW(perfData.Power)
|
||||||
|
|
||||||
|
var memUtilPct float64
|
||||||
|
if memTotalMB > 0 {
|
||||||
|
memUtilPct = float64(memUsedMB) / float64(memTotalMB) * 100.0
|
||||||
|
}
|
||||||
|
|
||||||
|
stats = append(stats, GpuStat{
|
||||||
|
Timestamp: now,
|
||||||
|
ID: i,
|
||||||
|
Name: fmt.Sprintf("GPU %d", i),
|
||||||
|
TempC: tempC,
|
||||||
|
GpuUtilPct: gpuUtil,
|
||||||
|
MemUtilPct: memUtilPct,
|
||||||
|
MemUsedMB: memUsedMB,
|
||||||
|
MemTotalMB: memTotalMB,
|
||||||
|
FanSpeedPct: fanSpeedPct,
|
||||||
|
PowerDrawW: powerDrawW,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(stats) > 0 {
|
||||||
|
select {
|
||||||
|
case ch <- stats:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return ch, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,98 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package perf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestD3dkmtNodeUtil_FullLoad(t *testing.T) {
|
||||||
|
prev := nodeRunningTimes{Global: 1000, System: 10000}
|
||||||
|
cur := nodeRunningTimes{Global: 5000, System: 14000}
|
||||||
|
got := d3dkmtNodeUtil(prev, cur, 100000)
|
||||||
|
assert.Equal(t, 100.0, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestD3dkmtNodeUtil_PartialUtil(t *testing.T) {
|
||||||
|
prev := nodeRunningTimes{Global: 1000, System: 10000}
|
||||||
|
cur := nodeRunningTimes{Global: 3000, System: 14000}
|
||||||
|
got := d3dkmtNodeUtil(prev, cur, 100000)
|
||||||
|
assert.Equal(t, 50.0, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestD3dkmtNodeUtil_Identical(t *testing.T) {
|
||||||
|
prev := nodeRunningTimes{Global: 10000, System: 10000}
|
||||||
|
cur := nodeRunningTimes{Global: 20000, System: 20000}
|
||||||
|
got := d3dkmtNodeUtil(prev, cur, 100000)
|
||||||
|
assert.Equal(t, 100.0, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestD3dkmtNodeUtil_CounterWrap(t *testing.T) {
|
||||||
|
prev := nodeRunningTimes{Global: 9000, System: 10000}
|
||||||
|
cur := nodeRunningTimes{Global: 1000, System: 10000}
|
||||||
|
got := d3dkmtNodeUtil(prev, cur, 100000)
|
||||||
|
assert.Equal(t, -1.0, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestD3dkmtNodeUtil_SystemWrap(t *testing.T) {
|
||||||
|
prev := nodeRunningTimes{Global: 1000, System: 9000}
|
||||||
|
cur := nodeRunningTimes{Global: 5000, System: 1000}
|
||||||
|
got := d3dkmtNodeUtil(prev, cur, 100000)
|
||||||
|
assert.Equal(t, -1.0, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestD3dkmtNodeUtil_ZeroDelta(t *testing.T) {
|
||||||
|
prev := nodeRunningTimes{Global: 1000, System: 10000}
|
||||||
|
cur := nodeRunningTimes{Global: 1000, System: 10000}
|
||||||
|
got := d3dkmtNodeUtil(prev, cur, 100000)
|
||||||
|
assert.Equal(t, 0.0, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestD3dkmtNodeUtil_ElapsedFallback(t *testing.T) {
|
||||||
|
prev := nodeRunningTimes{Global: 1000, System: 10000}
|
||||||
|
cur := nodeRunningTimes{Global: 6000, System: 10000}
|
||||||
|
got := d3dkmtNodeUtil(prev, cur, 50000)
|
||||||
|
assert.InDelta(t, 10.0, got, 0.01)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestD3dkmtFanPct_Normal(t *testing.T) {
|
||||||
|
assert.Equal(t, 50.0, d3dkmtFanPct(1500, 3000))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestD3dkmtFanPct_MaxFan(t *testing.T) {
|
||||||
|
assert.Equal(t, 100.0, d3dkmtFanPct(3000, 3000))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestD3dkmtFanPct_OverMaxClamped(t *testing.T) {
|
||||||
|
assert.Equal(t, 100.0, d3dkmtFanPct(4000, 3000))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestD3dkmtFanPct_ZeroMaxFan(t *testing.T) {
|
||||||
|
assert.Equal(t, 0.0, d3dkmtFanPct(1500, 0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestD3dkmtFanPct_ZeroFanRPM(t *testing.T) {
|
||||||
|
assert.Equal(t, 0.0, d3dkmtFanPct(0, 3000))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestD3dkmtFanPct_BothZero(t *testing.T) {
|
||||||
|
assert.Equal(t, 0.0, d3dkmtFanPct(0, 0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestD3dkmtPowerW(t *testing.T) {
|
||||||
|
assert.Equal(t, 250.0, d3dkmtPowerW(2500))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestD3dkmtPowerW_Zero(t *testing.T) {
|
||||||
|
assert.Equal(t, 0.0, d3dkmtPowerW(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestD3dkmtTempC(t *testing.T) {
|
||||||
|
assert.Equal(t, 65, d3dkmtTempC(650))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestD3dkmtTempC_Zero(t *testing.T) {
|
||||||
|
assert.Equal(t, 0, d3dkmtTempC(0))
|
||||||
|
}
|
||||||
@@ -22,6 +22,13 @@ func getGpuStats(ctx context.Context, every time.Duration, logger *logmon.Monito
|
|||||||
logger.Debugf("nvidia-smi: %s", err.Error())
|
logger.Debugf("nvidia-smi: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ch, err := tryD3DKMT(ctx, every, logger); err == nil {
|
||||||
|
logger.Info("using D3DKMT for GPU monitoring")
|
||||||
|
return ch, nil
|
||||||
|
} else {
|
||||||
|
logger.Debugf("D3DKMT: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
return nil, ErrNoGpuTool
|
return nil, ErrNoGpuTool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,159 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package perf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
pdhDLL = windows.NewLazySystemDLL("pdh.dll")
|
||||||
|
procPdhOpenQuery = pdhDLL.NewProc("PdhOpenQueryW")
|
||||||
|
procPdhAddEnglishCounter = pdhDLL.NewProc("PdhAddEnglishCounterW")
|
||||||
|
procPdhCollectQueryData = pdhDLL.NewProc("PdhCollectQueryData")
|
||||||
|
procPdhGetFormattedCounterArray = pdhDLL.NewProc("PdhGetFormattedCounterArrayW")
|
||||||
|
procPdhCloseQuery = pdhDLL.NewProc("PdhCloseQuery")
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
pdhFmtDouble = 0x00000200
|
||||||
|
pdhMoreData = 0x800007D2
|
||||||
|
pdhNoData = 0x800007D5
|
||||||
|
)
|
||||||
|
|
||||||
|
type pdhCounterValue struct {
|
||||||
|
CStatus uint32
|
||||||
|
DblVal float64
|
||||||
|
}
|
||||||
|
|
||||||
|
type pdhCounterValueItem struct {
|
||||||
|
SzName *uint16
|
||||||
|
FmtValue pdhCounterValue
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
var item pdhCounterValueItem
|
||||||
|
if unsafe.Sizeof(item) != 24 {
|
||||||
|
panic(fmt.Sprintf("pdhCounterValueItem size %d != expected 24 on x64", unsafe.Sizeof(item)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type pdhGpuUtil struct {
|
||||||
|
query uintptr
|
||||||
|
counter uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
// initPdhGpuUtil creates a PDH query for the GPU Engine utilization counter.
|
||||||
|
// Returns nil with an error if PDH or the counter is unavailable.
|
||||||
|
func initPdhGpuUtil() (*pdhGpuUtil, error) {
|
||||||
|
var query uintptr
|
||||||
|
if ret, _, _ := procPdhOpenQuery.Call(0, 0, uintptr(unsafe.Pointer(&query))); ret != 0 {
|
||||||
|
return nil, fmt.Errorf("PdhOpenQuery: 0x%x", ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
path, _ := windows.UTF16PtrFromString(`\GPU Engine(*)\Utilization Percentage`)
|
||||||
|
var counter uintptr
|
||||||
|
if ret, _, _ := procPdhAddEnglishCounter.Call(
|
||||||
|
query, uintptr(unsafe.Pointer(path)), 0, uintptr(unsafe.Pointer(&counter)),
|
||||||
|
); ret != 0 {
|
||||||
|
procPdhCloseQuery.Call(query)
|
||||||
|
return nil, fmt.Errorf("PdhAddEnglishCounter(GPU Engine): 0x%x", ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
procPdhCollectQueryData.Call(query)
|
||||||
|
|
||||||
|
return &pdhGpuUtil{query: query, counter: counter}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// close releases the PDH query handle.
|
||||||
|
func (p *pdhGpuUtil) close() {
|
||||||
|
if p.query != 0 {
|
||||||
|
procPdhCloseQuery.Call(p.query)
|
||||||
|
p.query = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// collect reads the PDH counter and returns a map of adapter LUID to
|
||||||
|
// aggregated GPU utilization percentage, summed across all engine instances
|
||||||
|
// per adapter and clamped to 100%.
|
||||||
|
func (p *pdhGpuUtil) collect() map[LUID]float64 {
|
||||||
|
ret, _, _ := procPdhCollectQueryData.Call(p.query)
|
||||||
|
if ret != 0 && ret != pdhNoData {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var bufSize uint32
|
||||||
|
var itemCount uint32
|
||||||
|
ret, _, _ = procPdhGetFormattedCounterArray.Call(
|
||||||
|
p.counter, pdhFmtDouble,
|
||||||
|
uintptr(unsafe.Pointer(&bufSize)),
|
||||||
|
uintptr(unsafe.Pointer(&itemCount)),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
if ret != pdhMoreData || itemCount == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, bufSize)
|
||||||
|
ret, _, _ = procPdhGetFormattedCounterArray.Call(
|
||||||
|
p.counter, pdhFmtDouble,
|
||||||
|
uintptr(unsafe.Pointer(&bufSize)),
|
||||||
|
uintptr(unsafe.Pointer(&itemCount)),
|
||||||
|
uintptr(unsafe.Pointer(&buf[0])),
|
||||||
|
)
|
||||||
|
if ret != 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
itemSize := uint32(unsafe.Sizeof(pdhCounterValueItem{}))
|
||||||
|
result := make(map[LUID]float64)
|
||||||
|
|
||||||
|
for i := uint32(0); i < itemCount; i++ {
|
||||||
|
item := (*pdhCounterValueItem)(unsafe.Pointer(&buf[i*itemSize]))
|
||||||
|
if item.FmtValue.CStatus != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
luid, ok := parsePdhLuid(windows.UTF16PtrToString(item.SzName))
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result[luid] += item.FmtValue.DblVal
|
||||||
|
}
|
||||||
|
|
||||||
|
for luid := range result {
|
||||||
|
if result[luid] > 100.0 {
|
||||||
|
result[luid] = 100.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// parsePdhLuid extracts the adapter LUID (high and low parts) from a PDH
|
||||||
|
// GPU Engine instance name (e.g. "pid_1234_luid_0x00000000_0x000148BF_phys_0_eng_2_engtype_Compute").
|
||||||
|
func parsePdhLuid(name string) (LUID, bool) {
|
||||||
|
idx := strings.Index(name, "luid_0x")
|
||||||
|
if idx < 0 {
|
||||||
|
return LUID{}, false
|
||||||
|
}
|
||||||
|
rest := name[idx+7:]
|
||||||
|
parts := strings.SplitN(rest, "_", 4)
|
||||||
|
if len(parts) < 3 {
|
||||||
|
return LUID{}, false
|
||||||
|
}
|
||||||
|
hp, err := strconv.ParseUint(parts[0], 16, 32)
|
||||||
|
if err != nil {
|
||||||
|
return LUID{}, false
|
||||||
|
}
|
||||||
|
lpStr := strings.TrimPrefix(parts[1], "0x")
|
||||||
|
lp, err := strconv.ParseUint(lpStr, 16, 32)
|
||||||
|
if err != nil {
|
||||||
|
return LUID{}, false
|
||||||
|
}
|
||||||
|
return LUID{LowPart: uint32(lp), HighPart: int32(hp)}, true
|
||||||
|
}
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package perf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParsePdhLuid_Valid(t *testing.T) {
|
||||||
|
name := `pid_25312_luid_0x00000000_0x000148BF_phys_0_eng_2_engtype_Compute`
|
||||||
|
got, ok := parsePdhLuid(name)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, uint32(0x000148BF), got.LowPart)
|
||||||
|
assert.Equal(t, int32(0x00000000), got.HighPart)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePdhLuid_ValidNvidia(t *testing.T) {
|
||||||
|
name := `pid_1388_luid_0x00000000_0x00011372_phys_0_eng_8_engtype_Compute_1`
|
||||||
|
got, ok := parsePdhLuid(name)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, uint32(0x00011372), got.LowPart)
|
||||||
|
assert.Equal(t, int32(0x00000000), got.HighPart)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePdhLuid_NonZeroHighPart(t *testing.T) {
|
||||||
|
name := `pid_1234_luid_0x00000001_0x0000C85A_phys_0_eng_5_engtype_Copy`
|
||||||
|
got, ok := parsePdhLuid(name)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, uint32(0x0000C85A), got.LowPart)
|
||||||
|
assert.Equal(t, int32(0x00000001), got.HighPart)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePdhLuid_InvalidNoLuid(t *testing.T) {
|
||||||
|
_, ok := parsePdhLuid("invalid_string_without_luid")
|
||||||
|
assert.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePdhLuid_InvalidEmpty(t *testing.T) {
|
||||||
|
_, ok := parsePdhLuid("")
|
||||||
|
assert.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePdhLuid_InvalidHex(t *testing.T) {
|
||||||
|
_, ok := parsePdhLuid("pid_1234_luid_0xZZZZ_0xGGGG_phys_0")
|
||||||
|
assert.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePdhLuid_ShortAfterLuid(t *testing.T) {
|
||||||
|
_, ok := parsePdhLuid("pid_1234_luid_0x00000000")
|
||||||
|
assert.False(t, ok)
|
||||||
|
}
|
||||||
+122
-417
@@ -11,6 +11,8 @@ import (
|
|||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
"github.com/mostlygeek/llama-swap/internal/process"
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
type shutdownReq struct {
|
type shutdownReq struct {
|
||||||
@@ -24,56 +26,16 @@ type unloadReq struct {
|
|||||||
respond chan struct{}
|
respond chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
type handlerReq struct {
|
// baseRouter owns the channels, run-loop, and process machinery shared by every
|
||||||
model string
|
// concrete router. Concrete routers embed *baseRouter and supply a
|
||||||
ctx context.Context
|
// scheduler.Swapper describing how eviction sets are decided. baseRouter
|
||||||
respond chan handlerResp
|
// implements scheduler.Effects so the scheduler can call back for side-effects.
|
||||||
positionCh chan int
|
|
||||||
}
|
|
||||||
|
|
||||||
type handlerResp struct {
|
|
||||||
handleFunc http.HandlerFunc
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
type swapDone struct {
|
|
||||||
modelID string
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
type serveDoneEvent struct {
|
|
||||||
modelID string
|
|
||||||
}
|
|
||||||
|
|
||||||
type activeSwap struct {
|
|
||||||
modelID string
|
|
||||||
evict []string
|
|
||||||
waiters []handlerReq
|
|
||||||
}
|
|
||||||
|
|
||||||
// swapPlanner is the only piece of behaviour that differs between concrete
|
|
||||||
// routers. baseRouter never inspects its internals.
|
|
||||||
type swapPlanner interface {
|
|
||||||
// EvictionFor returns running model IDs that must be stopped before
|
|
||||||
// target can serve. alsoRunning lists models the baseRouter has already
|
|
||||||
// committed to loading (in-flight swaps) which the planner cannot see
|
|
||||||
// via process.State() yet. Pure decision; must not log.
|
|
||||||
EvictionFor(target string, alsoRunning []string) []string
|
|
||||||
|
|
||||||
// OnSwapStart runs once at the start of every swap. Planners may log
|
|
||||||
// their decision here at whatever verbosity they choose.
|
|
||||||
OnSwapStart(target string)
|
|
||||||
}
|
|
||||||
|
|
||||||
// baseRouter owns the channels, run-loop, and orchestration code shared by
|
|
||||||
// every concrete router. Concrete routers embed *baseRouter and supply a
|
|
||||||
// swapPlanner that captures how their eviction set is decided.
|
|
||||||
type baseRouter struct {
|
type baseRouter struct {
|
||||||
name string
|
name string
|
||||||
config config.Config
|
config config.Config
|
||||||
processes map[string]process.Process
|
processes map[string]process.Process
|
||||||
logger *logmon.Monitor
|
logger *logmon.Monitor
|
||||||
planner swapPlanner
|
schedule scheduler.Scheduler
|
||||||
|
|
||||||
// shutdownCtx governs the request machinery: cancelling it tells grant()
|
// shutdownCtx governs the request machinery: cancelling it tells grant()
|
||||||
// and ServeHTTP to stop granting and reject callers. It is deliberately
|
// and ServeHTTP to stop granting and reject callers. It is deliberately
|
||||||
@@ -90,11 +52,12 @@ type baseRouter struct {
|
|||||||
procCtx context.Context
|
procCtx context.Context
|
||||||
procCancel context.CancelFunc
|
procCancel context.CancelFunc
|
||||||
|
|
||||||
handlerCh chan handlerReq
|
handlerCh chan scheduler.HandlerReq
|
||||||
|
cancelCh chan scheduler.HandlerReq
|
||||||
shutdownCh chan shutdownReq
|
shutdownCh chan shutdownReq
|
||||||
unloadCh chan unloadReq
|
unloadCh chan unloadReq
|
||||||
swapDoneCh chan swapDone
|
swapDoneCh chan scheduler.SwapDone
|
||||||
serveDoneCh chan serveDoneEvent
|
serveDoneCh chan scheduler.ServeDoneEvent
|
||||||
|
|
||||||
runDone chan struct{}
|
runDone chan struct{}
|
||||||
|
|
||||||
@@ -106,26 +69,38 @@ type baseRouter struct {
|
|||||||
testProcessed chan struct{}
|
testProcessed chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newBaseRouter(name string, conf config.Config, processes map[string]process.Process, planner swapPlanner, logger *logmon.Monitor) *baseRouter {
|
func newBaseRouter(
|
||||||
|
name string,
|
||||||
|
conf config.Config,
|
||||||
|
processes map[string]process.Process,
|
||||||
|
logger *logmon.Monitor,
|
||||||
|
planner scheduler.Swapper,
|
||||||
|
) (*baseRouter, error) {
|
||||||
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
|
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
|
||||||
procCtx, procCancel := context.WithCancel(context.Background())
|
procCtx, procCancel := context.WithCancel(context.Background())
|
||||||
return &baseRouter{
|
b := &baseRouter{
|
||||||
name: name,
|
name: name,
|
||||||
config: conf,
|
config: conf,
|
||||||
processes: processes,
|
processes: processes,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
planner: planner,
|
|
||||||
shutdownCtx: shutdownCtx,
|
shutdownCtx: shutdownCtx,
|
||||||
shutdownFn: shutdownFn,
|
shutdownFn: shutdownFn,
|
||||||
procCtx: procCtx,
|
procCtx: procCtx,
|
||||||
procCancel: procCancel,
|
procCancel: procCancel,
|
||||||
handlerCh: make(chan handlerReq),
|
handlerCh: make(chan scheduler.HandlerReq),
|
||||||
|
cancelCh: make(chan scheduler.HandlerReq),
|
||||||
shutdownCh: make(chan shutdownReq),
|
shutdownCh: make(chan shutdownReq),
|
||||||
unloadCh: make(chan unloadReq),
|
unloadCh: make(chan unloadReq),
|
||||||
swapDoneCh: make(chan swapDone),
|
swapDoneCh: make(chan scheduler.SwapDone),
|
||||||
serveDoneCh: make(chan serveDoneEvent),
|
serveDoneCh: make(chan scheduler.ServeDoneEvent),
|
||||||
runDone: make(chan struct{}),
|
runDone: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
sched, err := scheduler.New(conf, name, logger, planner, b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
b.schedule = sched
|
||||||
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *baseRouter) notifyProcessed() {
|
func (b *baseRouter) notifyProcessed() {
|
||||||
@@ -137,30 +112,31 @@ func (b *baseRouter) notifyProcessed() {
|
|||||||
func (b *baseRouter) run() {
|
func (b *baseRouter) run() {
|
||||||
defer close(b.runDone)
|
defer close(b.runDone)
|
||||||
|
|
||||||
active := make(map[string]*activeSwap)
|
|
||||||
inFlight := make(map[string]int)
|
|
||||||
var queued []handlerReq
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case req := <-b.shutdownCh:
|
case req := <-b.shutdownCh:
|
||||||
b.handleShutdown(req, active, queued)
|
b.handleShutdown(req)
|
||||||
return
|
return
|
||||||
|
|
||||||
case req := <-b.handlerCh:
|
case req := <-b.handlerCh:
|
||||||
b.handleRequest(req, active, inFlight, &queued)
|
b.schedule.OnRequest(req)
|
||||||
|
b.notifyProcessed()
|
||||||
|
|
||||||
|
case req := <-b.cancelCh:
|
||||||
|
b.schedule.OnCancel(req)
|
||||||
b.notifyProcessed()
|
b.notifyProcessed()
|
||||||
|
|
||||||
case req := <-b.unloadCh:
|
case req := <-b.unloadCh:
|
||||||
b.handleUnload(req, active, inFlight, &queued)
|
b.schedule.OnUnload(req.targets, req.timeout)
|
||||||
|
close(req.respond)
|
||||||
b.notifyProcessed()
|
b.notifyProcessed()
|
||||||
|
|
||||||
case ev := <-b.swapDoneCh:
|
case ev := <-b.swapDoneCh:
|
||||||
b.handleSwapDone(ev, active, inFlight, &queued)
|
b.schedule.OnSwapDone(ev)
|
||||||
b.notifyProcessed()
|
b.notifyProcessed()
|
||||||
|
|
||||||
case ev := <-b.serveDoneCh:
|
case ev := <-b.serveDoneCh:
|
||||||
b.handleServeDone(ev, active, inFlight, &queued)
|
b.schedule.OnServeDone(ev)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -177,37 +153,68 @@ func (b *baseRouter) run() {
|
|||||||
// down, the send never lands, one of the other select cases fires, and we
|
// down, the send never lands, one of the other select cases fires, and we
|
||||||
// report back that the grant did NOT happen.
|
// report back that the grant did NOT happen.
|
||||||
//
|
//
|
||||||
// That distinction matters for in-flight bookkeeping — see grantHandler.
|
// That distinction matters for in-flight bookkeeping — see GrantServe.
|
||||||
func (b *baseRouter) grant(req handlerReq, resp handlerResp) bool {
|
func (b *baseRouter) grant(req scheduler.HandlerReq, resp scheduler.HandlerResp) bool {
|
||||||
select {
|
select {
|
||||||
case req.respond <- resp:
|
case req.Respond <- resp:
|
||||||
return true
|
return true
|
||||||
case <-req.ctx.Done():
|
case <-req.Ctx.Done():
|
||||||
return false
|
return false
|
||||||
case <-b.shutdownCtx.Done():
|
case <-b.shutdownCtx.Done():
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// grantHandler is the "this caller can now use process p" path. It does
|
// ModelState implements scheduler.Effects.
|
||||||
// two things that must stay locked together:
|
func (b *baseRouter) ModelState(modelID string) (process.ProcessState, bool) {
|
||||||
//
|
p, ok := b.processes[modelID]
|
||||||
// 1. Hand the caller a wrapped p.ServeHTTP (via trackedServe) so when the
|
if !ok {
|
||||||
// HTTP request finishes, the run loop hears about it.
|
var zero process.ProcessState
|
||||||
// 2. Bump inFlight[modelID] so the router knows this process is busy and
|
return zero, false
|
||||||
// refuses to evict it until the count comes back down.
|
|
||||||
//
|
|
||||||
// The increment is gated on grant() returning true. If grant() returns
|
|
||||||
// false, the caller already walked away and trackedServe will never run —
|
|
||||||
// which means no matching decrement will ever arrive on serveDoneCh.
|
|
||||||
// Incrementing in that case would strand the counter at >0 forever and
|
|
||||||
// the router would never again be willing to swap this model out.
|
|
||||||
//
|
|
||||||
// In short: increment if and only if we know a decrement is coming.
|
|
||||||
func (b *baseRouter) grantHandler(req handlerReq, modelID string, p process.Process, inFlight map[string]int) {
|
|
||||||
if b.grant(req, handlerResp{handleFunc: b.trackedServe(modelID, p)}) {
|
|
||||||
inFlight[modelID]++
|
|
||||||
}
|
}
|
||||||
|
return p.State(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartSwap implements scheduler.Effects, launching the swap goroutine.
|
||||||
|
func (b *baseRouter) StartSwap(modelID string, evict []string) {
|
||||||
|
go b.doSwap(modelID, evict)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GrantError implements scheduler.Effects.
|
||||||
|
func (b *baseRouter) GrantError(req scheduler.HandlerReq, err error) {
|
||||||
|
b.grant(req, scheduler.HandlerResp{Err: err})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GrantServe implements scheduler.Effects. It hands the caller a wrapped
|
||||||
|
// p.ServeHTTP (via trackedServe) so the run loop hears about the request
|
||||||
|
// finishing, and reports whether the caller received it. The scheduler bumps
|
||||||
|
// its in-flight count only on a true return: if grant() returns false the
|
||||||
|
// caller already walked away and trackedServe will never run, so no matching
|
||||||
|
// decrement will ever arrive — incrementing would strand the counter at >0 and
|
||||||
|
// the router would never again be willing to evict this model.
|
||||||
|
func (b *baseRouter) GrantServe(req scheduler.HandlerReq, modelID string) bool {
|
||||||
|
p := b.processes[modelID]
|
||||||
|
return b.grant(req, scheduler.HandlerResp{HandleFunc: b.trackedServe(modelID, p)})
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopProcesses implements scheduler.Effects, stopping the named processes in
|
||||||
|
// parallel and blocking until all have stopped.
|
||||||
|
func (b *baseRouter) StopProcesses(timeout time.Duration, ids []string) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, id := range ids {
|
||||||
|
p, ok := b.processes[id]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id string, p process.Process) {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := p.Stop(timeout); err != nil {
|
||||||
|
b.logger.Warnf("%s: stopping %s failed: %v", b.name, id, err)
|
||||||
|
}
|
||||||
|
}(id, p)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// trackedServe is the wrapper that closes the loop on in-flight tracking.
|
// trackedServe is the wrapper that closes the loop on in-flight tracking.
|
||||||
@@ -224,7 +231,7 @@ func (b *baseRouter) trackedServe(modelID string, p process.Process) http.Handle
|
|||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
defer func() {
|
defer func() {
|
||||||
select {
|
select {
|
||||||
case b.serveDoneCh <- serveDoneEvent{modelID: modelID}:
|
case b.serveDoneCh <- scheduler.ServeDoneEvent{ModelID: modelID}:
|
||||||
case <-b.shutdownCtx.Done():
|
case <-b.shutdownCtx.Done():
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -232,240 +239,6 @@ func (b *baseRouter) trackedServe(modelID string, p process.Process) http.Handle
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleRequest decides what to do with one incoming ServeHTTP request. It is
|
|
||||||
// called from run() and never blocks indefinitely: any work that has to wait
|
|
||||||
// (starting a process, stopping siblings, waiting for ready) is deferred to
|
|
||||||
// a swap goroutine and reported back via swapDoneCh.
|
|
||||||
//
|
|
||||||
// The decision tree, in order:
|
|
||||||
//
|
|
||||||
// 1. Unknown model — respond with ErrNoLocalModelFound and move on.
|
|
||||||
// 2. A swap to the same model is already in flight — attach this waiter so
|
|
||||||
// one swap serves all callers that asked for the same model.
|
|
||||||
// 3. Fast path — the target process is already ready, the planner sees
|
|
||||||
// nothing to evict, and no in-flight swap is evicting it. Hand back its
|
|
||||||
// ServeHTTP immediately (wrapped so the run loop knows when it ends).
|
|
||||||
// 4. Would collide with an in-flight swap (we'd stop their target, or
|
|
||||||
// they're stopping us) — park in the queue for handleSwapDone to drain.
|
|
||||||
// 5. Would evict a process that is still handling requests — park in the
|
|
||||||
// queue. handleServeDone will retry when the busy process drains.
|
|
||||||
// 6. Otherwise — start a new swap. This may run in parallel with other
|
|
||||||
// active swaps when their evict sets don't intersect.
|
|
||||||
func (b *baseRouter) handleRequest(req handlerReq, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
|
||||||
// (1) Unknown model.
|
|
||||||
p, ok := b.processes[req.model]
|
|
||||||
if !ok {
|
|
||||||
b.logger.Debugf("%s: model %s not handled by this router", b.name, req.model)
|
|
||||||
b.grant(req, handlerResp{err: ErrNoLocalModelFound})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// (2) Join an in-flight swap for the same model.
|
|
||||||
if s, ok := active[req.model]; ok {
|
|
||||||
b.logger.Debugf("%s: joining in-flight swap for model %s (%d waiters)", b.name, req.model, len(s.waiters)+1)
|
|
||||||
s.waiters = append(s.waiters, req)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
evict := b.planner.EvictionFor(req.model, activeTargets(active, req.model))
|
|
||||||
|
|
||||||
// (3) Fast path: ready, nothing to evict, and nobody is evicting us.
|
|
||||||
if p.State() == process.StateReady && len(evict) == 0 && !collidesWith(req.model, evict, active) {
|
|
||||||
b.logger.Debugf("%s: fast-path serving model %s (already ready)", b.name, req.model)
|
|
||||||
b.grantHandler(req, req.model, p, inFlight)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// (4) Collision with an in-flight swap — queue.
|
|
||||||
if collidesWith(req.model, evict, active) {
|
|
||||||
b.logger.Debugf("%s: queuing request for model %s (collides with in-flight swap)", b.name, req.model)
|
|
||||||
*queued = append(*queued, req)
|
|
||||||
b.broadcastQueuePositions(*queued)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// (5) Would evict a busy process — queue until it drains.
|
|
||||||
if conflictsWithInFlight(evict, inFlight) {
|
|
||||||
b.logger.Debugf("%s: queuing request for model %s (would evict in-flight process)", b.name, req.model)
|
|
||||||
*queued = append(*queued, req)
|
|
||||||
b.broadcastQueuePositions(*queued)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// (6) Start a new (possibly parallel) swap.
|
|
||||||
b.logger.Debugf("%s: starting swap for model %s, evicting %v", b.name, req.model, evict)
|
|
||||||
s := b.startSwap(req, evict)
|
|
||||||
active[s.modelID] = s
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleSwapDone is called from run() when a swap goroutine reports that it
|
|
||||||
// has finished. It fans out the result to every waiter that joined this swap,
|
|
||||||
// removes the swap from the active map, and then walks the queue once,
|
|
||||||
// promoting any items that no longer collide with the remaining active set.
|
|
||||||
// FIFO order is preserved: items still blocked stay in place.
|
|
||||||
func (b *baseRouter) handleSwapDone(ev swapDone, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
|
||||||
s, ok := active[ev.modelID]
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
delete(active, ev.modelID)
|
|
||||||
|
|
||||||
for _, w := range s.waiters {
|
|
||||||
if ev.err != nil {
|
|
||||||
b.grant(w, handlerResp{err: ev.err})
|
|
||||||
} else {
|
|
||||||
p := b.processes[ev.modelID]
|
|
||||||
b.grantHandler(w, ev.modelID, p, inFlight)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
b.drainQueue(active, inFlight, queued)
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleServeDone is called from run() each time a tracked ServeHTTP
|
|
||||||
// finishes. It decrements the per-model in-flight count and, when that
|
|
||||||
// drops to zero, retries the queue: requests whose swap was deferred
|
|
||||||
// because they would have evicted this (now-idle) process can now proceed.
|
|
||||||
func (b *baseRouter) handleServeDone(ev serveDoneEvent, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
|
||||||
inFlight[ev.modelID]--
|
|
||||||
if inFlight[ev.modelID] <= 0 {
|
|
||||||
delete(inFlight, ev.modelID)
|
|
||||||
b.drainQueue(active, inFlight, queued)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// drainQueue walks the queued requests in order, re-running the handleRequest
|
|
||||||
// decision tree against the (now smaller) active set. Items that can now start
|
|
||||||
// or join become satisfied; items still blocked remain queued in original
|
|
||||||
// order so they get another chance on the next swap completion.
|
|
||||||
func (b *baseRouter) drainQueue(active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
|
||||||
if len(*queued) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
pending := *queued
|
|
||||||
var remaining []handlerReq
|
|
||||||
for _, req := range pending {
|
|
||||||
p, ok := b.processes[req.model]
|
|
||||||
if !ok {
|
|
||||||
b.grant(req, handlerResp{err: ErrNoLocalModelFound})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if s, ok := active[req.model]; ok {
|
|
||||||
b.logger.Debugf("%s: queued request for model %s now joining in-flight swap", b.name, req.model)
|
|
||||||
s.waiters = append(s.waiters, req)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
evict := b.planner.EvictionFor(req.model, activeTargets(active, req.model))
|
|
||||||
if p.State() == process.StateReady && len(evict) == 0 && !collidesWith(req.model, evict, active) {
|
|
||||||
b.logger.Debugf("%s: queued request for model %s now served fast-path", b.name, req.model)
|
|
||||||
b.grantHandler(req, req.model, p, inFlight)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if collidesWith(req.model, evict, active) {
|
|
||||||
remaining = append(remaining, req)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if conflictsWithInFlight(evict, inFlight) {
|
|
||||||
remaining = append(remaining, req)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
b.logger.Debugf("%s: queued request for model %s now starting swap, evicting %v", b.name, req.model, evict)
|
|
||||||
s := b.startSwap(req, evict)
|
|
||||||
active[s.modelID] = s
|
|
||||||
}
|
|
||||||
*queued = remaining
|
|
||||||
b.broadcastQueuePositions(*queued)
|
|
||||||
}
|
|
||||||
|
|
||||||
// broadcastQueuePositions sends each queued request its current 1-indexed
|
|
||||||
// position. Sends are non-blocking: if the channel is full, the old value is
|
|
||||||
// drained first so the consumer always sees the latest position.
|
|
||||||
func (b *baseRouter) broadcastQueuePositions(queued []handlerReq) {
|
|
||||||
for i, req := range queued {
|
|
||||||
pos := i + 1
|
|
||||||
select {
|
|
||||||
case req.positionCh <- pos:
|
|
||||||
default:
|
|
||||||
select {
|
|
||||||
case <-req.positionCh:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case req.positionCh <- pos:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *baseRouter) startSwap(initial handlerReq, evict []string) *activeSwap {
|
|
||||||
swap := &activeSwap{
|
|
||||||
modelID: initial.model,
|
|
||||||
evict: evict,
|
|
||||||
waiters: []handlerReq{initial},
|
|
||||||
}
|
|
||||||
b.planner.OnSwapStart(initial.model)
|
|
||||||
go b.doSwap(initial.model, evict)
|
|
||||||
return swap
|
|
||||||
}
|
|
||||||
|
|
||||||
// activeTargets returns the IDs of every in-flight swap target except exclude.
|
|
||||||
// baseRouter passes this to the planner so eviction decisions account for
|
|
||||||
// models that have been committed to but have not yet transitioned to
|
|
||||||
// StateStarting in their process state machine.
|
|
||||||
func activeTargets(active map[string]*activeSwap, exclude string) []string {
|
|
||||||
if len(active) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
out := make([]string, 0, len(active))
|
|
||||||
for id := range active {
|
|
||||||
if id == exclude {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
out = append(out, id)
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// collidesWith reports whether a new swap with this target and evict set can
|
|
||||||
// safely run alongside the currently active swaps. Same-target callers should
|
|
||||||
// JOIN (handled before this) — they do not collide with themselves.
|
|
||||||
func collidesWith(target string, evict []string, active map[string]*activeSwap) bool {
|
|
||||||
for id, s := range active {
|
|
||||||
if id == target {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if containsString(evict, id) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if containsString(s.evict, target) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// conflictsWithInFlight reports whether any model in evict is still handling
|
|
||||||
// requests. Stopping a busy process would cancel its callers' connections,
|
|
||||||
// so the router defers the swap until those callers finish.
|
|
||||||
func conflictsWithInFlight(evict []string, inFlight map[string]int) bool {
|
|
||||||
for _, m := range evict {
|
|
||||||
if inFlight[m] > 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func containsString(xs []string, s string) bool {
|
|
||||||
for _, x := range xs {
|
|
||||||
if x == s {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *baseRouter) doSwap(modelID string, toStop []string) {
|
func (b *baseRouter) doSwap(modelID string, toStop []string) {
|
||||||
timeout := b.healthCheckTimeout()
|
timeout := b.healthCheckTimeout()
|
||||||
|
|
||||||
@@ -493,31 +266,24 @@ func (b *baseRouter) doSwap(modelID string, toStop []string) {
|
|||||||
err := target.WaitReady(b.shutdownCtx)
|
err := target.WaitReady(b.shutdownCtx)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case b.swapDoneCh <- swapDone{modelID: modelID, err: err}:
|
case b.swapDoneCh <- scheduler.SwapDone{ModelID: modelID, Err: err}:
|
||||||
case <-b.shutdownCtx.Done():
|
case <-b.shutdownCtx.Done():
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *baseRouter) handleShutdown(req shutdownReq, active map[string]*activeSwap, queued []handlerReq) {
|
func (b *baseRouter) handleShutdown(req shutdownReq) {
|
||||||
shutdownErr := fmt.Errorf("%s is shutting down", b.name)
|
shutdownErr := fmt.Errorf("%s is shutting down", b.name)
|
||||||
|
|
||||||
// Cancel shutdownCtx first so any waiter that is currently parked on
|
// Cancel shutdownCtx first so any waiter that is currently parked on
|
||||||
// its respond channel can exit via its own shutdownCtx.Done() branch.
|
// its respond channel can exit via its own shutdownCtx.Done() branch.
|
||||||
// The grant calls below then either land (waiter happened to receive
|
// The OnShutdown grants below then either land (waiter happened to receive
|
||||||
// before noticing shutdown) or fall through immediately via grant's
|
// before noticing shutdown) or fall through immediately via grant's
|
||||||
// shutdownCtx case — either way the waiter sees a non-OK response.
|
// shutdownCtx case — either way the waiter sees a non-OK response.
|
||||||
// This does NOT touch processes: their lifetime is procCtx, cancelled
|
// This does NOT touch processes: their lifetime is procCtx, cancelled
|
||||||
// only after the graceful Stop() calls below have reaped them.
|
// only after the graceful Stop() calls below have reaped them.
|
||||||
b.shutdownFn()
|
b.shutdownFn()
|
||||||
|
|
||||||
for _, s := range active {
|
b.schedule.OnShutdown(shutdownErr)
|
||||||
for _, w := range s.waiters {
|
|
||||||
b.grant(w, handlerResp{err: shutdownErr})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, w := range queued {
|
|
||||||
b.grant(w, handlerResp{err: shutdownErr})
|
|
||||||
}
|
|
||||||
|
|
||||||
stopTimeout := req.timeout
|
stopTimeout := req.timeout
|
||||||
if stopTimeout <= 0 {
|
if stopTimeout <= 0 {
|
||||||
@@ -628,75 +394,6 @@ func (b *baseRouter) Unload(timeout time.Duration, models ...string) {
|
|||||||
<-req.respond
|
<-req.respond
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleUnload runs on the run loop in response to an Unload call. It
|
|
||||||
// reconciles router-owned state with the impending Stop, then performs
|
|
||||||
// the Stop synchronously so callers of Unload remain blocked until each
|
|
||||||
// targeted process has actually exited.
|
|
||||||
func (b *baseRouter) handleUnload(req unloadReq, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
|
||||||
unloadErr := fmt.Errorf("%s: model unloaded", b.name)
|
|
||||||
|
|
||||||
targetSet := make(map[string]bool, len(req.targets))
|
|
||||||
for _, id := range req.targets {
|
|
||||||
targetSet[id] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Release waiters of any in-flight swap whose target is being
|
|
||||||
// unloaded. The swap goroutine itself is left to finish on its own;
|
|
||||||
// when its swapDone arrives, handleSwapDone will find no entry in
|
|
||||||
// active and silently drop it.
|
|
||||||
for id := range targetSet {
|
|
||||||
s, ok := active[id]
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, w := range s.waiters {
|
|
||||||
b.grant(w, handlerResp{err: unloadErr})
|
|
||||||
}
|
|
||||||
delete(active, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Drop queued requests addressed to unloaded models. Requests for
|
|
||||||
// other models stay queued and may benefit from drainQueue at the end.
|
|
||||||
if len(*queued) > 0 {
|
|
||||||
kept := (*queued)[:0]
|
|
||||||
for _, w := range *queued {
|
|
||||||
if targetSet[w.model] {
|
|
||||||
b.grant(w, handlerResp{err: unloadErr})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
kept = append(kept, w)
|
|
||||||
}
|
|
||||||
*queued = kept
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop the targeted processes. Done synchronously so Unload's caller
|
|
||||||
// can rely on "after Unload returns, the process is stopped". inFlight
|
|
||||||
// is intentionally NOT cleared here: each dying handler will fire its
|
|
||||||
// trackedServe defer and reach handleServeDone in the normal way once
|
|
||||||
// the run loop is free again.
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for id := range targetSet {
|
|
||||||
p, ok := b.processes[id]
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
wg.Add(1)
|
|
||||||
go func(id string, p process.Process) {
|
|
||||||
defer wg.Done()
|
|
||||||
if err := p.Stop(req.timeout); err != nil {
|
|
||||||
b.logger.Warnf("%s: unloading %s failed: %v", b.name, id, err)
|
|
||||||
}
|
|
||||||
}(id, p)
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
// Removing entries from active above may have unblocked queued
|
|
||||||
// requests that previously collided with the now-cancelled swaps.
|
|
||||||
b.drainQueue(active, inFlight, queued)
|
|
||||||
|
|
||||||
close(req.respond)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *baseRouter) Shutdown(timeout time.Duration) error {
|
func (b *baseRouter) Shutdown(timeout time.Duration) error {
|
||||||
if !b.shuttingDown.CompareAndSwap(false, true) {
|
if !b.shuttingDown.CompareAndSwap(false, true) {
|
||||||
return fmt.Errorf("%s shutdown already in progress", b.name)
|
return fmt.Errorf("%s shutdown already in progress", b.name)
|
||||||
@@ -712,24 +409,24 @@ func (b *baseRouter) Shutdown(timeout time.Duration) error {
|
|||||||
|
|
||||||
func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
if b.shuttingDown.Load() {
|
if b.shuttingDown.Load() {
|
||||||
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := FetchContext(req, b.config)
|
data, err := shared.FetchContext(req, b.config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
SendError(w, req, err)
|
shared.SendError(w, req, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hr := handlerReq{
|
hr := scheduler.HandlerReq{
|
||||||
model: data.ModelID,
|
Model: data.ModelID,
|
||||||
ctx: req.Context(),
|
Ctx: req.Context(),
|
||||||
// Unbuffered: a successful send on respond proves the waiter is
|
// Unbuffered: a successful send on Respond proves the waiter is
|
||||||
// alive and consuming. grant() relies on this to avoid handing a
|
// alive and consuming. grant() relies on this to avoid handing a
|
||||||
// handleFunc to a cancelled waiter and leaking the inFlight count.
|
// handleFunc to a cancelled waiter and leaking the inFlight count.
|
||||||
respond: make(chan handlerResp),
|
Respond: make(chan scheduler.HandlerResp),
|
||||||
positionCh: make(chan int, 1),
|
PositionCh: make(chan int, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -737,7 +434,7 @@ func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|||||||
case <-req.Context().Done():
|
case <-req.Context().Done():
|
||||||
return
|
return
|
||||||
case <-b.shutdownCtx.Done():
|
case <-b.shutdownCtx.Done():
|
||||||
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -757,7 +454,7 @@ func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case pos := <-hr.positionCh:
|
case pos := <-hr.PositionCh:
|
||||||
lw.setUpdate(fmt.Sprintf("Queue position: #%d", pos))
|
lw.setUpdate(fmt.Sprintf("Queue position: #%d", pos))
|
||||||
case <-swapCtx.Done():
|
case <-swapCtx.Done():
|
||||||
return
|
return
|
||||||
@@ -779,22 +476,30 @@ func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var resp handlerResp
|
var resp scheduler.HandlerResp
|
||||||
select {
|
select {
|
||||||
case resp = <-hr.respond:
|
case resp = <-hr.Respond:
|
||||||
finishLoading()
|
finishLoading()
|
||||||
case <-req.Context().Done():
|
case <-req.Context().Done():
|
||||||
finishLoading()
|
finishLoading()
|
||||||
|
// Notify the scheduler so it can prune this request from its queue
|
||||||
|
// and swap waiters. Without this, a queued request whose client left
|
||||||
|
// would sit in the scheduler until drainQueue eventually starts a
|
||||||
|
// wasted model load for it.
|
||||||
|
select {
|
||||||
|
case b.cancelCh <- hr:
|
||||||
|
case <-b.shutdownCtx.Done():
|
||||||
|
}
|
||||||
return
|
return
|
||||||
case <-b.shutdownCtx.Done():
|
case <-b.shutdownCtx.Done():
|
||||||
finishLoading()
|
finishLoading()
|
||||||
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.err != nil {
|
if resp.Err != nil {
|
||||||
SendError(w, req, resp.err)
|
shared.SendError(w, req, resp.Err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.handleFunc(w, req)
|
resp.HandleFunc(w, req)
|
||||||
}
|
}
|
||||||
|
|||||||
+15
-614
@@ -5,35 +5,34 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
"github.com/mostlygeek/llama-swap/internal/process"
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
|
||||||
)
|
)
|
||||||
|
|
||||||
// stubPlanner is a swapPlanner that returns a fixed eviction list per target
|
// These tests cover baseRouter's own machinery — the run loop, process
|
||||||
// and never logs. It lets the base-router tests cover shared run-loop
|
// lifecycle (doSwap), grant/ServeHTTP plumbing, Unload, and Shutdown. The
|
||||||
// behaviour without dragging in either real router's eviction rules.
|
// scheduling decision logic (queueing, collation, eviction collisions) lives in
|
||||||
type stubPlanner struct {
|
// the scheduler package and is tested directly there; see fifo_test.go.
|
||||||
evict map[string][]string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stubPlanner) EvictionFor(target string, _ []string) []string {
|
// stubPlanner evicts nothing. baseRouter tests drive the run loop through the
|
||||||
if s.evict == nil {
|
// default FIFO scheduler without exercising any particular eviction policy.
|
||||||
return nil
|
type stubPlanner struct{}
|
||||||
}
|
|
||||||
return s.evict[target]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stubPlanner) OnSwapStart(string) {}
|
func (s *stubPlanner) EvictionFor(string, []string) []string { return nil }
|
||||||
|
func (s *stubPlanner) OnSwapStart(string, []string) {}
|
||||||
|
|
||||||
func newTestBase(t *testing.T, processes map[string]process.Process, planner swapPlanner) *baseRouter {
|
func newTestBase(t *testing.T, processes map[string]process.Process, planner scheduler.Swapper) *baseRouter {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
conf := config.Config{HealthCheckTimeout: 5}
|
conf := config.Config{HealthCheckTimeout: 5}
|
||||||
b := newBaseRouter("test", conf, processes, planner, logmon.NewWriter(io.Discard))
|
b, err := newBaseRouter("test", conf, processes, logmon.NewWriter(io.Discard), planner)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("newBaseRouter: %v", err)
|
||||||
|
}
|
||||||
b.testProcessed = make(chan struct{}, 64)
|
b.testProcessed = make(chan struct{}, 64)
|
||||||
go b.run()
|
go b.run()
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
@@ -157,114 +156,6 @@ func TestBaseRouter_Unload_StopsInParallel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestBaseRouter_Unload_ReleasesActiveSwapWaiters verifies that Unload
|
|
||||||
// rejoins router state: a request whose swap to the unloaded model is
|
|
||||||
// still in progress receives an error, instead of being abandoned
|
|
||||||
// against a process that's about to vanish.
|
|
||||||
func TestBaseRouter_Unload_ReleasesActiveSwapWaiters(t *testing.T) {
|
|
||||||
a := newFakeProcess("a")
|
|
||||||
// autoReady=false: the swap parks on WaitReady so we can interrupt
|
|
||||||
// it with Unload before it completes.
|
|
||||||
|
|
||||||
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
b.ServeHTTP(w, newRequest("a"))
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
waitProcessed(t, b.testProcessed, 1) // handlerReq absorbed; swap started
|
|
||||||
<-a.runStarted
|
|
||||||
|
|
||||||
b.Unload(time.Second, "a")
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("ServeHTTP did not return after Unload")
|
|
||||||
}
|
|
||||||
if w.Code == http.StatusOK {
|
|
||||||
t.Errorf("expected non-OK status after Unload, got %d body=%q", w.Code, w.Body.String())
|
|
||||||
}
|
|
||||||
if a.State() != process.StateStopped {
|
|
||||||
t.Errorf("a state=%q want stopped", a.State())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestBaseRouter_Unload_DropsQueuedRequests verifies that queued requests
|
|
||||||
// for an unloaded model receive an error rather than sitting forever in
|
|
||||||
// the queue against state the router no longer maintains.
|
|
||||||
func TestBaseRouter_Unload_DropsQueuedRequests(t *testing.T) {
|
|
||||||
a := newFakeProcess("a")
|
|
||||||
pb := newFakeProcess("b")
|
|
||||||
// Loading B evicts A — so a request for B while A is loading queues.
|
|
||||||
planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}}
|
|
||||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, planner)
|
|
||||||
|
|
||||||
// r1 starts the swap to A and parks on WaitReady (autoReady=false).
|
|
||||||
w1 := httptest.NewRecorder()
|
|
||||||
done1 := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
b.ServeHTTP(w1, newRequest("a"))
|
|
||||||
close(done1)
|
|
||||||
}()
|
|
||||||
waitProcessed(t, b.testProcessed, 1)
|
|
||||||
<-a.runStarted
|
|
||||||
|
|
||||||
// r2 for B collides with A's in-flight swap and queues.
|
|
||||||
w2 := httptest.NewRecorder()
|
|
||||||
done2 := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
b.ServeHTTP(w2, newRequest("b"))
|
|
||||||
close(done2)
|
|
||||||
}()
|
|
||||||
waitProcessed(t, b.testProcessed, 1)
|
|
||||||
|
|
||||||
// Unload B — r2 (queued, targeting B) must be released with an error.
|
|
||||||
b.Unload(time.Second, "b")
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done2:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("queued B request did not return after Unload(b)")
|
|
||||||
}
|
|
||||||
if w2.Code == http.StatusOK {
|
|
||||||
t.Errorf("queued B request: expected non-OK status, got %d", w2.Code)
|
|
||||||
}
|
|
||||||
if got := pb.runCalls.Load(); got != 0 {
|
|
||||||
t.Errorf("b.runCalls=%d want 0 (B should never have been started)", got)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Release r1 so the test cleans up cleanly.
|
|
||||||
a.markReady()
|
|
||||||
select {
|
|
||||||
case <-done1:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("r1 did not complete after a.markReady")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBaseRouter_FastPath(t *testing.T) {
|
|
||||||
a := newFakeProcess("a")
|
|
||||||
a.markReady()
|
|
||||||
|
|
||||||
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
b.ServeHTTP(w, newRequest("a"))
|
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
|
||||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
|
||||||
}
|
|
||||||
if got := a.serveCalls.Load(); got != 1 {
|
|
||||||
t.Errorf("serveCalls=%d want 1", got)
|
|
||||||
}
|
|
||||||
if got := a.runCalls.Load(); got != 0 {
|
|
||||||
t.Errorf("runCalls=%d want 0 (fast path should not start)", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBaseRouter_OnDemandStart(t *testing.T) {
|
func TestBaseRouter_OnDemandStart(t *testing.T) {
|
||||||
a := newFakeProcess("a")
|
a := newFakeProcess("a")
|
||||||
a.autoReady = true
|
a.autoReady = true
|
||||||
@@ -285,43 +176,6 @@ func TestBaseRouter_OnDemandStart(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBaseRouter_ConcurrentSameModel(t *testing.T) {
|
|
||||||
a := newFakeProcess("a")
|
|
||||||
// autoReady=false so the swap parks on WaitReady until we release it.
|
|
||||||
|
|
||||||
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
|
||||||
|
|
||||||
const N = 5
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
codes := make([]int, N)
|
|
||||||
for i := 0; i < N; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(i int) {
|
|
||||||
defer wg.Done()
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
b.ServeHTTP(w, newRequest("a"))
|
|
||||||
codes[i] = w.Code
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
waitProcessed(t, b.testProcessed, N) // all N handlerReqs absorbed by run()
|
|
||||||
<-a.runStarted // swap goroutine reached Run()
|
|
||||||
a.markReady()
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
for i, c := range codes {
|
|
||||||
if c != http.StatusOK {
|
|
||||||
t.Errorf("request %d: status=%d", i, c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if got := a.runCalls.Load(); got != 1 {
|
|
||||||
t.Errorf("runCalls=%d want 1 (single swap should issue one Run)", got)
|
|
||||||
}
|
|
||||||
if got := a.serveCalls.Load(); got != N {
|
|
||||||
t.Errorf("serveCalls=%d want %d", got, N)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBaseRouter_ContextCancel(t *testing.T) {
|
func TestBaseRouter_ContextCancel(t *testing.T) {
|
||||||
a := newFakeProcess("a")
|
a := newFakeProcess("a")
|
||||||
// autoReady=false so swap parks forever until we mark ready.
|
// autoReady=false so swap parks forever until we mark ready.
|
||||||
@@ -364,459 +218,6 @@ func TestBaseRouter_ContextCancel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBaseRouter_QueuedDifferentModel(t *testing.T) {
|
|
||||||
a := newFakeProcess("a")
|
|
||||||
pa := newFakeProcess("b")
|
|
||||||
|
|
||||||
// Loading b must stop a.
|
|
||||||
planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}}
|
|
||||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pa}, planner)
|
|
||||||
|
|
||||||
// First request starts a swap to A; A's autoReady=false so it parks.
|
|
||||||
w1 := httptest.NewRecorder()
|
|
||||||
done1 := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
b.ServeHTTP(w1, newRequest("a"))
|
|
||||||
close(done1)
|
|
||||||
}()
|
|
||||||
waitProcessed(t, b.testProcessed, 1)
|
|
||||||
<-a.runStarted
|
|
||||||
|
|
||||||
// Second request for B should queue while A's swap is in flight.
|
|
||||||
w2 := httptest.NewRecorder()
|
|
||||||
done2 := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
b.ServeHTTP(w2, newRequest("b"))
|
|
||||||
close(done2)
|
|
||||||
}()
|
|
||||||
waitProcessed(t, b.testProcessed, 1)
|
|
||||||
|
|
||||||
if got := pa.runCalls.Load(); got != 0 {
|
|
||||||
t.Errorf("b started early: runCalls=%d want 0 while A's swap is pending", got)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Release A's swap. B's swap should then run.
|
|
||||||
a.markReady()
|
|
||||||
waitProcessed(t, b.testProcessed, 1) // swapDone for A → B's swap kicked off
|
|
||||||
<-pa.runStarted
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done1:
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
t.Fatal("A request did not complete")
|
|
||||||
}
|
|
||||||
pa.markReady()
|
|
||||||
select {
|
|
||||||
case <-done2:
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
t.Fatal("queued B request did not complete after A's swap")
|
|
||||||
}
|
|
||||||
if w2.Code != http.StatusOK {
|
|
||||||
t.Errorf("B status=%d body=%q", w2.Code, w2.Body.String())
|
|
||||||
}
|
|
||||||
if got := a.stopCalls.Load(); got != 1 {
|
|
||||||
t.Errorf("a.stopCalls=%d want 1 (B's swap must stop A)", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestBaseRouter_QueueCollation verifies that incoming requests of the form
|
|
||||||
// a, b, c, a, b, c collapse into three swaps (one per model) and that the
|
|
||||||
// second request for each model rides the fast path — either joining the
|
|
||||||
// active swap, or being pulled out of the queue when handleSwapDone promotes
|
|
||||||
// the next model.
|
|
||||||
func TestBaseRouter_QueueCollation(t *testing.T) {
|
|
||||||
a := newFakeProcess("a")
|
|
||||||
pb := newFakeProcess("b")
|
|
||||||
pc := newFakeProcess("c")
|
|
||||||
|
|
||||||
// Each model evicts the other two so all swaps are mutually exclusive.
|
|
||||||
planner := &stubPlanner{evict: map[string][]string{
|
|
||||||
"a": {"b", "c"},
|
|
||||||
"b": {"a", "c"},
|
|
||||||
"c": {"a", "b"},
|
|
||||||
}}
|
|
||||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner)
|
|
||||||
|
|
||||||
var (
|
|
||||||
completedMu sync.Mutex
|
|
||||||
completed []string
|
|
||||||
)
|
|
||||||
record := func(id string) {
|
|
||||||
completedMu.Lock()
|
|
||||||
defer completedMu.Unlock()
|
|
||||||
completed = append(completed, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
ids := []string{"a", "b", "c", "a", "b", "c"}
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for _, id := range ids {
|
|
||||||
id := id
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
b.ServeHTTP(w, newRequest(id))
|
|
||||||
if w.Code != http.StatusOK {
|
|
||||||
t.Errorf("%s: status=%d body=%q", id, w.Code, w.Body.String())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
record(id)
|
|
||||||
}()
|
|
||||||
// Wait for run() to absorb this request before launching the next,
|
|
||||||
// so handlerCh receives them in launch order.
|
|
||||||
waitProcessed(t, b.testProcessed, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// All 6 are now parked in run()'s waiters/queue. Release each swap in
|
|
||||||
// sequence, waiting deterministically for each promotion to fire.
|
|
||||||
<-a.runStarted
|
|
||||||
a.markReady()
|
|
||||||
waitProcessed(t, b.testProcessed, 1) // swapDone(a) → b swap kicked off
|
|
||||||
|
|
||||||
<-pb.runStarted
|
|
||||||
pb.markReady()
|
|
||||||
waitProcessed(t, b.testProcessed, 1) // swapDone(b) → c swap kicked off
|
|
||||||
|
|
||||||
<-pc.runStarted
|
|
||||||
pc.markReady()
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
if got := len(completed); got != 6 {
|
|
||||||
t.Fatalf("completed=%v want 6", completed)
|
|
||||||
}
|
|
||||||
|
|
||||||
// run() fans out responses in model-grouped order (a1,a2 → b1,b2 → c1,c2)
|
|
||||||
// but waiter goroutines may be scheduled in any order after their respond
|
|
||||||
// channel fires, so completion order isn't deterministic. Per-model counts
|
|
||||||
// (combined with the runCalls checks below) are sufficient to prove queue
|
|
||||||
// collation collapsed each pair into a single swap.
|
|
||||||
aDone, bDone, cDone := 0, 0, 0
|
|
||||||
for _, id := range completed {
|
|
||||||
switch id {
|
|
||||||
case "a":
|
|
||||||
aDone++
|
|
||||||
case "b":
|
|
||||||
bDone++
|
|
||||||
case "c":
|
|
||||||
cDone++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if aDone != 2 || bDone != 2 || cDone != 2 {
|
|
||||||
t.Errorf("per-model counts: a=%d b=%d c=%d, want 2 each (order=%v)", aDone, bDone, cDone, completed)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Single swap per model — the second request for each must have ridden
|
|
||||||
// the fast path (joined active swap or joined a queued sibling), not
|
|
||||||
// triggered an extra Run.
|
|
||||||
if got := a.runCalls.Load(); got != 1 {
|
|
||||||
t.Errorf("a.runCalls=%d want 1", got)
|
|
||||||
}
|
|
||||||
if got := pb.runCalls.Load(); got != 1 {
|
|
||||||
t.Errorf("b.runCalls=%d want 1", got)
|
|
||||||
}
|
|
||||||
if got := pc.runCalls.Load(); got != 1 {
|
|
||||||
t.Errorf("c.runCalls=%d want 1", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestBaseRouter_ConcurrentDisjointSwaps verifies that two requests with
|
|
||||||
// non-conflicting evict sets are loaded in parallel: both Run() calls happen
|
|
||||||
// before either process is marked ready.
|
|
||||||
func TestBaseRouter_ConcurrentDisjointSwaps(t *testing.T) {
|
|
||||||
a := newFakeProcess("a")
|
|
||||||
pb := newFakeProcess("b")
|
|
||||||
|
|
||||||
// Empty evict sets for both: they can load in parallel.
|
|
||||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, &stubPlanner{})
|
|
||||||
|
|
||||||
w1 := httptest.NewRecorder()
|
|
||||||
done1 := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
b.ServeHTTP(w1, newRequest("a"))
|
|
||||||
close(done1)
|
|
||||||
}()
|
|
||||||
waitProcessed(t, b.testProcessed, 1)
|
|
||||||
|
|
||||||
w2 := httptest.NewRecorder()
|
|
||||||
done2 := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
b.ServeHTTP(w2, newRequest("b"))
|
|
||||||
close(done2)
|
|
||||||
}()
|
|
||||||
waitProcessed(t, b.testProcessed, 1)
|
|
||||||
|
|
||||||
// Both swaps must have reached Run() before either is marked ready —
|
|
||||||
// proves they ran in parallel rather than serializing.
|
|
||||||
<-a.runStarted
|
|
||||||
<-pb.runStarted
|
|
||||||
|
|
||||||
a.markReady()
|
|
||||||
pb.markReady()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done1:
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
t.Fatal("request A did not complete")
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-done2:
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
t.Fatal("request B did not complete")
|
|
||||||
}
|
|
||||||
|
|
||||||
if w1.Code != http.StatusOK {
|
|
||||||
t.Errorf("A status=%d body=%q", w1.Code, w1.Body.String())
|
|
||||||
}
|
|
||||||
if w2.Code != http.StatusOK {
|
|
||||||
t.Errorf("B status=%d body=%q", w2.Code, w2.Body.String())
|
|
||||||
}
|
|
||||||
if got := a.stopCalls.Load(); got != 0 {
|
|
||||||
t.Errorf("a.stopCalls=%d want 0 (parallel swap, no eviction)", got)
|
|
||||||
}
|
|
||||||
if got := pb.stopCalls.Load(); got != 0 {
|
|
||||||
t.Errorf("b.stopCalls=%d want 0 (parallel swap, no eviction)", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestBaseRouter_QueueDrainPromotesMultiple verifies that completing one swap
|
|
||||||
// unblocks every queued request that no longer collides — they all start in
|
|
||||||
// parallel rather than one-per-completion.
|
|
||||||
func TestBaseRouter_QueueDrainPromotesMultiple(t *testing.T) {
|
|
||||||
a := newFakeProcess("a")
|
|
||||||
pb := newFakeProcess("b")
|
|
||||||
pc := newFakeProcess("c")
|
|
||||||
|
|
||||||
// A's swap evicts both B and C, so B and C must queue. Once A finishes
|
|
||||||
// B and C themselves have empty evict sets, so they can start together.
|
|
||||||
planner := &stubPlanner{evict: map[string][]string{
|
|
||||||
"a": {"b", "c"},
|
|
||||||
}}
|
|
||||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner)
|
|
||||||
|
|
||||||
w1 := httptest.NewRecorder()
|
|
||||||
done1 := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
b.ServeHTTP(w1, newRequest("a"))
|
|
||||||
close(done1)
|
|
||||||
}()
|
|
||||||
waitProcessed(t, b.testProcessed, 1)
|
|
||||||
<-a.runStarted
|
|
||||||
|
|
||||||
// B and C arrive while A is loading. evict_b and evict_c are empty,
|
|
||||||
// but collidesWith returns true because they appear in A's evict set.
|
|
||||||
w2 := httptest.NewRecorder()
|
|
||||||
done2 := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
b.ServeHTTP(w2, newRequest("b"))
|
|
||||||
close(done2)
|
|
||||||
}()
|
|
||||||
waitProcessed(t, b.testProcessed, 1)
|
|
||||||
|
|
||||||
w3 := httptest.NewRecorder()
|
|
||||||
done3 := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
b.ServeHTTP(w3, newRequest("c"))
|
|
||||||
close(done3)
|
|
||||||
}()
|
|
||||||
waitProcessed(t, b.testProcessed, 1)
|
|
||||||
|
|
||||||
if got := pb.runCalls.Load(); got != 0 {
|
|
||||||
t.Errorf("b started early: runCalls=%d", got)
|
|
||||||
}
|
|
||||||
if got := pc.runCalls.Load(); got != 0 {
|
|
||||||
t.Errorf("c started early: runCalls=%d", got)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Release A. The swapDone handler should drain the queue and start
|
|
||||||
// both B and C in parallel.
|
|
||||||
a.markReady()
|
|
||||||
waitProcessed(t, b.testProcessed, 1) // swapDone(A) → drainQueue starts B and C
|
|
||||||
<-pb.runStarted
|
|
||||||
<-pc.runStarted
|
|
||||||
|
|
||||||
pb.markReady()
|
|
||||||
pc.markReady()
|
|
||||||
|
|
||||||
for i, ch := range []chan struct{}{done1, done2, done3} {
|
|
||||||
select {
|
|
||||||
case <-ch:
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
t.Fatalf("request %d did not complete", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestBaseRouter_Shutdown_FailsAllInFlight verifies that shutdown returns
|
|
||||||
// the shutdown error to every waiter on every active swap AND to every
|
|
||||||
// queued request.
|
|
||||||
func TestBaseRouter_Shutdown_FailsAllInFlight(t *testing.T) {
|
|
||||||
a := newFakeProcess("a")
|
|
||||||
pb := newFakeProcess("b")
|
|
||||||
pc := newFakeProcess("c")
|
|
||||||
|
|
||||||
// a and b load in parallel (empty evicts). c collides with both.
|
|
||||||
planner := &stubPlanner{evict: map[string][]string{
|
|
||||||
"c": {"a", "b"},
|
|
||||||
}}
|
|
||||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner)
|
|
||||||
|
|
||||||
const waitersPer = 2
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
codes := make([]int, 0, 2*waitersPer+1)
|
|
||||||
var codesMu sync.Mutex
|
|
||||||
record := func(code int) {
|
|
||||||
codesMu.Lock()
|
|
||||||
codes = append(codes, code)
|
|
||||||
codesMu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
launch := func(model string) {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
b.ServeHTTP(w, newRequest(model))
|
|
||||||
record(w.Code)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Active swaps for a and b, each with 2 waiters.
|
|
||||||
for i := 0; i < waitersPer; i++ {
|
|
||||||
launch("a")
|
|
||||||
waitProcessed(t, b.testProcessed, 1)
|
|
||||||
}
|
|
||||||
for i := 0; i < waitersPer; i++ {
|
|
||||||
launch("b")
|
|
||||||
waitProcessed(t, b.testProcessed, 1)
|
|
||||||
}
|
|
||||||
// c collides with both → queues.
|
|
||||||
launch("c")
|
|
||||||
waitProcessed(t, b.testProcessed, 1)
|
|
||||||
|
|
||||||
<-a.runStarted
|
|
||||||
<-pb.runStarted
|
|
||||||
|
|
||||||
if err := b.Shutdown(time.Second); err != nil {
|
|
||||||
t.Fatalf("Shutdown: %v", err)
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
codesMu.Lock()
|
|
||||||
defer codesMu.Unlock()
|
|
||||||
if len(codes) != 2*waitersPer+1 {
|
|
||||||
t.Fatalf("got %d responses, want %d", len(codes), 2*waitersPer+1)
|
|
||||||
}
|
|
||||||
for i, c := range codes {
|
|
||||||
if c == http.StatusOK {
|
|
||||||
t.Errorf("response %d: status=%d, want non-200 (shutdown)", i, c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestBaseRouter_NoSwapWhileServing verifies that an already-loaded model
|
|
||||||
// is not stopped to satisfy another model's swap while it is still handling
|
|
||||||
// a request.
|
|
||||||
//
|
|
||||||
// Sequence:
|
|
||||||
// 1. r1 (A) — A loads; ServeHTTP enters and is pinned via serveBlock.
|
|
||||||
// 2. r2 (B, planner: B evicts A) — must NOT cause A.Stop while r1 is live.
|
|
||||||
// 3. r3 (A) — arrives next; the existing code queues it because B's swap
|
|
||||||
// intent collides with A.
|
|
||||||
// 4. r1 released — A finishes r1, then r3 is served by A.
|
|
||||||
// 5. B's swap then proceeds; r2 is served by B.
|
|
||||||
//
|
|
||||||
// fakeProcess.stoppedWhileServing flips true if Stop is ever called while
|
|
||||||
// a ServeHTTP is in flight — a direct, race-free signal of the violation.
|
|
||||||
func TestBaseRouter_NoSwapWhileServing(t *testing.T) {
|
|
||||||
a := newFakeProcess("a")
|
|
||||||
// autoReady left false: we markReady manually after observing runStarted,
|
|
||||||
// so autoReady's setState(Ready) cannot race with a later Stop and leave
|
|
||||||
// A in Ready, masking the bug.
|
|
||||||
a.serveBlock = make(chan struct{})
|
|
||||||
pb := newFakeProcess("b")
|
|
||||||
// Same reasoning for B: park its swap on WaitReady until we choose.
|
|
||||||
|
|
||||||
planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}}
|
|
||||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, planner)
|
|
||||||
|
|
||||||
// r1 — load A and enter its ServeHTTP (which blocks on serveBlock).
|
|
||||||
w1 := httptest.NewRecorder()
|
|
||||||
done1 := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
b.ServeHTTP(w1, newRequest("a"))
|
|
||||||
close(done1)
|
|
||||||
}()
|
|
||||||
waitProcessed(t, b.testProcessed, 1) // handlerReq for r1
|
|
||||||
<-a.runStarted
|
|
||||||
a.markReady()
|
|
||||||
waitProcessed(t, b.testProcessed, 1) // swapDone for A
|
|
||||||
<-a.serveStarted
|
|
||||||
|
|
||||||
// r2 — would evict A. A must not be stopped while r1 is in flight.
|
|
||||||
w2 := httptest.NewRecorder()
|
|
||||||
done2 := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
b.ServeHTTP(w2, newRequest("b"))
|
|
||||||
close(done2)
|
|
||||||
}()
|
|
||||||
waitProcessed(t, b.testProcessed, 1)
|
|
||||||
|
|
||||||
// r3 — another request for A, arrives behind r2 and queues because
|
|
||||||
// B's swap intent (which evicts A) is recorded as active.
|
|
||||||
w3 := httptest.NewRecorder()
|
|
||||||
done3 := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
b.ServeHTTP(w3, newRequest("a"))
|
|
||||||
close(done3)
|
|
||||||
}()
|
|
||||||
waitProcessed(t, b.testProcessed, 1)
|
|
||||||
|
|
||||||
// Release r1 (and r3 if it is fast-pathed onto the still-loaded A).
|
|
||||||
// The router must hold off B's swap until A has drained.
|
|
||||||
close(a.serveBlock)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done1:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("r1 did not complete after serveBlock release")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for B.Run before marking it ready: markReady before Run would
|
|
||||||
// skip the Run path entirely and leave pb.runCalls at 0. In a correct
|
|
||||||
// implementation B's swap only starts after A has drained; in the
|
|
||||||
// current implementation it has already started — either way runStarted
|
|
||||||
// fires.
|
|
||||||
<-pb.runStarted
|
|
||||||
pb.markReady()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done2:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("r2 did not complete after B marked ready")
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-done3:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("r3 did not complete")
|
|
||||||
}
|
|
||||||
|
|
||||||
if w1.Code != http.StatusOK || w2.Code != http.StatusOK || w3.Code != http.StatusOK {
|
|
||||||
t.Fatalf("statuses: w1=%d w2=%d w3=%d", w1.Code, w2.Code, w3.Code)
|
|
||||||
}
|
|
||||||
if w1.Body.String() != "ok:a" {
|
|
||||||
t.Errorf("r1 body=%q want ok:a", w1.Body.String())
|
|
||||||
}
|
|
||||||
if w3.Body.String() != "ok:a" {
|
|
||||||
t.Errorf("r3 body=%q want ok:a (r3 must be served by A)", w3.Body.String())
|
|
||||||
}
|
|
||||||
if w2.Body.String() != "ok:b" {
|
|
||||||
t.Errorf("r2 body=%q want ok:b", w2.Body.String())
|
|
||||||
}
|
|
||||||
if a.stoppedWhileServing.Load() {
|
|
||||||
t.Errorf("A.Stop was called while A was still handling a request — the router swapped out a busy process")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBaseRouter_ModelNotFound(t *testing.T) {
|
func TestBaseRouter_ModelNotFound(t *testing.T) {
|
||||||
a := newFakeProcess("a")
|
a := newFakeProcess("a")
|
||||||
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||||
|
|||||||
@@ -0,0 +1,404 @@
|
|||||||
|
# Router design
|
||||||
|
|
||||||
|
A developer tutorial for the `internal/router` package and its `scheduler`
|
||||||
|
sub-package.
|
||||||
|
|
||||||
|
## Intro
|
||||||
|
|
||||||
|
A llama-swap router is the component that sits behind the proxy and answers one
|
||||||
|
question for every incoming request: _can this model serve right now, and if
|
||||||
|
not, what has to happen first?_ Answering it means juggling three concerns that
|
||||||
|
used to live tangled together in one type:
|
||||||
|
|
||||||
|
1. **Process machinery** — owning the OS processes, starting and stopping them,
|
||||||
|
running health checks, and shuttling HTTP requests onto the right upstream.
|
||||||
|
2. **Scheduling strategy** — the queue, in-flight bookkeeping, and the decision
|
||||||
|
tree that turns one request into "serve now", "join an existing swap",
|
||||||
|
"queue", or "start a swap".
|
||||||
|
3. **Eviction policy** — given a model we want to load, which currently-running
|
||||||
|
models have to be stopped to make room?
|
||||||
|
|
||||||
|
The design pulls those three apart into separate, independently replaceable
|
||||||
|
pieces:
|
||||||
|
|
||||||
|
| Concern | Type | Lives in |
|
||||||
|
| ------------------- | ------------------------------ | ------------------------------- |
|
||||||
|
| Process machinery | `baseRouter` | `internal/router/base.go` |
|
||||||
|
| Scheduling strategy | `scheduler.Scheduler` (`FIFO`) | `internal/router/scheduler/` |
|
||||||
|
| Eviction policy | `scheduler.Swapper` | `groupSwapper`, `matrixSwapper` |
|
||||||
|
|
||||||
|
`baseRouter` keeps the channels, run loop, process lifecycle, and shutdown
|
||||||
|
teardown, and exposes the side-effects a scheduler needs through the
|
||||||
|
`scheduler.Effects` interface. The scheduler owns the queue and decision tree
|
||||||
|
but performs no side-effects directly — it calls back through `Effects`. The
|
||||||
|
`Swapper` is a pure function from "target model + currently running" to "models
|
||||||
|
to evict", and knows nothing about queues, channels, or processes.
|
||||||
|
|
||||||
|
Because the seams are interfaces, you can replace the scheduling strategy
|
||||||
|
without touching process management, or write a new eviction policy without
|
||||||
|
touching either. `FIFO` is the first and currently only `Scheduler`;
|
||||||
|
`groupSwapper` and `matrixSwapper` are the two `Swapper`s.
|
||||||
|
|
||||||
|
## Key concepts
|
||||||
|
|
||||||
|
### One run loop, no locks
|
||||||
|
|
||||||
|
`baseRouter.run()` is a single goroutine selecting over a handful of channels:
|
||||||
|
|
||||||
|
```go
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case req := <-b.shutdownCh: b.handleShutdown(req); return
|
||||||
|
case req := <-b.handlerCh: b.schedule.OnRequest(req)
|
||||||
|
case req := <-b.unloadCh: b.schedule.OnUnload(req.targets, req.timeout); close(req.respond)
|
||||||
|
case ev := <-b.swapDoneCh: b.schedule.OnSwapDone(ev)
|
||||||
|
case ev := <-b.serveDoneCh: b.schedule.OnServeDone(ev)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Every `Scheduler` method runs on this one goroutine. That is the single most
|
||||||
|
important fact about the design: **the scheduler never needs a mutex for its own
|
||||||
|
state**. All scheduler state is touched only from these callbacks, which are
|
||||||
|
serialized by the run loop. If you write a new scheduler, you get the same
|
||||||
|
guarantee for free — and you must not break it by spinning up goroutines that
|
||||||
|
mutate scheduler state.
|
||||||
|
|
||||||
|
### Events flow in, side-effects flow out
|
||||||
|
|
||||||
|
The run loop turns external happenings into method calls on the scheduler:
|
||||||
|
|
||||||
|
- A new HTTP request becomes `OnRequest(HandlerReq)`.
|
||||||
|
- A swap goroutine finishing becomes `OnSwapDone(SwapDone)`.
|
||||||
|
- A tracked request handler returning becomes `OnServeDone(ServeDoneEvent)`.
|
||||||
|
- An admin unload becomes `OnUnload(targets, timeout)`.
|
||||||
|
- Shutdown becomes `OnShutdown(err)`.
|
||||||
|
|
||||||
|
The scheduler reacts by calling **back out** through `Effects`: inspect a
|
||||||
|
process state, start a swap, grant a response to a caller, or stop processes. It
|
||||||
|
never calls `process.Process` directly and never writes to a channel directly.
|
||||||
|
This keeps the scheduler pure enough to unit-test against a fake `Effects` with
|
||||||
|
no goroutines or real processes involved (see `scheduler/fifo_test.go`).
|
||||||
|
|
||||||
|
```
|
||||||
|
HTTP request admin Unload / Shutdown
|
||||||
|
│ │
|
||||||
|
▼ ▼
|
||||||
|
ServeHTTP ──HandlerReq──▶ baseRouter.run() ◀──unloadCh/shutdownCh
|
||||||
|
│ (single goroutine)
|
||||||
|
▼
|
||||||
|
Scheduler.On*(...)
|
||||||
|
│ calls back through
|
||||||
|
▼
|
||||||
|
Effects: ModelState / StartSwap /
|
||||||
|
GrantServe / GrantError / StopProcesses
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
baseRouter side-effects: doSwap goroutine,
|
||||||
|
grant() to caller, process.Stop()
|
||||||
|
│
|
||||||
|
swap completes ──SwapDone──▶ back into run loop
|
||||||
|
```
|
||||||
|
|
||||||
|
### The swap goroutine
|
||||||
|
|
||||||
|
Scheduling decisions must be quick and non-blocking, but loading a model is
|
||||||
|
slow. The two are reconciled by doing the slow part on a separate goroutine.
|
||||||
|
|
||||||
|
When the scheduler decides to start a swap, inside `OnRequest` it:
|
||||||
|
|
||||||
|
1. records "a swap for X is in flight" in its own state, then
|
||||||
|
2. calls `Effects.StartSwap(modelID, evict)`.
|
||||||
|
|
||||||
|
`StartSwap` does **not** load the model itself — it just launches a detached
|
||||||
|
goroutine (`doSwap`) and returns straight away. `doSwap` is what does the slow
|
||||||
|
work: stop the evicted processes, start the target, wait for it to become ready.
|
||||||
|
Because `StartSwap` returned immediately, `OnRequest` returns too, and the run
|
||||||
|
loop is free to pick up the next event — another request, a serve-done, an
|
||||||
|
unload — while `doSwap` runs in the background.
|
||||||
|
|
||||||
|
The swap's eventual result comes back as just another event: when `doSwap`
|
||||||
|
finishes it posts a `SwapDone` onto `swapDoneCh`, which the run loop delivers as
|
||||||
|
`OnSwapDone`. So a slow load never blocks the run loop; it brackets it with two
|
||||||
|
quick events (`OnRequest` to start, `OnSwapDone` to finish) and everything in
|
||||||
|
between is handled normally.
|
||||||
|
|
||||||
|
### In-flight tracking and `trackedServe`
|
||||||
|
|
||||||
|
When the scheduler grants a request, the handler it hands back is wrapped by
|
||||||
|
`baseRouter.trackedServe`. The wrapper runs the real `ServeHTTP` and, on return,
|
||||||
|
posts a `ServeDoneEvent` so the run loop can decrement the per-model in-flight
|
||||||
|
count. This is why the scheduler can know whether a process is "busy": it counts
|
||||||
|
grants out and serve-dones in. A swap that would evict a busy process is
|
||||||
|
deferred until that process's in-flight count hits zero (`OnServeDone` then
|
||||||
|
re-drains the queue).
|
||||||
|
|
||||||
|
The subtle contract here is `GrantServe`'s boolean return. The caller's
|
||||||
|
`Respond` channel is unbuffered, so a successful send proves the HTTP goroutine
|
||||||
|
is alive and took the handler. If the caller already disconnected, the send
|
||||||
|
fails, `trackedServe` never runs, and **no** `ServeDoneEvent` will ever arrive —
|
||||||
|
so the scheduler must only increment `inFlight` when `GrantServe` returns true.
|
||||||
|
Incrementing on a false return would strand the counter above zero and the model
|
||||||
|
could never be evicted again.
|
||||||
|
|
||||||
|
## The interfaces
|
||||||
|
|
||||||
|
All three live in `scheduler/scheduler.go`.
|
||||||
|
|
||||||
|
### `Scheduler`
|
||||||
|
|
||||||
|
```go
|
||||||
|
type Scheduler interface {
|
||||||
|
OnRequest(req HandlerReq)
|
||||||
|
OnSwapDone(ev SwapDone)
|
||||||
|
OnServeDone(ev ServeDoneEvent)
|
||||||
|
OnUnload(targets []string, timeout time.Duration)
|
||||||
|
OnShutdown(err error)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Owns the queue, in-flight tracking, and the decision tree. All methods run on
|
||||||
|
the run-loop goroutine, so no internal locking is needed.
|
||||||
|
|
||||||
|
### `Swapper`
|
||||||
|
|
||||||
|
```go
|
||||||
|
type Swapper interface {
|
||||||
|
EvictionFor(target string, running []string) []string
|
||||||
|
OnSwapStart(target string, running []string)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The eviction policy. `EvictionFor` is a **pure decision** — given the target and
|
||||||
|
the complete `running` set, return the running model IDs that must stop. It must
|
||||||
|
not log or mutate anything, and it does **not** inspect process state itself:
|
||||||
|
the scheduler hands it `running` already assembled (every non-stopped process,
|
||||||
|
unioned with the targets of in-flight swaps already committed but not yet
|
||||||
|
visible in process state). That keeps the swapper a pure function of its inputs,
|
||||||
|
with no reference to processes.
|
||||||
|
|
||||||
|
The reason it must not log is that it is a _speculative_ query — "what would we
|
||||||
|
evict if we started this swap right now?" — called far more often than swaps
|
||||||
|
actually happen. The scheduler calls it once per incoming request, and then
|
||||||
|
**again for every still-queued request on every queue drain** (each `OnSwapDone`,
|
||||||
|
`OnServeDone`, and `OnUnload`). Most of those calls end in "still queued",
|
||||||
|
"collides", or "nothing to evict", not a real swap. Logging there would emit
|
||||||
|
duplicate lines for a request that simply sits in the queue, and lines for
|
||||||
|
decisions that never happen — the log would stop meaning "a swap occurred".
|
||||||
|
|
||||||
|
`OnSwapStart` is the one place a Swapper may log, because it is called exactly
|
||||||
|
once, at the moment a swap is committed. One log line there equals one real swap,
|
||||||
|
with the evict set that is genuinely being applied — which is why `matrixSwapper`
|
||||||
|
re-solves and logs the full decision (set, DSL, cost) in `OnSwapStart` rather
|
||||||
|
than in `EvictionFor`.
|
||||||
|
|
||||||
|
### `Effects`
|
||||||
|
|
||||||
|
```go
|
||||||
|
type Effects interface {
|
||||||
|
ModelState(modelID string) (process.ProcessState, bool)
|
||||||
|
RunningModels() map[string]process.ProcessState
|
||||||
|
StartSwap(modelID string, evict []string)
|
||||||
|
GrantError(req HandlerReq, err error)
|
||||||
|
GrantServe(req HandlerReq, modelID string) bool
|
||||||
|
StopProcesses(timeout time.Duration, ids []string)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Implemented by `baseRouter`. This is the scheduler's entire window onto the
|
||||||
|
outside world; everything else about the router is hidden from it. See the
|
||||||
|
deep-dive below.
|
||||||
|
|
||||||
|
### `Factory` — wiring it together
|
||||||
|
|
||||||
|
```go
|
||||||
|
type Factory func(name string, logger *logmon.Monitor, eff Effects) Scheduler
|
||||||
|
```
|
||||||
|
|
||||||
|
`baseRouter` doesn't know which scheduler or swapper it has — it is handed a
|
||||||
|
`Factory` at construction and calls it once, passing itself as the `Effects`.
|
||||||
|
The concrete router captures its `Swapper` in the closure. From `group.go`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
swapper := &groupSwapper{ /* ... */ }
|
||||||
|
base := newBaseRouter("group", conf, processes, proxylog,
|
||||||
|
func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler {
|
||||||
|
return scheduler.NewFIFO(name, logger, swapper, eff)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
This closure is the single point where the three pieces meet: it binds a
|
||||||
|
specific `Swapper` (`swapper`) and a specific `Scheduler` (`FIFO`) to the
|
||||||
|
`baseRouter`'s `Effects` (`eff`).
|
||||||
|
|
||||||
|
**The swapper is a separate type from the concrete router.** There are currently two router implementations router.Group and router.Matrix. Each of these has a custom swapper that implements scheduler.Swapper for custom eviction logic. This decoupling of responsibilities makes it easy to implement custom swapping strategies.
|
||||||
|
|
||||||
|
### The events
|
||||||
|
|
||||||
|
A single goroutine in `baseRouter.run()` owns and serializes all state changes in the router. By processing events one at a time it ensures correctness and eliminates complex mutex lock logic.
|
||||||
|
|
||||||
|
These are the events the router currently uses:
|
||||||
|
|
||||||
|
```go
|
||||||
|
type HandlerReq struct { // one in-flight ServeHTTP awaiting a decision
|
||||||
|
Model string
|
||||||
|
Ctx context.Context
|
||||||
|
Respond chan HandlerResp // UNBUFFERED — see GrantServe contract
|
||||||
|
PositionCh chan int // queue-position updates for the loading UI
|
||||||
|
}
|
||||||
|
|
||||||
|
type HandlerResp struct { // the decision handed back to the caller
|
||||||
|
HandleFunc http.HandlerFunc // serve with this, or...
|
||||||
|
Err error // ...fail with this
|
||||||
|
}
|
||||||
|
|
||||||
|
type SwapDone struct{ ModelID string; Err error } // swap goroutine finished
|
||||||
|
type ServeDoneEvent struct{ ModelID string } // tracked handler returned
|
||||||
|
```
|
||||||
|
|
||||||
|
## Deep-dive: the `Effects` interface and why it exists
|
||||||
|
|
||||||
|
`Effects` is the inversion-of-control boundary that makes the split possible.
|
||||||
|
The scheduler decides and `baseRouter` _acts_. Pulling the side-effects behind this
|
||||||
|
interface buys three things:
|
||||||
|
|
||||||
|
1. **Purity and testability.** The scheduler performs no I/O, starts no
|
||||||
|
goroutines of its own, and touches no real processes. Its tests drive the
|
||||||
|
`On*` methods directly and assert on a `fakeEffects` that just records the
|
||||||
|
calls — synchronous, deterministic, no sleeps. (`scheduler/fifo_test.go`.)
|
||||||
|
2. **A single, auditable side-effect surface.** Every externally-visible thing a
|
||||||
|
scheduler can do is one of six methods. You can reason about the whole
|
||||||
|
contract by reading one interface.
|
||||||
|
3. **Decoupling lifetime.** The scheduler never holds a `process.Process`,
|
||||||
|
never sees a channel, and never learns how shutdown teardown works. It only
|
||||||
|
knows model IDs and states.
|
||||||
|
|
||||||
|
Method by method, as implemented in `base.go`:
|
||||||
|
|
||||||
|
- **`ModelState(modelID) (state, ok)`** — read-only snapshot of a process's
|
||||||
|
state, and whether this router handles the model at all. The scheduler uses it
|
||||||
|
for the "unknown model" check and the "already ready" fast path. Safe to call
|
||||||
|
any time because the process map is fixed at construction and `State()` is a
|
||||||
|
snapshot.
|
||||||
|
|
||||||
|
- **`RunningModels()`** — the state of every process that isn't stopped or shut
|
||||||
|
down. The scheduler unions its keys with its own in-flight swap targets to
|
||||||
|
build the `running` set it hands the `Swapper`, so the swapper never has to
|
||||||
|
touch process state itself.
|
||||||
|
|
||||||
|
- **`StartSwap(modelID, evict)`** — fire-and-forget. `baseRouter` launches the
|
||||||
|
`doSwap` goroutine and returns immediately; the result comes back later as a
|
||||||
|
`SwapDone`. The scheduler records the swap as active _before_ calling this so
|
||||||
|
that requests arriving in the meantime can join it.
|
||||||
|
|
||||||
|
- **`GrantError(req, err)`** — hand a caller an error response. Used for unknown
|
||||||
|
models, failed swaps, unloads, and shutdown.
|
||||||
|
|
||||||
|
- **`GrantServe(req, modelID) bool`** — hand a caller the tracked handler for a
|
||||||
|
ready model, returning whether the caller was still there to receive it. The
|
||||||
|
scheduler increments the in-flight count **only on a true return** (see the
|
||||||
|
in-flight contract above). This is the one `Effects` method whose return value
|
||||||
|
carries state-machine significance.
|
||||||
|
|
||||||
|
- **`StopProcesses(timeout, ids)`** — stop processes in parallel and **block**
|
||||||
|
until all have stopped. Used by `OnUnload` so an admin `Unload` call can
|
||||||
|
guarantee the process is dead by the time it returns. (Note `StartSwap` is
|
||||||
|
async but `StopProcesses` is sync — the difference is deliberate and tied to
|
||||||
|
the caller's expectations.)
|
||||||
|
|
||||||
|
A useful way to hold it in your head: `Effects` is the scheduler's syscall
|
||||||
|
table. The scheduler is a pure state machine; `Effects` is how it touches the
|
||||||
|
world, and `baseRouter` is the kernel that implements those syscalls with real
|
||||||
|
goroutines, channels, and processes.
|
||||||
|
|
||||||
|
## How to implement a new `Swapper`
|
||||||
|
|
||||||
|
A `Swapper` is a pure decision function plus a logging hook — the easiest of the three pieces to replace.
|
||||||
|
|
||||||
|
1. **Write the swapper type** and give it whatever config it needs to make a
|
||||||
|
decision. It does **not** need the process map — the scheduler supplies the
|
||||||
|
running set as an argument. `groupSwapper` holds only its group config;
|
||||||
|
`matrixSwapper` holds only its solver and logger:
|
||||||
|
|
||||||
|
```go
|
||||||
|
type mySwapper struct {
|
||||||
|
config config.Config
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Implement `EvictionFor(target, running)`** as a _pure_ decision:
|
||||||
|
- `running` is the complete live set, already assembled for you: every
|
||||||
|
non-stopped process unioned with the targets of in-flight swaps the
|
||||||
|
scheduler has committed to. You don't filter process state or fold in
|
||||||
|
in-flight targets yourself, that's the scheduler's job. Just decide against the slice you're handed.
|
||||||
|
- Return the list of model IDs in `running` that must stop for `target` to
|
||||||
|
run. Return `nil`/empty when nothing needs evicting.
|
||||||
|
- Do **not** mutate state here.
|
||||||
|
- Do **not** log here. It can be called multiple times per request. Since it is pure function have tests verify the expected behaviour.
|
||||||
|
|
||||||
|
3. **Implement `OnSwapStart(target, running)`** — called once when a swap
|
||||||
|
actually begins, with the same `running` set `EvictionFor` saw. This is the
|
||||||
|
right place to log: one call equals one real swap. `matrixSwapper` re-solves
|
||||||
|
and logs the chosen set and cost here; `groupSwapper` logs nothing.
|
||||||
|
|
||||||
|
4. **Wire it in** by instantiating the swapper in your router's constructor and
|
||||||
|
capturing it in the `Factory` closure passed to `newBaseRouter` — exactly as
|
||||||
|
`NewGroup` and `NewMatrix` do. The router struct itself only ever embeds
|
||||||
|
`*baseRouter`; the swapper reaches the scheduler solely through that closure.
|
||||||
|
|
||||||
|
Reference implementations: `groupSwapper` (static group config) in `group.go`
|
||||||
|
and `matrixSwapper` (cost-based set solver) in `matrix.go`.
|
||||||
|
|
||||||
|
## How to implement a new `Scheduler`
|
||||||
|
|
||||||
|
Replacing the scheduler means taking over the queue and the entire decision tree. Read `scheduler/fifo.go` end to end first — it is the reference implementation and the rules below are easiest to understand in context.
|
||||||
|
|
||||||
|
The rules you must honour:
|
||||||
|
|
||||||
|
- **Single goroutine.** Every method runs on the `baseRouter.run()` goroutine. Keep your state in plain maps/slices and never read or write it from another goroutine. If you need slow work done, hand it to `Effects.StartSwap` and react to the resulting `SwapDone` — do not block a method waiting for it.
|
||||||
|
|
||||||
|
- **Never block the run loop.** `OnRequest`, `OnSwapDone`, and `OnServeDone` must make a decision and return. The one method allowed to block is `OnUnload`, and only because it must wait on the synchronous `StopProcesses` so the admin caller's guarantee holds.
|
||||||
|
|
||||||
|
- **Respect the `GrantServe` boolean.** Only count a request as in-flight when `GrantServe` returns true (see the in-flight contract above). A false return means the caller is gone; no `ServeDoneEvent` will ever arrive, so incrementing on false permanently strands the counter.
|
||||||
|
|
||||||
|
- **Account for in-flight swaps in your running set.** When you call `Swapper.EvictionFor`, the running set you pass must include not just live processes (`Effects.RunningModels`) but also the targets of swaps you've already started that aren't yet visible in process state — otherwise the swapper contradicts decisions already in motion.
|
||||||
|
|
||||||
|
What each method must do:
|
||||||
|
|
||||||
|
- **`OnRequest(req)`** — every request must resolve to exactly one of: granted, errored, joined (piggybacks an in-flight swap), queued, or swap-started. No request may be silently dropped.
|
||||||
|
|
||||||
|
- **`OnSwapDone(ev)`** — deliver the result to every waiter that joined this swap (grant on success, error on `ev.Err`), drop the swap from active tracking, then re-examine anything queued — a finished swap may have unblocked it.
|
||||||
|
|
||||||
|
- **`OnServeDone(ev)`** — decrement the model's in-flight count; when it hits zero, re-examine the queue. Do **not** clear in-flight counts by hand; the handlers post their own `ServeDoneEvent`s on return.
|
||||||
|
|
||||||
|
- **`OnUnload(targets, timeout)`** — error out any waiters or queued requests for the unloaded models, call `Effects.StopProcesses` (synchronously — the admin caller relies on the process being dead afterwards), then re-examine the queue.
|
||||||
|
|
||||||
|
- **`OnShutdown(err)`** — error out every waiter you still hold (active swap waiters and queued requests). Don't touch processes; teardown is `baseRouter`'s job.
|
||||||
|
|
||||||
|
Expose a constructor matching the `Factory` shape:
|
||||||
|
|
||||||
|
```go
|
||||||
|
func NewMyScheduler(name string, logger *logmon.Monitor, swapper Swapper, eff Effects) *MyScheduler {
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
|
||||||
|
// in the concrete router:
|
||||||
|
base := newBaseRouter(name, conf, processes, proxylog,
|
||||||
|
func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler {
|
||||||
|
return scheduler.NewMyScheduler(name, logger, swapper, eff)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
- **Schedulers** are tested as pure state machines in the `scheduler` package:
|
||||||
|
drive the `On*` methods directly against a `fakeEffects` and assert on the
|
||||||
|
recorded grants/starts/stops. No goroutines, no sleeps. See
|
||||||
|
`scheduler/fifo_test.go` as the reference; follow the `TestSchedulerName_<scenario>`
|
||||||
|
naming convention.
|
||||||
|
- **`baseRouter` mechanism** (run loop, `grant`/`ServeHTTP`, `Unload`,
|
||||||
|
`Shutdown`) is tested in `base_test.go`. The run loop exposes a
|
||||||
|
`testProcessed` channel so tests can wait for an event to be fully processed
|
||||||
|
instead of sleeping.
|
||||||
|
- Run new tests with `go test -v -run TestMyScheduler_... ./internal/router/scheduler/`,
|
||||||
|
then `make test-dev` for a quick `go test` + `staticcheck` pass over `proxy/`.
|
||||||
+13
-19
@@ -14,7 +14,7 @@ type Group struct {
|
|||||||
|
|
||||||
func NewGroup(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Group, error) {
|
func NewGroup(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Group, error) {
|
||||||
modelToGroup := make(map[string]string)
|
modelToGroup := make(map[string]string)
|
||||||
for gid, gcfg := range conf.Groups {
|
for gid, gcfg := range conf.Routing.Router.Settings.Groups {
|
||||||
for _, mid := range gcfg.Members {
|
for _, mid := range gcfg.Members {
|
||||||
if existing, dup := modelToGroup[mid]; dup {
|
if existing, dup := modelToGroup[mid]; dup {
|
||||||
return nil, fmt.Errorf("model %q is in multiple groups: %q and %q", mid, existing, gid)
|
return nil, fmt.Errorf("model %q is in multiple groups: %q and %q", mid, existing, gid)
|
||||||
@@ -23,14 +23,16 @@ func NewGroup(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Group
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
planner := &groupPlanner{
|
swapper := &groupSwapper{
|
||||||
config: conf,
|
config: conf,
|
||||||
modelToGroup: modelToGroup,
|
modelToGroup: modelToGroup,
|
||||||
}
|
}
|
||||||
|
|
||||||
processes := make(map[string]process.Process, len(modelToGroup))
|
processes := make(map[string]process.Process, len(modelToGroup))
|
||||||
base := newBaseRouter("group", conf, processes, planner, proxylog)
|
base, err := newBaseRouter("group", conf, processes, proxylog, swapper)
|
||||||
planner.processes = processes
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating base router: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
for mid := range modelToGroup {
|
for mid := range modelToGroup {
|
||||||
modelCfg, _, ok := conf.FindConfig(mid)
|
modelCfg, _, ok := conf.FindConfig(mid)
|
||||||
@@ -54,21 +56,20 @@ func NewGroup(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Group
|
|||||||
return g, nil
|
return g, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// groupPlanner decides evictions from static group configuration.
|
// groupSwapper decides evictions from static group configuration.
|
||||||
//
|
//
|
||||||
// Same-group siblings are stopped when the group has swap=true. Cross-group
|
// Same-group siblings are stopped when the group has swap=true. Cross-group
|
||||||
// members are stopped only when the target's group is exclusive; loading a
|
// members are stopped only when the target's group is exclusive; loading a
|
||||||
// model from a non-exclusive group leaves running exclusive groups alone,
|
// model from a non-exclusive group leaves running exclusive groups alone,
|
||||||
// matching the gotcha in the original ProcessGroup behaviour.
|
// matching the gotcha in the original ProcessGroup behaviour.
|
||||||
type groupPlanner struct {
|
type groupSwapper struct {
|
||||||
config config.Config
|
config config.Config
|
||||||
modelToGroup map[string]string
|
modelToGroup map[string]string
|
||||||
processes map[string]process.Process
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *groupPlanner) EvictionFor(target string, alsoRunning []string) []string {
|
func (p *groupSwapper) EvictionFor(target string, running []string) []string {
|
||||||
tg := p.modelToGroup[target]
|
tg := p.modelToGroup[target]
|
||||||
tgCfg := p.config.Groups[tg]
|
tgCfg := p.config.Routing.Router.Settings.Groups[tg]
|
||||||
|
|
||||||
seen := make(map[string]struct{})
|
seen := make(map[string]struct{})
|
||||||
var result []string
|
var result []string
|
||||||
@@ -89,24 +90,17 @@ func (p *groupPlanner) EvictionFor(target string, alsoRunning []string) []string
|
|||||||
// for backwards compatibility. The newer swap matrix approach does not
|
// for backwards compatibility. The newer swap matrix approach does not
|
||||||
// have this issue.
|
// have this issue.
|
||||||
case og != tg && tgCfg.Exclusive:
|
case og != tg && tgCfg.Exclusive:
|
||||||
if ogCfg := p.config.Groups[og]; !ogCfg.Persistent {
|
if ogCfg := p.config.Routing.Router.Settings.Groups[og]; !ogCfg.Persistent {
|
||||||
seen[mID] = struct{}{}
|
seen[mID] = struct{}{}
|
||||||
result = append(result, mID)
|
result = append(result, mID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for mID, proc := range p.processes {
|
for _, mID := range running {
|
||||||
st := proc.State()
|
|
||||||
if st == process.StateStopped || st == process.StateShutdown {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
consider(mID)
|
|
||||||
}
|
|
||||||
for _, mID := range alsoRunning {
|
|
||||||
consider(mID)
|
consider(mID)
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *groupPlanner) OnSwapStart(target string) {}
|
func (p *groupSwapper) OnSwapStart(target string, running []string) {}
|
||||||
|
|||||||
@@ -17,17 +17,19 @@ import (
|
|||||||
func newTestGroup(t *testing.T, conf config.Config, processes map[string]process.Process) *Group {
|
func newTestGroup(t *testing.T, conf config.Config, processes map[string]process.Process) *Group {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
modelToGroup := make(map[string]string)
|
modelToGroup := make(map[string]string)
|
||||||
for gid, gcfg := range conf.Groups {
|
for gid, gcfg := range conf.Routing.Router.Settings.Groups {
|
||||||
for _, mid := range gcfg.Members {
|
for _, mid := range gcfg.Members {
|
||||||
modelToGroup[mid] = gid
|
modelToGroup[mid] = gid
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
planner := &groupPlanner{
|
swapper := &groupSwapper{
|
||||||
config: conf,
|
config: conf,
|
||||||
modelToGroup: modelToGroup,
|
modelToGroup: modelToGroup,
|
||||||
processes: processes,
|
|
||||||
}
|
}
|
||||||
base := newBaseRouter("group", conf, processes, planner, logmon.NewWriter(io.Discard))
|
base, err := newBaseRouter("group", conf, processes, logmon.NewWriter(io.Discard), swapper)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("newBaseRouter: %v", err)
|
||||||
|
}
|
||||||
base.testProcessed = make(chan struct{}, 64)
|
base.testProcessed = make(chan struct{}, 64)
|
||||||
g := &Group{baseRouter: base}
|
g := &Group{baseRouter: base}
|
||||||
go base.run()
|
go base.run()
|
||||||
@@ -41,10 +43,10 @@ func newTestGroup(t *testing.T, conf config.Config, processes map[string]process
|
|||||||
|
|
||||||
func TestGroup_NewGroup_DuplicateMembership(t *testing.T) {
|
func TestGroup_NewGroup_DuplicateMembership(t *testing.T) {
|
||||||
conf := config.Config{
|
conf := config.Config{
|
||||||
Groups: map[string]config.GroupConfig{
|
Routing: groupRouting(map[string]config.GroupConfig{
|
||||||
"g1": {Swap: true, Members: []string{"a"}},
|
"g1": {Swap: true, Members: []string{"a"}},
|
||||||
"g2": {Swap: true, Members: []string{"a"}},
|
"g2": {Swap: true, Members: []string{"a"}},
|
||||||
},
|
}),
|
||||||
Models: map[string]config.ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"a": {},
|
"a": {},
|
||||||
},
|
},
|
||||||
@@ -65,9 +67,9 @@ func TestGroup_ServeHTTP_SwapStopsPrevious(t *testing.T) {
|
|||||||
|
|
||||||
conf := config.Config{
|
conf := config.Config{
|
||||||
HealthCheckTimeout: 5,
|
HealthCheckTimeout: 5,
|
||||||
Groups: map[string]config.GroupConfig{
|
Routing: groupRouting(map[string]config.GroupConfig{
|
||||||
"g": {Swap: true, Exclusive: true, Members: []string{"a", "b"}},
|
"g": {Swap: true, Exclusive: true, Members: []string{"a", "b"}},
|
||||||
},
|
}),
|
||||||
}
|
}
|
||||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||||
|
|
||||||
@@ -97,9 +99,9 @@ func TestGroup_NonSwapGroup_NoStop(t *testing.T) {
|
|||||||
|
|
||||||
conf := config.Config{
|
conf := config.Config{
|
||||||
HealthCheckTimeout: 5,
|
HealthCheckTimeout: 5,
|
||||||
Groups: map[string]config.GroupConfig{
|
Routing: groupRouting(map[string]config.GroupConfig{
|
||||||
"g": {Swap: false, Exclusive: false, Members: []string{"a", "b"}},
|
"g": {Swap: false, Exclusive: false, Members: []string{"a", "b"}},
|
||||||
},
|
}),
|
||||||
}
|
}
|
||||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||||
|
|
||||||
@@ -127,10 +129,10 @@ func TestGroup_CrossGroupExclusive(t *testing.T) {
|
|||||||
|
|
||||||
conf := config.Config{
|
conf := config.Config{
|
||||||
HealthCheckTimeout: 5,
|
HealthCheckTimeout: 5,
|
||||||
Groups: map[string]config.GroupConfig{
|
Routing: groupRouting(map[string]config.GroupConfig{
|
||||||
"g1": {Swap: true, Exclusive: true, Members: []string{"a"}},
|
"g1": {Swap: true, Exclusive: true, Members: []string{"a"}},
|
||||||
"g2": {Swap: true, Exclusive: true, Members: []string{"b"}},
|
"g2": {Swap: true, Exclusive: true, Members: []string{"b"}},
|
||||||
},
|
}),
|
||||||
}
|
}
|
||||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||||
|
|
||||||
@@ -154,10 +156,10 @@ func TestGroup_CrossGroupNonExclusiveParallel(t *testing.T) {
|
|||||||
|
|
||||||
conf := config.Config{
|
conf := config.Config{
|
||||||
HealthCheckTimeout: 5,
|
HealthCheckTimeout: 5,
|
||||||
Groups: map[string]config.GroupConfig{
|
Routing: groupRouting(map[string]config.GroupConfig{
|
||||||
"g1": {Swap: true, Exclusive: false, Members: []string{"a"}},
|
"g1": {Swap: true, Exclusive: false, Members: []string{"a"}},
|
||||||
"g2": {Swap: true, Exclusive: false, Members: []string{"b"}},
|
"g2": {Swap: true, Exclusive: false, Members: []string{"b"}},
|
||||||
},
|
}),
|
||||||
}
|
}
|
||||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": pb})
|
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": pb})
|
||||||
|
|
||||||
@@ -202,16 +204,17 @@ func TestGroup_CrossGroupNonExclusiveParallel(t *testing.T) {
|
|||||||
|
|
||||||
// TestGroup_SameGroupSwapSerialises verifies that two same-group requests
|
// TestGroup_SameGroupSwapSerialises verifies that two same-group requests
|
||||||
// (Swap=true) serialise even when both arrive while neither has reached
|
// (Swap=true) serialise even when both arrive while neither has reached
|
||||||
// StateStarting yet — the alsoRunning hint to the planner closes that race.
|
// StateStarting yet — the in-flight swap target the scheduler folds into the
|
||||||
|
// running set closes that race.
|
||||||
func TestGroup_SameGroupSwapSerialises(t *testing.T) {
|
func TestGroup_SameGroupSwapSerialises(t *testing.T) {
|
||||||
a := newFakeProcess("a")
|
a := newFakeProcess("a")
|
||||||
pb := newFakeProcess("b")
|
pb := newFakeProcess("b")
|
||||||
|
|
||||||
conf := config.Config{
|
conf := config.Config{
|
||||||
HealthCheckTimeout: 5,
|
HealthCheckTimeout: 5,
|
||||||
Groups: map[string]config.GroupConfig{
|
Routing: groupRouting(map[string]config.GroupConfig{
|
||||||
"g": {Swap: true, Exclusive: false, Members: []string{"a", "b"}},
|
"g": {Swap: true, Exclusive: false, Members: []string{"a", "b"}},
|
||||||
},
|
}),
|
||||||
}
|
}
|
||||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": pb})
|
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": pb})
|
||||||
|
|
||||||
@@ -224,8 +227,9 @@ func TestGroup_SameGroupSwapSerialises(t *testing.T) {
|
|||||||
waitProcessed(t, g.testProcessed, 1)
|
waitProcessed(t, g.testProcessed, 1)
|
||||||
|
|
||||||
// Request B arrives before A transitions to StateStarting in the process
|
// Request B arrives before A transitions to StateStarting in the process
|
||||||
// state machine. Without the alsoRunning hint, the planner would not see
|
// state machine. Without folding the in-flight swap target into the running
|
||||||
// A as running, and B would start in parallel, violating Swap=true.
|
// set, the swapper would not see A as running, and B would start in
|
||||||
|
// parallel, violating Swap=true.
|
||||||
w2 := httptest.NewRecorder()
|
w2 := httptest.NewRecorder()
|
||||||
done2 := make(chan struct{})
|
done2 := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
@@ -269,10 +273,10 @@ func TestGroup_PersistentNotEvicted(t *testing.T) {
|
|||||||
|
|
||||||
conf := config.Config{
|
conf := config.Config{
|
||||||
HealthCheckTimeout: 5,
|
HealthCheckTimeout: 5,
|
||||||
Groups: map[string]config.GroupConfig{
|
Routing: groupRouting(map[string]config.GroupConfig{
|
||||||
"persist": {Swap: true, Exclusive: false, Persistent: true, Members: []string{"a"}},
|
"persist": {Swap: true, Exclusive: false, Persistent: true, Members: []string{"a"}},
|
||||||
"other": {Swap: true, Exclusive: true, Members: []string{"b"}},
|
"other": {Swap: true, Exclusive: true, Members: []string{"b"}},
|
||||||
},
|
}),
|
||||||
}
|
}
|
||||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||||
|
|
||||||
@@ -306,10 +310,10 @@ func TestGroup_NonExclusiveDoesNotUnloadExclusive(t *testing.T) {
|
|||||||
|
|
||||||
conf := config.Config{
|
conf := config.Config{
|
||||||
HealthCheckTimeout: 5,
|
HealthCheckTimeout: 5,
|
||||||
Groups: map[string]config.GroupConfig{
|
Routing: groupRouting(map[string]config.GroupConfig{
|
||||||
"g1": {Swap: true, Exclusive: true, Members: []string{"a"}},
|
"g1": {Swap: true, Exclusive: true, Members: []string{"a"}},
|
||||||
"g2": {Swap: true, Exclusive: false, Members: []string{"b"}},
|
"g2": {Swap: true, Exclusive: false, Members: []string{"b"}},
|
||||||
},
|
}),
|
||||||
}
|
}
|
||||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||||
|
|
||||||
|
|||||||
@@ -12,10 +12,23 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
"github.com/mostlygeek/llama-swap/internal/process"
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// groupRouting builds a normalized RoutingConfig for the group router, mirroring
|
||||||
|
// what config.LoadConfigFromReader produces. Tests use it to populate
|
||||||
|
// config.Config.Routing without going through LoadConfig.
|
||||||
|
func groupRouting(groups map[string]config.GroupConfig) config.RoutingConfig {
|
||||||
|
return config.RoutingConfig{
|
||||||
|
Router: config.RouterConfig{
|
||||||
|
Use: "group",
|
||||||
|
Settings: config.RouterSettings{Groups: groups},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// fakeProcess is an in-memory implementation of process.Process used to drive
|
// fakeProcess is an in-memory implementation of process.Process used to drive
|
||||||
// the routers through their state machine without spawning real upstreams.
|
// the routers through their state machine without spawning real upstreams.
|
||||||
type fakeProcess struct {
|
type fakeProcess struct {
|
||||||
|
|||||||
@@ -226,69 +226,6 @@ func TestIsLoadingPath(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExtractContext_Streaming_GET(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
query string
|
|
||||||
wantStreaming bool
|
|
||||||
}{
|
|
||||||
{"streaming true", "model=llama3&stream=true", true},
|
|
||||||
{"streaming false", "model=llama3&stream=false", false},
|
|
||||||
{"no stream param", "model=llama3", false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
|
|
||||||
got, err := ExtractContext(r)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ExtractContext: %v", err)
|
|
||||||
}
|
|
||||||
if got.Streaming != tt.wantStreaming {
|
|
||||||
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtractContext_Streaming_JSON(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
body string
|
|
||||||
wantStreaming bool
|
|
||||||
}{
|
|
||||||
{"streaming true", `{"model":"llama3","stream":true}`, true},
|
|
||||||
{"streaming false", `{"model":"llama3","stream":false}`, false},
|
|
||||||
{"no stream param", `{"model":"llama3"}`, false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
|
|
||||||
r.Header.Set("Content-Type", "application/json")
|
|
||||||
got, err := ExtractContext(r)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ExtractContext: %v", err)
|
|
||||||
}
|
|
||||||
if got.Streaming != tt.wantStreaming {
|
|
||||||
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtractContext_Streaming_URLEncodedForm(t *testing.T) {
|
|
||||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader("model=whisper-1&stream=true"))
|
|
||||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
got, err := ExtractContext(r)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ExtractContext: %v", err)
|
|
||||||
}
|
|
||||||
if !got.Streaming {
|
|
||||||
t.Error("Streaming should be true")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func countSSEMessages(s string) int {
|
func countSSEMessages(s string) int {
|
||||||
scanner := bufio.NewScanner(strings.NewReader(s))
|
scanner := bufio.NewScanner(strings.NewReader(s))
|
||||||
count := 0
|
count := 0
|
||||||
|
|||||||
+16
-45
@@ -2,7 +2,6 @@ package router
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
@@ -14,20 +13,23 @@ type Matrix struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewMatrix(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Matrix, error) {
|
func NewMatrix(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Matrix, error) {
|
||||||
if conf.Matrix == nil {
|
mtx := conf.Routing.Router.Settings.Matrix
|
||||||
|
if mtx == nil {
|
||||||
return nil, fmt.Errorf("matrix router requires a matrix configuration")
|
return nil, fmt.Errorf("matrix router requires a matrix configuration")
|
||||||
}
|
}
|
||||||
|
|
||||||
planner := &matrixPlanner{
|
swapper := &matrixSwapper{
|
||||||
solver: newMatrixSolver(conf.ExpandedSets, conf.Matrix.ResolvedEvictCosts()),
|
solver: newMatrixSolver(mtx.ExpandedSets, mtx.ResolvedEvictCosts()),
|
||||||
logger: proxylog,
|
logger: proxylog,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build a process for every model in the config. Any model can run alone
|
// Build a process for every model in the config. Any model can run alone
|
||||||
// even if it is not part of a set; this mirrors proxy.NewMatrix.
|
// even if it is not part of a set; this mirrors proxy.NewMatrix.
|
||||||
processes := make(map[string]process.Process, len(conf.Models))
|
processes := make(map[string]process.Process, len(conf.Models))
|
||||||
base := newBaseRouter("matrix", conf, processes, planner, proxylog)
|
base, err := newBaseRouter("matrix", conf, processes, proxylog, swapper)
|
||||||
planner.processes = processes
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating base router: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
for mid, modelCfg := range conf.Models {
|
for mid, modelCfg := range conf.Models {
|
||||||
procLog := logmon.NewWriter(upstreamlog)
|
procLog := logmon.NewWriter(upstreamlog)
|
||||||
@@ -45,20 +47,18 @@ func NewMatrix(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Matr
|
|||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// matrixPlanner decides evictions by asking the matrix solver against the
|
// matrixSwapper decides evictions by asking the matrix solver against the
|
||||||
// current running set.
|
// running set the scheduler hands it.
|
||||||
type matrixPlanner struct {
|
type matrixSwapper struct {
|
||||||
solver *matrixSolver
|
solver *matrixSolver
|
||||||
processes map[string]process.Process
|
logger *logmon.Monitor
|
||||||
logger *logmon.Monitor
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *matrixPlanner) EvictionFor(target string, alsoRunning []string) []string {
|
func (p *matrixSwapper) EvictionFor(target string, running []string) []string {
|
||||||
return p.solver.Solve(target, p.runningSet(alsoRunning)).Evict
|
return p.solver.Solve(target, running).Evict
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *matrixPlanner) OnSwapStart(target string) {
|
func (p *matrixSwapper) OnSwapStart(target string, running []string) {
|
||||||
running := p.runningModels()
|
|
||||||
result := p.solver.Solve(target, running)
|
result := p.solver.Solve(target, running)
|
||||||
switch {
|
switch {
|
||||||
case len(result.Evict) > 0:
|
case len(result.Evict) > 0:
|
||||||
@@ -70,32 +70,3 @@ func (p *matrixPlanner) OnSwapStart(target string) {
|
|||||||
p.logger.Debugf("matrix: model=%s already running in set=%s dsl=%q", target, result.SetName, result.DSL)
|
p.logger.Debugf("matrix: model=%s already running in set=%s dsl=%q", target, result.SetName, result.DSL)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *matrixPlanner) runningModels() []string {
|
|
||||||
return p.runningSet(nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// runningSet returns the union of live processes (State != Stopped/Shutdown)
|
|
||||||
// and any extra IDs the baseRouter has already committed to loading but which
|
|
||||||
// the process state machine has not yet reflected.
|
|
||||||
func (p *matrixPlanner) runningSet(alsoRunning []string) []string {
|
|
||||||
seen := make(map[string]struct{}, len(p.processes))
|
|
||||||
var running []string
|
|
||||||
for id, proc := range p.processes {
|
|
||||||
st := proc.State()
|
|
||||||
if st == process.StateStopped || st == process.StateShutdown {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
seen[id] = struct{}{}
|
|
||||||
running = append(running, id)
|
|
||||||
}
|
|
||||||
for _, id := range alsoRunning {
|
|
||||||
if _, dup := seen[id]; dup {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
seen[id] = struct{}{}
|
|
||||||
running = append(running, id)
|
|
||||||
}
|
|
||||||
sort.Strings(running)
|
|
||||||
return running
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -17,12 +17,14 @@ import (
|
|||||||
func newTestMatrix(t *testing.T, conf config.Config, expanded []config.ExpandedSet, evictCosts map[string]int, processes map[string]process.Process) *Matrix {
|
func newTestMatrix(t *testing.T, conf config.Config, expanded []config.ExpandedSet, evictCosts map[string]int, processes map[string]process.Process) *Matrix {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
logger := logmon.NewWriter(io.Discard)
|
logger := logmon.NewWriter(io.Discard)
|
||||||
planner := &matrixPlanner{
|
swapper := &matrixSwapper{
|
||||||
solver: newMatrixSolver(expanded, evictCosts),
|
solver: newMatrixSolver(expanded, evictCosts),
|
||||||
processes: processes,
|
logger: logger,
|
||||||
logger: logger,
|
}
|
||||||
|
base, err := newBaseRouter("matrix", conf, processes, logger, swapper)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("newBaseRouter: %v", err)
|
||||||
}
|
}
|
||||||
base := newBaseRouter("matrix", conf, processes, planner, logger)
|
|
||||||
base.testProcessed = make(chan struct{}, 64)
|
base.testProcessed = make(chan struct{}, 64)
|
||||||
r := &Matrix{baseRouter: base}
|
r := &Matrix{baseRouter: base}
|
||||||
go base.run()
|
go base.run()
|
||||||
@@ -153,8 +155,8 @@ func TestMatrix_CoexistingSetParallel(t *testing.T) {
|
|||||||
|
|
||||||
// TestMatrix_IncompatibleQueues verifies that the second request for a model
|
// TestMatrix_IncompatibleQueues verifies that the second request for a model
|
||||||
// that cannot coexist with the in-flight first model queues until the first
|
// that cannot coexist with the in-flight first model queues until the first
|
||||||
// completes, and then evicts it. This exercises the alsoRunning hint via the
|
// completes, and then evicts it. This exercises the scheduler folding in-flight
|
||||||
// matrix solver's union into runningSet.
|
// swap targets into the running set it hands the swapper.
|
||||||
func TestMatrix_IncompatibleQueues(t *testing.T) {
|
func TestMatrix_IncompatibleQueues(t *testing.T) {
|
||||||
a := newFakeProcess("a")
|
a := newFakeProcess("a")
|
||||||
pb := newFakeProcess("b")
|
pb := newFakeProcess("b")
|
||||||
@@ -173,8 +175,9 @@ func TestMatrix_IncompatibleQueues(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
waitProcessed(t, r.testProcessed, 1)
|
waitProcessed(t, r.testProcessed, 1)
|
||||||
|
|
||||||
// B arrives before A transitions to StateStarting. The solver sees A via
|
// B arrives before A transitions to StateStarting. The running set the
|
||||||
// alsoRunning and returns evict=[a], so collidesWith forces B to queue.
|
// scheduler builds includes A (an in-flight swap target), so the solver
|
||||||
|
// returns evict=[a] and collidesWith forces B to queue.
|
||||||
w2 := httptest.NewRecorder()
|
w2 := httptest.NewRecorder()
|
||||||
done2 := make(chan struct{})
|
done2 := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
type peerMember struct {
|
type peerMember struct {
|
||||||
@@ -146,22 +147,22 @@ func (r *Peer) Shutdown(timeout time.Duration) error {
|
|||||||
|
|
||||||
func (r *Peer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
func (r *Peer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
if r.shuttingDown.Load() {
|
if r.shuttingDown.Load() {
|
||||||
SendError(w, req, fmt.Errorf("peer proxy is shutting down"))
|
shared.SendError(w, req, fmt.Errorf("peer proxy is shutting down"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.inflight.Add(1)
|
r.inflight.Add(1)
|
||||||
defer r.inflight.Done()
|
defer r.inflight.Done()
|
||||||
|
|
||||||
data, err := FetchContext(req, r.cfg)
|
data, err := shared.FetchContext(req, r.cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
SendError(w, req, err)
|
shared.SendError(w, req, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pp, found := r.peers[data.ModelID]
|
pp, found := r.peers[data.ModelID]
|
||||||
if !found {
|
if !found {
|
||||||
r.logger.Warnf("peer model not found: %s", data.ModelID)
|
r.logger.Warnf("peer model not found: %s", data.ModelID)
|
||||||
SendError(w, req, ErrNoPeerModelFound)
|
shared.SendError(w, req, ErrNoPeerModelFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
var testLogger = logmon.NewWriter(os.Stdout)
|
var testLogger = logmon.NewWriter(os.Stdout)
|
||||||
@@ -142,7 +143,7 @@ func TestPeer_ServeHTTP_Success(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
pr.ServeHTTP(w, req)
|
pr.ServeHTTP(w, req)
|
||||||
@@ -178,7 +179,7 @@ func TestPeer_ServeHTTP_PeerModelNotFound(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "nonexistent-model", ModelID: "nonexistent-model"}))
|
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "nonexistent-model", ModelID: "nonexistent-model"}))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
pr.ServeHTTP(w, req)
|
pr.ServeHTTP(w, req)
|
||||||
@@ -212,7 +213,7 @@ func TestPeer_ServeHTTP_ApiKeyInjection(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
pr.ServeHTTP(w, req)
|
pr.ServeHTTP(w, req)
|
||||||
@@ -246,7 +247,7 @@ func TestPeer_ServeHTTP_NoApiKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
pr.ServeHTTP(w, req)
|
pr.ServeHTTP(w, req)
|
||||||
@@ -279,7 +280,7 @@ func TestPeer_ServeHTTP_HostHeaderSet(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
pr.ServeHTTP(w, req)
|
pr.ServeHTTP(w, req)
|
||||||
@@ -311,7 +312,7 @@ func TestPeer_ServeHTTP_SSEHeaderModification(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
pr.ServeHTTP(w, req)
|
pr.ServeHTTP(w, req)
|
||||||
@@ -347,7 +348,7 @@ func TestPeer_ServeHTTP_ShutdownRejectsNewRequests(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
pr.ServeHTTP(w, req)
|
pr.ServeHTTP(w, req)
|
||||||
@@ -385,7 +386,7 @@ func TestPeer_ServeHTTP_WaitsForInflightDuringShutdown(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
@@ -448,7 +449,7 @@ func TestPeer_ServeHTTP_ShutdownTimeoutCancelsInflight(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
@@ -551,7 +552,7 @@ func TestPeer_ServeHTTP_ContextOverridesBodyModel(t *testing.T) {
|
|||||||
body := strings.NewReader(`{"model":"body-model","prompt":"hello"}`)
|
body := strings.NewReader(`{"model":"body-model","prompt":"hello"}`)
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", body)
|
req := httptest.NewRequest("POST", "/v1/chat/completions", body)
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "context-model", ModelID: "context-model"}))
|
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "context-model", ModelID: "context-model"}))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
pr.ServeHTTP(w, req)
|
pr.ServeHTTP(w, req)
|
||||||
|
|||||||
+4
-151
@@ -1,39 +1,18 @@
|
|||||||
package router
|
package router
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
"github.com/mostlygeek/llama-swap/internal/process"
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
type contextkey struct {
|
|
||||||
name string
|
|
||||||
}
|
|
||||||
|
|
||||||
type ReqContextData struct {
|
|
||||||
Model string
|
|
||||||
ModelID string
|
|
||||||
Streaming bool
|
|
||||||
SendLoadingState bool
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrNoModelInContext = fmt.Errorf("no model in request context")
|
ErrNoRouterFound = shared.ErrNoRouterFound
|
||||||
ErrNoRouterFound = fmt.Errorf("no router found for model")
|
ErrNoPeerModelFound = shared.ErrNoPeerModelFound
|
||||||
ErrNoPeerModelFound = fmt.Errorf("peer model not found")
|
ErrNoLocalModelFound = shared.ErrNoLocalModelFound
|
||||||
ErrNoLocalModelFound = fmt.Errorf("local model not found")
|
|
||||||
|
|
||||||
ContextKey = &contextkey{"context"}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Router interface {
|
type Router interface {
|
||||||
@@ -71,129 +50,3 @@ type LocalRouter interface {
|
|||||||
// model is not known to this router.
|
// model is not known to this router.
|
||||||
ProcessLogger(modelID string) (*logmon.Monitor, bool)
|
ProcessLogger(modelID string) (*logmon.Monitor, bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchContext will attempt to get the model id from the context then
|
|
||||||
// from the model body. If it extracts the model from the body it will
|
|
||||||
// store the model in the context for downstream handlers. An error
|
|
||||||
// will be returned when model can not be fetch from either location.
|
|
||||||
func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
|
|
||||||
data, ok := ReadContext(r.Context())
|
|
||||||
if ok {
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if data, err := ExtractContext(r); err == nil {
|
|
||||||
realName, _ := cfg.RealModelName(data.Model)
|
|
||||||
if realName == "" {
|
|
||||||
realName = data.Model
|
|
||||||
}
|
|
||||||
data.ModelID = realName
|
|
||||||
if mc, ok := cfg.Models[realName]; ok {
|
|
||||||
data.SendLoadingState = mc.SendLoadingState != nil && *mc.SendLoadingState
|
|
||||||
}
|
|
||||||
*r = *r.WithContext(SetContext(r.Context(), data))
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return ReqContextData{}, ErrNoModelInContext
|
|
||||||
}
|
|
||||||
|
|
||||||
func SetContext(ctx context.Context, data ReqContextData) context.Context {
|
|
||||||
return context.WithValue(ctx, ContextKey, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReadContext(ctx context.Context) (ReqContextData, bool) {
|
|
||||||
data, ok := ctx.Value(ContextKey).(ReqContextData)
|
|
||||||
return data, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExtractContext pulls the model name from an HTTP request without consuming the
|
|
||||||
// body. For GET requests it reads the "model" query parameter. For POST
|
|
||||||
// requests it inspects Content-Type and parses JSON, multipart/form-data, or
|
|
||||||
// application/x-www-form-urlencoded bodies. The request body is always restored
|
|
||||||
// before returning so downstream handlers — including reverse proxies that
|
|
||||||
// forward raw bytes upstream — can still read it.
|
|
||||||
func ExtractContext(r *http.Request) (ReqContextData, error) {
|
|
||||||
if r.Method == http.MethodGet {
|
|
||||||
if model := r.URL.Query().Get("model"); model != "" {
|
|
||||||
return ReqContextData{Model: model, Streaming: r.URL.Query().Get("stream") == "true"}, nil
|
|
||||||
}
|
|
||||||
return ReqContextData{}, fmt.Errorf("missing 'model' query parameter")
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyBytes, err := io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
return ReqContextData{}, fmt.Errorf("error reading request body: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
}()
|
|
||||||
|
|
||||||
contentType := r.Header.Get("Content-Type")
|
|
||||||
|
|
||||||
if strings.Contains(contentType, "application/json") {
|
|
||||||
model := gjson.GetBytes(bodyBytes, "model").String()
|
|
||||||
if model == "" {
|
|
||||||
return ReqContextData{}, fmt.Errorf("missing or empty 'model' in JSON body")
|
|
||||||
}
|
|
||||||
return ReqContextData{Model: model, Streaming: gjson.GetBytes(bodyBytes, "stream").Bool()}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Form parsers read from r.Body, so feed them a fresh reader over the
|
|
||||||
// buffered bytes. The deferred restore above will reset r.Body again
|
|
||||||
// after parsing.
|
|
||||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
if strings.Contains(contentType, "multipart/form-data") {
|
|
||||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
|
||||||
return ReqContextData{}, fmt.Errorf("error parsing multipart form: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if err := r.ParseForm(); err != nil {
|
|
||||||
return ReqContextData{}, fmt.Errorf("error parsing form: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if model := r.FormValue("model"); model != "" {
|
|
||||||
return ReqContextData{Model: model, Streaming: r.FormValue("stream") == "true"}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return ReqContextData{}, fmt.Errorf("missing 'model' parameter")
|
|
||||||
}
|
|
||||||
|
|
||||||
func SendError(w http.ResponseWriter, r *http.Request, err error) {
|
|
||||||
switch {
|
|
||||||
case errors.Is(err, ErrNoModelInContext):
|
|
||||||
SendResponse(w, r, http.StatusNotFound, "no model id could be identified")
|
|
||||||
case errors.Is(err, ErrNoPeerModelFound):
|
|
||||||
SendResponse(w, r, http.StatusNotFound, "no peer found for requested model")
|
|
||||||
case errors.Is(err, ErrNoLocalModelFound):
|
|
||||||
SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
|
|
||||||
case errors.Is(err, ErrNoRouterFound):
|
|
||||||
SendResponse(w, r, http.StatusNotFound, "no router for requested model")
|
|
||||||
default:
|
|
||||||
SendResponse(w, r, http.StatusInternalServerError, fmt.Sprintf("unspecific error: %v", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendResponse detects what content type the client prefers and returns an error response in that format.
|
|
||||||
func SendResponse(w http.ResponseWriter, r *http.Request, status int, message string) {
|
|
||||||
// Check Accept header for preferred response format
|
|
||||||
acceptHeader := r.Header.Get("Accept")
|
|
||||||
if strings.Contains(acceptHeader, "text/plain") {
|
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
|
||||||
w.WriteHeader(status)
|
|
||||||
w.Write([]byte(fmt.Sprintf("llama-swap: %s", message)))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.Contains(acceptHeader, "text/html") {
|
|
||||||
w.Header().Set("Content-Type", "text/html")
|
|
||||||
w.WriteHeader(status)
|
|
||||||
w.Write([]byte(fmt.Sprintf(`<html><body><h1>llama-swap</h1><p>%s</p></body></html>`, message)))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(status)
|
|
||||||
w.Write([]byte(fmt.Sprintf(`{"src":"llama-swap", "error": "%s"}`, message)))
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,275 +0,0 @@
|
|||||||
package router
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"io"
|
|
||||||
"mime/multipart"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestExtractContext_GET(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
query string
|
|
||||||
wantModel string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"model present", "model=llama3", "llama3", false},
|
|
||||||
{"model with slashes", "model=author/model-7b", "author/model-7b", false},
|
|
||||||
{"model missing", "", "", true},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
|
|
||||||
got, err := ExtractContext(r)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
|
||||||
}
|
|
||||||
if got.Model != tt.wantModel {
|
|
||||||
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtractContext_JSON(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
body string
|
|
||||||
wantModel string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"model present", `{"model":"llama3","stream":true}`, "llama3", false},
|
|
||||||
{"model with slashes", `{"model":"author/model-7b"}`, "author/model-7b", false},
|
|
||||||
{"model empty string", `{"model":""}`, "", true},
|
|
||||||
{"model key missing", `{"stream":true}`, "", true},
|
|
||||||
{"invalid json", `not-json`, "", true},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
|
|
||||||
r.Header.Set("Content-Type", "application/json")
|
|
||||||
got, err := ExtractContext(r)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
|
||||||
}
|
|
||||||
if got.Model != tt.wantModel {
|
|
||||||
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtractContext_URLEncodedForm(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
formModel string
|
|
||||||
wantModel string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"model present", "whisper-1", "whisper-1", false},
|
|
||||||
{"model missing", "", "", true},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
form := url.Values{}
|
|
||||||
if tt.formModel != "" {
|
|
||||||
form.Set("model", tt.formModel)
|
|
||||||
}
|
|
||||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(form.Encode()))
|
|
||||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
got, err := ExtractContext(r)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
|
||||||
}
|
|
||||||
if got.Model != tt.wantModel {
|
|
||||||
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtractContext_MultipartForm(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
formModel string
|
|
||||||
wantModel string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"model present", "whisper-1", "whisper-1", false},
|
|
||||||
{"model missing", "", "", true},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
mw := multipart.NewWriter(&buf)
|
|
||||||
if tt.formModel != "" {
|
|
||||||
fw, _ := mw.CreateFormField("model")
|
|
||||||
fw.Write([]byte(tt.formModel))
|
|
||||||
}
|
|
||||||
mw.Close()
|
|
||||||
|
|
||||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf)
|
|
||||||
r.Header.Set("Content-Type", mw.FormDataContentType())
|
|
||||||
got, err := ExtractContext(r)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
|
||||||
}
|
|
||||||
if got.Model != tt.wantModel {
|
|
||||||
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtractContext_JSONBodyRestored(t *testing.T) {
|
|
||||||
body := `{"model":"llama3","stream":true}`
|
|
||||||
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body))
|
|
||||||
r.Header.Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
if _, err := ExtractContext(r); err != nil {
|
|
||||||
t.Fatalf("ExtractContext: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
remaining, err := io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("reading body after ExtractContext: %v", err)
|
|
||||||
}
|
|
||||||
if string(remaining) != body {
|
|
||||||
t.Errorf("body not restored: want %q got %q", body, string(remaining))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtractContext_MultipartBodyRestored(t *testing.T) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
mw := multipart.NewWriter(&buf)
|
|
||||||
fw, _ := mw.CreateFormField("model")
|
|
||||||
fw.Write([]byte("whisper-1"))
|
|
||||||
ff, _ := mw.CreateFormFile("file", "audio.wav")
|
|
||||||
ff.Write([]byte("fake-audio-bytes"))
|
|
||||||
mw.Close()
|
|
||||||
|
|
||||||
original := buf.Bytes()
|
|
||||||
|
|
||||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", bytes.NewReader(original))
|
|
||||||
r.Header.Set("Content-Type", mw.FormDataContentType())
|
|
||||||
|
|
||||||
if _, err := ExtractContext(r); err != nil {
|
|
||||||
t.Fatalf("ExtractContext: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
remaining, err := io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("reading body after ExtractContext: %v", err)
|
|
||||||
}
|
|
||||||
if !bytes.Equal(remaining, original) {
|
|
||||||
t.Errorf("multipart body not restored: want %d bytes got %d bytes", len(original), len(remaining))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtractContext_URLEncodedBodyRestored(t *testing.T) {
|
|
||||||
body := "model=whisper-1&extra=value"
|
|
||||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(body))
|
|
||||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
if _, err := ExtractContext(r); err != nil {
|
|
||||||
t.Fatalf("ExtractContext: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
remaining, err := io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("reading body after ExtractContext: %v", err)
|
|
||||||
}
|
|
||||||
if string(remaining) != body {
|
|
||||||
t.Errorf("url-encoded body not restored: want %q got %q", body, string(remaining))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSetContext(t *testing.T) {
|
|
||||||
ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"})
|
|
||||||
data, ok := ctx.Value(ContextKey).(ReqContextData)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("ContextKey not set or wrong type")
|
|
||||||
}
|
|
||||||
if data.Model != "llama3" {
|
|
||||||
t.Errorf("want %q got %q", "llama3", data.Model)
|
|
||||||
}
|
|
||||||
if data.ModelID != "llama3" {
|
|
||||||
t.Errorf("want %q got %q", "llama3", data.ModelID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSetContext_WithAlias(t *testing.T) {
|
|
||||||
ctx := SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"})
|
|
||||||
data, _ := ctx.Value(ContextKey).(ReqContextData)
|
|
||||||
if data.Model != "llama" {
|
|
||||||
t.Errorf("want requested %q got %q", "llama", data.Model)
|
|
||||||
}
|
|
||||||
if data.ModelID != "llama3" {
|
|
||||||
t.Errorf("want real %q got %q", "llama3", data.ModelID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSetContext_DoesNotMutateParent(t *testing.T) {
|
|
||||||
parent := context.Background()
|
|
||||||
_ = SetContext(parent, ReqContextData{Model: "llama3", ModelID: "llama3"})
|
|
||||||
if v := parent.Value(ContextKey); v != nil {
|
|
||||||
t.Errorf("parent context was mutated: %v", v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadContext(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
ctx context.Context
|
|
||||||
wantReq string
|
|
||||||
wantReal string
|
|
||||||
wantBool bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "model present, same name",
|
|
||||||
ctx: SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"}),
|
|
||||||
wantReq: "llama3",
|
|
||||||
wantReal: "llama3",
|
|
||||||
wantBool: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "model present, aliased",
|
|
||||||
ctx: SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"}),
|
|
||||||
wantReq: "llama",
|
|
||||||
wantReal: "llama3",
|
|
||||||
wantBool: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "model absent",
|
|
||||||
ctx: context.Background(),
|
|
||||||
wantReq: "",
|
|
||||||
wantReal: "",
|
|
||||||
wantBool: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "model is empty string",
|
|
||||||
ctx: SetContext(context.Background(), ReqContextData{Model: "", ModelID: ""}),
|
|
||||||
wantReq: "",
|
|
||||||
wantReal: "",
|
|
||||||
wantBool: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
gotData, ok := ReadContext(tt.ctx)
|
|
||||||
if gotData.Model != tt.wantReq || gotData.ModelID != tt.wantReal || ok != tt.wantBool {
|
|
||||||
t.Errorf("want (%q, %q, %v) got (%q, %q, %v)", tt.wantReq, tt.wantReal, tt.wantBool, gotData.Model, gotData.ModelID, ok)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,489 @@
|
|||||||
|
package scheduler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
|
)
|
||||||
|
|
||||||
|
// defaultConcurrencyLimit caps simultaneous in-flight requests per model when
|
||||||
|
// the model config leaves concurrencyLimit unset.
|
||||||
|
const defaultConcurrencyLimit = 10
|
||||||
|
|
||||||
|
// activeSwap tracks one in-flight swap and the callers waiting on it.
|
||||||
|
type activeSwap struct {
|
||||||
|
modelID string
|
||||||
|
evict []string
|
||||||
|
waiters []HandlerReq
|
||||||
|
}
|
||||||
|
|
||||||
|
// FIFO is the default scheduler. Requests are handled in a first-in, first-out order.
|
||||||
|
// To reduce swapping requests for a model that is already running will be handled
|
||||||
|
// immediately by the running process.
|
||||||
|
//
|
||||||
|
// Requests into this schedule are handled like this:
|
||||||
|
//
|
||||||
|
// A B C A B C --> A A B B C C
|
||||||
|
//
|
||||||
|
// The strategy is simple and reduces the number of swaps required.
|
||||||
|
type FIFO struct {
|
||||||
|
name string
|
||||||
|
logger *logmon.Monitor
|
||||||
|
planner Swapper
|
||||||
|
cfg config.FifoConfig
|
||||||
|
effects Effects
|
||||||
|
|
||||||
|
limits map[string]int
|
||||||
|
active map[string]*activeSwap
|
||||||
|
inFlight map[string]int
|
||||||
|
queued []HandlerReq
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFIFO builds a FIFO scheduler. Per-model concurrency limits are derived
|
||||||
|
// from models: each model's ConcurrencyLimit overrides defaultConcurrencyLimit
|
||||||
|
// when set to a value greater than zero.
|
||||||
|
func NewFIFO(name string, logger *logmon.Monitor, planner Swapper, cfg config.FifoConfig, models map[string]config.ModelConfig, eff Effects) *FIFO {
|
||||||
|
limits := make(map[string]int, len(models))
|
||||||
|
for id, mc := range models {
|
||||||
|
limit := defaultConcurrencyLimit
|
||||||
|
if mc.ConcurrencyLimit > 0 {
|
||||||
|
limit = mc.ConcurrencyLimit
|
||||||
|
}
|
||||||
|
limits[id] = limit
|
||||||
|
}
|
||||||
|
|
||||||
|
return &FIFO{
|
||||||
|
name: name,
|
||||||
|
logger: logger,
|
||||||
|
planner: planner,
|
||||||
|
cfg: cfg,
|
||||||
|
effects: eff,
|
||||||
|
limits: limits,
|
||||||
|
active: make(map[string]*activeSwap),
|
||||||
|
inFlight: make(map[string]int),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnRequest decides what to do with one incoming ServeHTTP request. It never
|
||||||
|
// blocks indefinitely: any work that has to wait (starting a process, stopping
|
||||||
|
// siblings, waiting for ready) is deferred to a swap goroutine and reported back
|
||||||
|
// via OnSwapDone.
|
||||||
|
//
|
||||||
|
// The decision tree, in order:
|
||||||
|
//
|
||||||
|
// 1. Unknown model — respond with ErrModelNotFound and move on.
|
||||||
|
// 2. A swap to the same model is already in flight — attach this waiter so
|
||||||
|
// one swap serves all callers that asked for the same model.
|
||||||
|
// 3. Fast path — the target process is already ready, the planner sees
|
||||||
|
// nothing to evict, and no in-flight swap is evicting it. Hand back its
|
||||||
|
// ServeHTTP immediately.
|
||||||
|
// 4. Would collide with an in-flight swap (we'd stop their target, or they're
|
||||||
|
// stopping us) — park in the queue for OnSwapDone to drain.
|
||||||
|
// 5. Would evict a process that is still handling requests — park in the
|
||||||
|
// queue. OnServeDone will retry when the busy process drains.
|
||||||
|
// 6. Otherwise — start a new swap. This may run in parallel with other active
|
||||||
|
// swaps when their evict sets don't intersect.
|
||||||
|
func (s *FIFO) OnRequest(req HandlerReq) {
|
||||||
|
// (1) Unknown model.
|
||||||
|
state, ok := s.effects.ModelState(req.Model)
|
||||||
|
if !ok {
|
||||||
|
s.logger.Debugf("%s: model %s not handled by this router", s.name, req.Model)
|
||||||
|
s.effects.GrantError(req, ErrModelNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// (2) Join an in-flight swap for the same model.
|
||||||
|
if sw, ok := s.active[req.Model]; ok {
|
||||||
|
s.logger.Debugf("%s: joining in-flight swap for model %s (%d waiters)", s.name, req.Model, len(sw.waiters)+1)
|
||||||
|
sw.waiters = append(sw.waiters, req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
running := s.runningSet(req.Model)
|
||||||
|
evict := s.planner.EvictionFor(req.Model, running)
|
||||||
|
|
||||||
|
// (3) Fast path: ready, nothing to evict, and nobody is evicting us.
|
||||||
|
if state == process.StateReady && len(evict) == 0 && !collidesWith(req.Model, evict, s.active) {
|
||||||
|
s.logger.Debugf("%s: fast-path serving model %s (already ready)", s.name, req.Model)
|
||||||
|
s.grantHandler(req, req.Model)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// (4) Collision with an in-flight swap — queue.
|
||||||
|
if collidesWith(req.Model, evict, s.active) {
|
||||||
|
s.logger.Debugf("%s: queuing request for model %s (collides with in-flight swap)", s.name, req.Model)
|
||||||
|
s.enqueue(req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// (5) Would evict a busy process — queue until it drains.
|
||||||
|
if conflictsWithInFlight(evict, s.inFlight) {
|
||||||
|
s.logger.Debugf("%s: queuing request for model %s (would evict in-flight process)", s.name, req.Model)
|
||||||
|
s.enqueue(req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// (6) Start a new (possibly parallel) swap.
|
||||||
|
s.logger.Debugf("%s: starting swap for model %s, evicting %v", s.name, req.Model, evict)
|
||||||
|
s.startSwap(req, evict, running)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnCancel removes a request whose client has disconnected from the queue and
|
||||||
|
// from every in-flight swap's waiters. If the request was the sole waiter of an
|
||||||
|
// active swap, the swap goroutine is left to complete on its own — OnSwapDone
|
||||||
|
// will find no waiters and simply clean up. This prevents drainQueue from ever
|
||||||
|
// starting a model load for a caller that is no longer there.
|
||||||
|
func (s *FIFO) OnCancel(req HandlerReq) {
|
||||||
|
removed := false
|
||||||
|
|
||||||
|
// Prune from the queue.
|
||||||
|
if len(s.queued) > 0 {
|
||||||
|
kept := s.queued[:0]
|
||||||
|
for _, q := range s.queued {
|
||||||
|
if q.Respond == req.Respond {
|
||||||
|
removed = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
kept = append(kept, q)
|
||||||
|
}
|
||||||
|
s.queued = kept
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prune from any active swap's waiters.
|
||||||
|
for _, sw := range s.active {
|
||||||
|
filtered := sw.waiters[:0]
|
||||||
|
for _, w := range sw.waiters {
|
||||||
|
if w.Respond == req.Respond {
|
||||||
|
removed = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered = append(filtered, w)
|
||||||
|
}
|
||||||
|
sw.waiters = filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
if removed {
|
||||||
|
s.logger.Debugf("%s: cancelled request for model %s pruned from scheduler", s.name, req.Model)
|
||||||
|
broadcastQueuePositions(s.queued)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnSwapDone fans the result out to every waiter that joined this swap, removes
|
||||||
|
// the swap from the active map, then walks the queue once, promoting any items
|
||||||
|
// that no longer collide with the remaining active set. FIFO order is preserved:
|
||||||
|
// items still blocked stay in place.
|
||||||
|
func (s *FIFO) OnSwapDone(ev SwapDone) {
|
||||||
|
sw, ok := s.active[ev.ModelID]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(s.active, ev.ModelID)
|
||||||
|
|
||||||
|
for _, w := range sw.waiters {
|
||||||
|
if ev.Err != nil {
|
||||||
|
s.effects.GrantError(w, ev.Err)
|
||||||
|
} else {
|
||||||
|
s.grantHandler(w, ev.ModelID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.drainQueue()
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnServeDone decrements the per-model in-flight count and, when that drops to
|
||||||
|
// zero, retries the queue: requests whose swap was deferred because they would
|
||||||
|
// have evicted this (now-idle) process can now proceed.
|
||||||
|
func (s *FIFO) OnServeDone(ev ServeDoneEvent) {
|
||||||
|
s.inFlight[ev.ModelID]--
|
||||||
|
if s.inFlight[ev.ModelID] <= 0 {
|
||||||
|
delete(s.inFlight, ev.ModelID)
|
||||||
|
s.drainQueue()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnUnload reconciles router-owned state with the impending Stop, performs the
|
||||||
|
// Stop (synchronously, via Effects) so callers of Unload remain blocked until
|
||||||
|
// each targeted process has exited, then drains the queue.
|
||||||
|
func (s *FIFO) OnUnload(targets []string, timeout time.Duration) {
|
||||||
|
unloadErr := fmt.Errorf("%s: model unloaded", s.name)
|
||||||
|
|
||||||
|
targetSet := make(map[string]bool, len(targets))
|
||||||
|
for _, id := range targets {
|
||||||
|
targetSet[id] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release waiters of any in-flight swap whose target is being unloaded.
|
||||||
|
// The swap goroutine itself is left to finish on its own; when its
|
||||||
|
// SwapDone arrives, OnSwapDone will find no entry in active and drop it.
|
||||||
|
for id := range targetSet {
|
||||||
|
sw, ok := s.active[id]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, w := range sw.waiters {
|
||||||
|
s.effects.GrantError(w, unloadErr)
|
||||||
|
}
|
||||||
|
delete(s.active, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drop queued requests addressed to unloaded models. Requests for other
|
||||||
|
// models stay queued and may benefit from drainQueue at the end.
|
||||||
|
if len(s.queued) > 0 {
|
||||||
|
kept := s.queued[:0]
|
||||||
|
for _, w := range s.queued {
|
||||||
|
if targetSet[w.Model] {
|
||||||
|
s.effects.GrantError(w, unloadErr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
kept = append(kept, w)
|
||||||
|
}
|
||||||
|
s.queued = kept
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop the targeted processes. Done synchronously so Unload's caller can
|
||||||
|
// rely on "after Unload returns, the process is stopped". inFlight is
|
||||||
|
// intentionally NOT cleared here: each dying handler will fire its tracked
|
||||||
|
// serve and reach OnServeDone in the normal way.
|
||||||
|
s.effects.StopProcesses(timeout, targets)
|
||||||
|
|
||||||
|
// Removing entries from active above may have unblocked queued requests
|
||||||
|
// that previously collided with the now-cancelled swaps.
|
||||||
|
s.drainQueue()
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnShutdown grants err to every waiter still held by the scheduler.
|
||||||
|
func (s *FIFO) OnShutdown(err error) {
|
||||||
|
for _, sw := range s.active {
|
||||||
|
for _, w := range sw.waiters {
|
||||||
|
s.effects.GrantError(w, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, w := range s.queued {
|
||||||
|
s.effects.GrantError(w, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// grantHandler hands the caller a tracked handler for modelID and, only if the
|
||||||
|
// caller was still there to receive it, bumps the in-flight count. Incrementing
|
||||||
|
// when the grant failed would strand the counter and block future evictions.
|
||||||
|
// Requests that would exceed the model's concurrency limit are rejected with a
|
||||||
|
// shared.NewConcurrencyLimitError (HTTP 429 with Retry-After).
|
||||||
|
func (s *FIFO) grantHandler(req HandlerReq, modelID string) {
|
||||||
|
if s.inFlight[modelID] >= s.limit(modelID) {
|
||||||
|
s.effects.GrantError(req, shared.ConcurrencyLimitError{})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := shared.SetReqData(req.Ctx, "fifo_priority", strconv.Itoa(s.cfg.Priority[req.Model])); err != nil {
|
||||||
|
s.logger.Debugf("failed to set fifo_priority metadata: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.effects.GrantServe(req, modelID) {
|
||||||
|
s.inFlight[modelID]++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// limit returns the per-model concurrency cap, defaulting to
|
||||||
|
// defaultConcurrencyLimit when the model has no explicit entry.
|
||||||
|
func (s *FIFO) limit(modelID string) int {
|
||||||
|
if l, ok := s.limits[modelID]; ok {
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
return defaultConcurrencyLimit
|
||||||
|
}
|
||||||
|
|
||||||
|
// startSwap records the swap as active and launches it via Effects. running is
|
||||||
|
// the set EvictionFor saw, forwarded to OnSwapStart so the planner logs against
|
||||||
|
// the same picture it decided on.
|
||||||
|
func (s *FIFO) startSwap(initial HandlerReq, evict, running []string) {
|
||||||
|
s.active[initial.Model] = &activeSwap{
|
||||||
|
modelID: initial.Model,
|
||||||
|
evict: evict,
|
||||||
|
waiters: []HandlerReq{initial},
|
||||||
|
}
|
||||||
|
s.planner.OnSwapStart(initial.Model, running)
|
||||||
|
s.effects.StartSwap(initial.Model, evict)
|
||||||
|
}
|
||||||
|
|
||||||
|
// enqueue inserts req into the queue in priority order: it goes just before the
|
||||||
|
// first queued item whose priority is strictly lower, so higher-priority models
|
||||||
|
// are serviced first while equal-priority requests keep their arrival (FIFO)
|
||||||
|
// order. Priorities come from the FifoConfig; unlisted models default to 0.
|
||||||
|
func (s *FIFO) enqueue(req HandlerReq) {
|
||||||
|
p := s.cfg.Priority[req.Model]
|
||||||
|
i := len(s.queued)
|
||||||
|
for j, q := range s.queued {
|
||||||
|
if s.cfg.Priority[q.Model] < p {
|
||||||
|
i = j
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.queued = append(s.queued, HandlerReq{})
|
||||||
|
copy(s.queued[i+1:], s.queued[i:])
|
||||||
|
s.queued[i] = req
|
||||||
|
broadcastQueuePositions(s.queued)
|
||||||
|
}
|
||||||
|
|
||||||
|
// drainQueue walks the queued requests in order, re-running the OnRequest
|
||||||
|
// decision tree against the (now smaller) active set. Items that can now start
|
||||||
|
// or join become satisfied; items still blocked remain queued in original order
|
||||||
|
// so they get another chance on the next swap completion.
|
||||||
|
func (s *FIFO) drainQueue() {
|
||||||
|
if len(s.queued) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pending := s.queued
|
||||||
|
var remaining []HandlerReq
|
||||||
|
for _, req := range pending {
|
||||||
|
state, ok := s.effects.ModelState(req.Model)
|
||||||
|
if !ok {
|
||||||
|
s.effects.GrantError(req, ErrModelNotFound)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if sw, ok := s.active[req.Model]; ok {
|
||||||
|
s.logger.Debugf("%s: queued request for model %s now joining in-flight swap", s.name, req.Model)
|
||||||
|
sw.waiters = append(sw.waiters, req)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
running := s.runningSet(req.Model)
|
||||||
|
evict := s.planner.EvictionFor(req.Model, running)
|
||||||
|
if state == process.StateReady && len(evict) == 0 && !collidesWith(req.Model, evict, s.active) {
|
||||||
|
s.logger.Debugf("%s: queued request for model %s now served fast-path", s.name, req.Model)
|
||||||
|
s.grantHandler(req, req.Model)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if collidesWith(req.Model, evict, s.active) {
|
||||||
|
remaining = append(remaining, req)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if conflictsWithInFlight(evict, s.inFlight) {
|
||||||
|
remaining = append(remaining, req)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.logger.Debugf("%s: queued request for model %s now starting swap, evicting %v", s.name, req.Model, evict)
|
||||||
|
s.startSwap(req, evict, running)
|
||||||
|
}
|
||||||
|
s.queued = remaining
|
||||||
|
broadcastQueuePositions(s.queued)
|
||||||
|
}
|
||||||
|
|
||||||
|
// runningSet is the live model set handed to the Swapper: every process the
|
||||||
|
// baseRouter reports as running, unioned with the targets of in-flight swaps
|
||||||
|
// (excluding excludeActive, the model whose own swap is being decided — its
|
||||||
|
// in-flight entry must not count as "already running"). The result is sorted so
|
||||||
|
// eviction decisions derived from it are deterministic.
|
||||||
|
func (s *FIFO) runningSet(excludeActive string) []string {
|
||||||
|
seen := make(map[string]struct{})
|
||||||
|
var out []string
|
||||||
|
add := func(id string) {
|
||||||
|
if _, dup := seen[id]; dup {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen[id] = struct{}{}
|
||||||
|
out = append(out, id)
|
||||||
|
}
|
||||||
|
for id := range s.effects.RunningModels() {
|
||||||
|
add(id)
|
||||||
|
}
|
||||||
|
for _, id := range activeTargets(s.active, excludeActive) {
|
||||||
|
add(id)
|
||||||
|
}
|
||||||
|
sort.Strings(out)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// activeTargets returns the IDs of every in-flight swap target except exclude.
|
||||||
|
// The planner uses this to account for models committed to but not yet reflected
|
||||||
|
// in process state.
|
||||||
|
func activeTargets(active map[string]*activeSwap, exclude string) []string {
|
||||||
|
if len(active) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]string, 0, len(active))
|
||||||
|
for id := range active {
|
||||||
|
if id == exclude {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, id)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// collidesWith reports whether a new swap with this target and evict set can
|
||||||
|
// safely run alongside the currently active swaps. Same-target callers should
|
||||||
|
// JOIN (handled before this) — they do not collide with themselves.
|
||||||
|
func collidesWith(target string, evict []string, active map[string]*activeSwap) bool {
|
||||||
|
for id, sw := range active {
|
||||||
|
if id == target {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if containsString(evict, id) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if containsString(sw.evict, target) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if slicesOverlap(evict, sw.evict) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// slicesOverlap reports whether xs and ys share any common element.
|
||||||
|
func slicesOverlap(xs, ys []string) bool {
|
||||||
|
for _, x := range xs {
|
||||||
|
if containsString(ys, x) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// conflictsWithInFlight reports whether any model in evict is still handling
|
||||||
|
// requests. Stopping a busy process would cancel its callers' connections, so
|
||||||
|
// the scheduler defers the swap until those callers finish.
|
||||||
|
func conflictsWithInFlight(evict []string, inFlight map[string]int) bool {
|
||||||
|
for _, m := range evict {
|
||||||
|
if inFlight[m] > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsString(xs []string, s string) bool {
|
||||||
|
for _, x := range xs {
|
||||||
|
if x == s {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// broadcastQueuePositions sends each queued request its current 1-indexed
|
||||||
|
// position. Sends are non-blocking: if the channel is full, the old value is
|
||||||
|
// drained first so the consumer always sees the latest position.
|
||||||
|
func broadcastQueuePositions(queued []HandlerReq) {
|
||||||
|
for i, req := range queued {
|
||||||
|
pos := i + 1
|
||||||
|
select {
|
||||||
|
case req.PositionCh <- pos:
|
||||||
|
default:
|
||||||
|
select {
|
||||||
|
case <-req.PositionCh:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case req.PositionCh <- pos:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,779 @@
|
|||||||
|
package scheduler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FIFO methods all run on the router's single run-loop goroutine, so these
|
||||||
|
// tests drive them directly and synchronously. A swap is "completed" by calling
|
||||||
|
// OnSwapDone, a served request "finishes" by calling OnServeDone — exactly the
|
||||||
|
// events the run loop would deliver. fakeEffects records every side-effect and
|
||||||
|
// stubPlanner supplies a fixed eviction set per target.
|
||||||
|
|
||||||
|
// stubPlanner returns a fixed eviction list per target.
|
||||||
|
type stubPlanner struct {
|
||||||
|
evict map[string][]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubPlanner) EvictionFor(target string, _ []string) []string {
|
||||||
|
if s.evict == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.evict[target]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubPlanner) OnSwapStart(string, []string) {}
|
||||||
|
|
||||||
|
// grantRec is one GrantError / GrantServe call. err!=nil marks an error grant;
|
||||||
|
// otherwise it is a serve grant and serve reports whether the caller received it.
|
||||||
|
type grantRec struct {
|
||||||
|
model string
|
||||||
|
err error
|
||||||
|
serve bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type startRec struct {
|
||||||
|
model string
|
||||||
|
evict []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type stopRec struct {
|
||||||
|
timeout time.Duration
|
||||||
|
ids []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// fakeEffects is an in-memory scheduler.Effects. Tests program process states
|
||||||
|
// and GrantServe outcomes, then assert on the recorded calls.
|
||||||
|
type fakeEffects struct {
|
||||||
|
states map[string]process.ProcessState // model -> state; missing => not handled
|
||||||
|
serveResult map[string]bool // GrantServe return per model (default true)
|
||||||
|
lastServeReq HandlerReq
|
||||||
|
|
||||||
|
starts []startRec
|
||||||
|
grants []grantRec
|
||||||
|
stops []stopRec
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFakeEffects() *fakeEffects {
|
||||||
|
return &fakeEffects{
|
||||||
|
states: map[string]process.ProcessState{},
|
||||||
|
serveResult: map[string]bool{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeEffects) ModelState(modelID string) (process.ProcessState, bool) {
|
||||||
|
st, ok := f.states[modelID]
|
||||||
|
return st, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeEffects) RunningModels() map[string]process.ProcessState {
|
||||||
|
out := make(map[string]process.ProcessState)
|
||||||
|
for id, st := range f.states {
|
||||||
|
if st == process.StateStopped || st == process.StateShutdown {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out[id] = st
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeEffects) StartSwap(modelID string, evict []string) {
|
||||||
|
f.starts = append(f.starts, startRec{model: modelID, evict: evict})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeEffects) GrantError(req HandlerReq, err error) {
|
||||||
|
f.grants = append(f.grants, grantRec{model: req.Model, err: err})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeEffects) GrantServe(req HandlerReq, modelID string) bool {
|
||||||
|
ok := true
|
||||||
|
if v, set := f.serveResult[modelID]; set {
|
||||||
|
ok = v
|
||||||
|
}
|
||||||
|
f.lastServeReq = req
|
||||||
|
f.grants = append(f.grants, grantRec{model: modelID, serve: ok})
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeEffects) StopProcesses(timeout time.Duration, ids []string) {
|
||||||
|
f.stops = append(f.stops, stopRec{timeout: timeout, ids: ids})
|
||||||
|
}
|
||||||
|
|
||||||
|
// served counts grants that handed modelID a handler and were received.
|
||||||
|
func (f *fakeEffects) served(modelID string) int {
|
||||||
|
n := 0
|
||||||
|
for _, g := range f.grants {
|
||||||
|
if g.err == nil && g.serve && g.model == modelID {
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// errored counts error grants, optionally filtered by model ("" = any).
|
||||||
|
func (f *fakeEffects) errored(model string) int {
|
||||||
|
n := 0
|
||||||
|
for _, g := range f.grants {
|
||||||
|
if g.err != nil && (model == "" || g.model == model) {
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// startsFor counts StartSwap calls for modelID.
|
||||||
|
func (f *fakeEffects) startsFor(modelID string) int {
|
||||||
|
n := 0
|
||||||
|
for _, s := range f.starts {
|
||||||
|
if s.model == modelID {
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFIFO(planner Swapper, eff Effects) *FIFO {
|
||||||
|
return NewFIFO("test", logmon.NewWriter(io.Discard), planner, config.FifoConfig{}, nil, eff)
|
||||||
|
}
|
||||||
|
|
||||||
|
func req(model string) HandlerReq { return HandlerReq{Model: model} }
|
||||||
|
|
||||||
|
// reqCh creates a HandlerReq with a unique Respond channel so OnCancel can
|
||||||
|
// identify it among queued requests and swap waiters.
|
||||||
|
func reqCh(model string) HandlerReq {
|
||||||
|
return HandlerReq{
|
||||||
|
Model: model,
|
||||||
|
Respond: make(chan HandlerResp, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFIFO_FastPath(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
s := newFIFO(&stubPlanner{}, eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
|
||||||
|
if got := eff.startsFor("a"); got != 0 {
|
||||||
|
t.Errorf("StartSwap calls=%d want 0 (fast path should not swap)", got)
|
||||||
|
}
|
||||||
|
if got := eff.served("a"); got != 1 {
|
||||||
|
t.Errorf("served(a)=%d want 1", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFIFO_GrantSetsPriorityMetadata(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
cfg := config.FifoConfig{Priority: map[string]int{"a": 7}}
|
||||||
|
s := NewFIFO("test", logmon.NewWriter(io.Discard), &stubPlanner{}, cfg, nil, eff)
|
||||||
|
|
||||||
|
ctx := shared.SetContext(context.Background(), shared.ReqContextData{ModelID: "a", Metadata: make(map[string]string)})
|
||||||
|
s.OnRequest(HandlerReq{Model: "a", Ctx: ctx})
|
||||||
|
|
||||||
|
if got := eff.served("a"); got != 1 {
|
||||||
|
t.Fatalf("served(a)=%d want 1", got)
|
||||||
|
}
|
||||||
|
data, ok := shared.ReadContext(eff.lastServeReq.Ctx)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("context data missing from granted request")
|
||||||
|
}
|
||||||
|
if data.Metadata["fifo_priority"] != "7" {
|
||||||
|
t.Errorf("fifo_priority = %q, want 7", data.Metadata["fifo_priority"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFIFO_ModelNotFound(t *testing.T) {
|
||||||
|
eff := newFakeEffects() // no states => model unknown
|
||||||
|
s := newFIFO(&stubPlanner{}, eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("ghost"))
|
||||||
|
|
||||||
|
if got := len(eff.starts); got != 0 {
|
||||||
|
t.Errorf("StartSwap calls=%d want 0", got)
|
||||||
|
}
|
||||||
|
if eff.errored("ghost") != 1 {
|
||||||
|
t.Fatalf("want 1 error grant for ghost, grants=%+v", eff.grants)
|
||||||
|
}
|
||||||
|
if !errors.Is(eff.grants[0].err, ErrModelNotFound) {
|
||||||
|
t.Errorf("err=%v want ErrModelNotFound", eff.grants[0].err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFIFO_OnDemandStartThenServe(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
s := newFIFO(&stubPlanner{}, eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
if got := eff.startsFor("a"); got != 1 {
|
||||||
|
t.Fatalf("StartSwap(a)=%d want 1", got)
|
||||||
|
}
|
||||||
|
if got := eff.served("a"); got != 0 {
|
||||||
|
t.Errorf("served(a)=%d want 0 before swap completes", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Swap finishes, model is now ready.
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||||
|
|
||||||
|
if got := eff.served("a"); got != 1 {
|
||||||
|
t.Errorf("served(a)=%d want 1 after swap done", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFIFO_JoinInFlightSwap(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
s := newFIFO(&stubPlanner{}, eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a")) // starts swap
|
||||||
|
s.OnRequest(req("a")) // joins
|
||||||
|
s.OnRequest(req("a")) // joins
|
||||||
|
|
||||||
|
if got := eff.startsFor("a"); got != 1 {
|
||||||
|
t.Fatalf("StartSwap(a)=%d want 1 (all three share one swap)", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||||
|
|
||||||
|
if got := eff.served("a"); got != 3 {
|
||||||
|
t.Errorf("served(a)=%d want 3 (one swap serves all waiters)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFIFO_SwapDoneError_FailsAllWaiters(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
s := newFIFO(&stubPlanner{}, eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: "a", Err: errors.New("boom")})
|
||||||
|
|
||||||
|
if eff.served("a") != 0 {
|
||||||
|
t.Errorf("served(a)=%d want 0 on swap error", eff.served("a"))
|
||||||
|
}
|
||||||
|
if eff.errored("a") != 2 {
|
||||||
|
t.Errorf("errored(a)=%d want 2 (both waiters fail)", eff.errored("a"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFIFO_QueueOnEvictionCollision covers a request whose target evicts the
|
||||||
|
// model currently being swapped: it must queue until that swap finishes AND its
|
||||||
|
// served request drains, because starting it would stop a busy process.
|
||||||
|
func TestFIFO_QueueOnEvictionCollision(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
eff.states["b"] = process.StateStopped
|
||||||
|
// Loading b evicts a.
|
||||||
|
s := newFIFO(&stubPlanner{evict: map[string][]string{"b": {"a"}}}, eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a")) // StartSwap(a)
|
||||||
|
s.OnRequest(req("b")) // collides with a's in-flight swap -> queue
|
||||||
|
if got := eff.startsFor("b"); got != 0 {
|
||||||
|
t.Fatalf("b started early: StartSwap(b)=%d want 0", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// a becomes ready and is granted (now serving, inFlight[a]=1).
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||||
|
if got := eff.startsFor("b"); got != 0 {
|
||||||
|
t.Fatalf("b started while a is serving: StartSwap(b)=%d want 0", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// a's request finishes -> a no longer in-flight -> b may now swap.
|
||||||
|
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
|
||||||
|
if got := eff.startsFor("b"); got != 1 {
|
||||||
|
t.Fatalf("StartSwap(b)=%d want 1 after a drained", got)
|
||||||
|
}
|
||||||
|
if got := eff.starts[len(eff.starts)-1].evict; len(got) != 1 || got[0] != "a" {
|
||||||
|
t.Errorf("b swap evict=%v want [a]", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFIFO_DisjointSwapsRunInParallel verifies two requests with
|
||||||
|
// non-conflicting evict sets both start without waiting for each other.
|
||||||
|
func TestFIFO_DisjointSwapsRunInParallel(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
eff.states["b"] = process.StateStopped
|
||||||
|
s := newFIFO(&stubPlanner{}, eff) // empty evicts
|
||||||
|
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
s.OnRequest(req("b"))
|
||||||
|
|
||||||
|
if eff.startsFor("a") != 1 || eff.startsFor("b") != 1 {
|
||||||
|
t.Fatalf("StartSwap a=%d b=%d want 1 each (parallel)", eff.startsFor("a"), eff.startsFor("b"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFIFO_OverlappingEvictSetsDoNotRunInParallel verifies two swaps with
|
||||||
|
// different targets that evict the *same* model do not run concurrently: the
|
||||||
|
// second must queue rather than double-evict the shared model. Neither target is
|
||||||
|
// in the other's evict set, so this is only caught by the evict-set overlap
|
||||||
|
// check in collidesWith.
|
||||||
|
func TestFIFO_OverlappingEvictSetsDoNotRunInParallel(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
eff.states["b"] = process.StateStopped
|
||||||
|
eff.states["x"] = process.StateReady // shared eviction target, running
|
||||||
|
// Loading a or b both require evicting x.
|
||||||
|
s := newFIFO(&stubPlanner{evict: map[string][]string{"a": {"x"}, "b": {"x"}}}, eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a")) // StartSwap(a, [x])
|
||||||
|
s.OnRequest(req("b")) // overlaps a's evict set ([x]) -> queue
|
||||||
|
if eff.startsFor("a") != 1 {
|
||||||
|
t.Fatalf("StartSwap(a)=%d want 1", eff.startsFor("a"))
|
||||||
|
}
|
||||||
|
if got := eff.startsFor("b"); got != 0 {
|
||||||
|
t.Fatalf("b started in parallel while a evicts x: StartSwap(b)=%d want 0", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// a's swap completes and x is gone; b can now evict nothing and start.
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
eff.states["x"] = process.StateStopped
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||||
|
if got := eff.startsFor("b"); got != 1 {
|
||||||
|
t.Fatalf("StartSwap(b)=%d want 1 after a's swap drained", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFIFO_QueueDrainPromotesMultiple verifies completing one swap unblocks
|
||||||
|
// every queued request that no longer collides — they all start together.
|
||||||
|
func TestFIFO_QueueDrainPromotesMultiple(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
eff.states["b"] = process.StateStopped
|
||||||
|
eff.states["c"] = process.StateStopped
|
||||||
|
// a's swap evicts both b and c; b and c evict nothing.
|
||||||
|
s := newFIFO(&stubPlanner{evict: map[string][]string{"a": {"b", "c"}}}, eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a")) // StartSwap(a, [b,c])
|
||||||
|
s.OnRequest(req("b")) // collides (in a's evict set) -> queue
|
||||||
|
s.OnRequest(req("c")) // collides -> queue
|
||||||
|
if eff.startsFor("b") != 0 || eff.startsFor("c") != 0 {
|
||||||
|
t.Fatalf("b/c started early")
|
||||||
|
}
|
||||||
|
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||||
|
|
||||||
|
// b and c have empty evict sets and don't evict a, so both start now.
|
||||||
|
if eff.startsFor("b") != 1 || eff.startsFor("c") != 1 {
|
||||||
|
t.Fatalf("StartSwap b=%d c=%d want 1 each after a done", eff.startsFor("b"), eff.startsFor("c"))
|
||||||
|
}
|
||||||
|
if eff.served("a") != 1 {
|
||||||
|
t.Errorf("served(a)=%d want 1", eff.served("a"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFIFO_QueueCollation verifies duplicate requests collapse into one swap
|
||||||
|
// per model: the second request for each model joins the active swap (at arrival
|
||||||
|
// or at drain time) rather than triggering its own swap.
|
||||||
|
func TestFIFO_QueueCollation(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
for _, id := range []string{"a", "b", "c"} {
|
||||||
|
eff.states[id] = process.StateStopped
|
||||||
|
}
|
||||||
|
// Each model evicts the other two: all swaps are mutually exclusive.
|
||||||
|
s := newFIFO(&stubPlanner{evict: map[string][]string{
|
||||||
|
"a": {"b", "c"},
|
||||||
|
"b": {"a", "c"},
|
||||||
|
"c": {"a", "b"},
|
||||||
|
}}, eff)
|
||||||
|
|
||||||
|
for _, id := range []string{"a", "b", "c", "a", "b", "c"} {
|
||||||
|
s.OnRequest(req(id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drain a, then its served requests, which promotes b; repeat for b -> c.
|
||||||
|
drain := func(model string, waiters int) {
|
||||||
|
eff.states[model] = process.StateReady
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: model})
|
||||||
|
for i := 0; i < waiters; i++ {
|
||||||
|
s.OnServeDone(ServeDoneEvent{ModelID: model})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
drain("a", 2)
|
||||||
|
drain("b", 2)
|
||||||
|
drain("c", 2)
|
||||||
|
|
||||||
|
for _, id := range []string{"a", "b", "c"} {
|
||||||
|
if got := eff.startsFor(id); got != 1 {
|
||||||
|
t.Errorf("StartSwap(%s)=%d want 1 (collation)", id, got)
|
||||||
|
}
|
||||||
|
if got := eff.served(id); got != 2 {
|
||||||
|
t.Errorf("served(%s)=%d want 2", id, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFIFO_NoSwapWhileServing verifies a model still handling requests is not
|
||||||
|
// evicted: the evicting request waits until every in-flight request drains.
|
||||||
|
func TestFIFO_NoSwapWhileServing(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
eff.states["b"] = process.StateStopped
|
||||||
|
s := newFIFO(&stubPlanner{evict: map[string][]string{"b": {"a"}}}, eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a")) // fast path, inFlight[a]=1
|
||||||
|
s.OnRequest(req("a")) // fast path, inFlight[a]=2
|
||||||
|
s.OnRequest(req("b")) // would evict busy a -> queue
|
||||||
|
if eff.startsFor("b") != 0 {
|
||||||
|
t.Fatalf("b started while a serving")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.OnServeDone(ServeDoneEvent{ModelID: "a"}) // inFlight[a]=1
|
||||||
|
if eff.startsFor("b") != 0 {
|
||||||
|
t.Fatalf("b started while a still serving one request")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.OnServeDone(ServeDoneEvent{ModelID: "a"}) // inFlight[a]=0
|
||||||
|
if eff.startsFor("b") != 1 {
|
||||||
|
t.Fatalf("StartSwap(b)=%d want 1 after a fully drained", eff.startsFor("b"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFIFO_GrantServeFalseDoesNotLeakInFlight verifies that when a caller has
|
||||||
|
// walked away (GrantServe returns false) the in-flight count is not bumped, so a
|
||||||
|
// later evicting request is not blocked forever.
|
||||||
|
func TestFIFO_GrantServeFalseDoesNotLeakInFlight(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
eff.states["b"] = process.StateStopped
|
||||||
|
eff.serveResult["a"] = false // a's waiter is gone by grant time
|
||||||
|
s := newFIFO(&stubPlanner{evict: map[string][]string{"b": {"a"}}}, eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: "a"}) // grant fails, inFlight[a] stays 0
|
||||||
|
|
||||||
|
// b evicts a; since a is not in-flight, b should start immediately.
|
||||||
|
s.OnRequest(req("b"))
|
||||||
|
if eff.startsFor("b") != 1 {
|
||||||
|
t.Fatalf("StartSwap(b)=%d want 1 (no leaked in-flight on a)", eff.startsFor("b"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFIFO_OnShutdown_FailsAllWaiters verifies shutdown errors every waiter the
|
||||||
|
// scheduler holds: active-swap waiters and queued requests alike.
|
||||||
|
func TestFIFO_OnShutdown_FailsAllWaiters(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
for _, id := range []string{"a", "b", "c"} {
|
||||||
|
eff.states[id] = process.StateStopped
|
||||||
|
}
|
||||||
|
// a and b load in parallel; c collides with both and queues.
|
||||||
|
s := newFIFO(&stubPlanner{evict: map[string][]string{"c": {"a", "b"}}}, eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a")) // StartSwap(a)
|
||||||
|
s.OnRequest(req("a")) // join a
|
||||||
|
s.OnRequest(req("b")) // StartSwap(b)
|
||||||
|
s.OnRequest(req("b")) // join b
|
||||||
|
s.OnRequest(req("c")) // queued
|
||||||
|
|
||||||
|
s.OnShutdown(errors.New("shutting down"))
|
||||||
|
|
||||||
|
if got := eff.errored(""); got != 5 {
|
||||||
|
t.Errorf("error grants=%d want 5 (2 a + 2 b + 1 c)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFIFO_OnUnload_ReleasesActiveWaiters(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
s := newFIFO(&stubPlanner{}, eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a")) // active swap a with one waiter
|
||||||
|
s.OnRequest(req("a")) // join
|
||||||
|
|
||||||
|
s.OnUnload([]string{"a"}, time.Second)
|
||||||
|
|
||||||
|
if got := eff.errored("a"); got != 2 {
|
||||||
|
t.Errorf("errored(a)=%d want 2 (active swap waiters released)", got)
|
||||||
|
}
|
||||||
|
if len(eff.stops) != 1 || len(eff.stops[0].ids) != 1 || eff.stops[0].ids[0] != "a" {
|
||||||
|
t.Errorf("StopProcesses=%+v want one call stopping [a]", eff.stops)
|
||||||
|
}
|
||||||
|
if eff.stops[0].timeout != time.Second {
|
||||||
|
t.Errorf("StopProcesses timeout=%v want 1s", eff.stops[0].timeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFIFO_OnUnload_DropsQueuedRequests(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
eff.states["b"] = process.StateStopped
|
||||||
|
// b evicts a, so a request for b queues while a is loading.
|
||||||
|
s := newFIFO(&stubPlanner{evict: map[string][]string{"b": {"a"}}}, eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a")) // StartSwap(a)
|
||||||
|
s.OnRequest(req("b")) // queued
|
||||||
|
|
||||||
|
s.OnUnload([]string{"b"}, time.Second)
|
||||||
|
|
||||||
|
if got := eff.errored("b"); got != 1 {
|
||||||
|
t.Errorf("errored(b)=%d want 1 (queued request dropped)", got)
|
||||||
|
}
|
||||||
|
if got := eff.startsFor("b"); got != 0 {
|
||||||
|
t.Errorf("StartSwap(b)=%d want 0 (b should never start)", got)
|
||||||
|
}
|
||||||
|
// a's swap is untouched: its waiter is neither served nor errored yet.
|
||||||
|
if eff.served("a") != 0 || eff.errored("a") != 0 {
|
||||||
|
t.Errorf("a swap should be untouched: served=%d errored=%d", eff.served("a"), eff.errored("a"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFIFO_PriorityQueueOrder verifies queued requests are ordered by descending
|
||||||
|
// priority, with arrival (FIFO) order preserved among equal-priority models.
|
||||||
|
func TestFIFO_PriorityQueueOrder(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
for _, m := range []string{"z", "A", "B", "C", "D"} {
|
||||||
|
eff.states[m] = process.StateStopped
|
||||||
|
}
|
||||||
|
// z's swap evicts every other model, so any request that arrives while z is
|
||||||
|
// loading collides with z's in-flight swap and parks in the queue.
|
||||||
|
planner := &stubPlanner{evict: map[string][]string{"z": {"A", "B", "C", "D"}}}
|
||||||
|
cfg := config.FifoConfig{Priority: map[string]int{"A": 10, "B": 5, "C": 5, "D": 1}}
|
||||||
|
s := NewFIFO("test", logmon.NewWriter(io.Discard), planner, cfg, nil, eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("z")) // StartSwap(z, [A,B,C,D])
|
||||||
|
|
||||||
|
// Arrive out of priority order; B before C exercises FIFO tie-breaking.
|
||||||
|
for _, m := range []string{"B", "D", "C", "A"} {
|
||||||
|
s.OnRequest(req(m))
|
||||||
|
}
|
||||||
|
|
||||||
|
got := make([]string, len(s.queued))
|
||||||
|
for i, q := range s.queued {
|
||||||
|
got[i] = q.Model
|
||||||
|
}
|
||||||
|
want := []string{"A", "B", "C", "D"}
|
||||||
|
if len(got) != len(want) {
|
||||||
|
t.Fatalf("queue=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if got[i] != want[i] {
|
||||||
|
t.Fatalf("queue=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFIFO_OnCancel_QueuedRequest verifies that cancelling a queued request
|
||||||
|
// prevents drainQueue from ever starting a model load for it. Without OnCancel
|
||||||
|
// the dead request would sit in the queue until a drain triggers a wasted swap.
|
||||||
|
func TestFIFO_OnCancel_QueuedRequest(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
eff.states["b"] = process.StateStopped
|
||||||
|
// b evicts a, so a request for b queues while a is loading.
|
||||||
|
s := newFIFO(&stubPlanner{evict: map[string][]string{"b": {"a"}}}, eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a")) // StartSwap(a)
|
||||||
|
|
||||||
|
cancelledReq := reqCh("b")
|
||||||
|
s.OnRequest(cancelledReq) // queued (collides with a's in-flight swap)
|
||||||
|
if len(s.queued) != 1 {
|
||||||
|
t.Fatalf("queue len=%d want 1 before cancel", len(s.queued))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client disconnects.
|
||||||
|
s.OnCancel(cancelledReq)
|
||||||
|
|
||||||
|
if len(s.queued) != 0 {
|
||||||
|
t.Fatalf("queue len=%d want 0 after cancel", len(s.queued))
|
||||||
|
}
|
||||||
|
|
||||||
|
// a's swap finishes; drainQueue runs but b is gone — no swap for b.
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||||
|
|
||||||
|
if got := eff.startsFor("b"); got != 0 {
|
||||||
|
t.Errorf("StartSwap(b)=%d want 0 (cancelled request should not trigger a load)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFIFO_OnCancel_SwapWaiter verifies that cancelling a request that joined an
|
||||||
|
// in-flight swap removes it from the waiter list. When the swap completes, the
|
||||||
|
// cancelled waiter receives no grant and does not bump the in-flight count.
|
||||||
|
func TestFIFO_OnCancel_SwapWaiter(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
s := newFIFO(&stubPlanner{}, eff)
|
||||||
|
|
||||||
|
liveReq := reqCh("a")
|
||||||
|
cancelledReq := reqCh("a")
|
||||||
|
s.OnRequest(liveReq) // starts swap
|
||||||
|
s.OnRequest(cancelledReq) // joins
|
||||||
|
|
||||||
|
if sw := s.active["a"]; len(sw.waiters) != 2 {
|
||||||
|
t.Fatalf("waiters=%d want 2", len(sw.waiters))
|
||||||
|
}
|
||||||
|
|
||||||
|
s.OnCancel(cancelledReq)
|
||||||
|
|
||||||
|
if sw := s.active["a"]; len(sw.waiters) != 1 {
|
||||||
|
t.Fatalf("waiters=%d want 1 after cancel", len(sw.waiters))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Swap finishes: only the live waiter is granted.
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||||
|
|
||||||
|
if got := eff.served("a"); got != 1 {
|
||||||
|
t.Errorf("served(a)=%d want 1 (only the non-cancelled waiter)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFIFO_OnCancel_NotPresent is a no-op: cancelling a request that was already
|
||||||
|
// granted (and is no longer queued or waiting) must not affect anything.
|
||||||
|
func TestFIFO_OnCancel_NotPresent(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
s := newFIFO(&stubPlanner{}, eff)
|
||||||
|
|
||||||
|
r := reqCh("a")
|
||||||
|
s.OnRequest(r) // fast-path served immediately
|
||||||
|
|
||||||
|
// Cancel after grant — should be a harmless no-op.
|
||||||
|
s.OnCancel(r)
|
||||||
|
|
||||||
|
if got := eff.served("a"); got != 1 {
|
||||||
|
t.Errorf("served(a)=%d want 1 (cancel of granted request is a no-op)", got)
|
||||||
|
}
|
||||||
|
if len(s.queued) != 0 {
|
||||||
|
t.Errorf("queue should be empty, len=%d", len(s.queued))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newFIFOWithLimit builds a FIFO whose single model has the given concurrency
|
||||||
|
// limit, already in StateReady so every request exercises the fast path.
|
||||||
|
func newFIFOWithLimit(t *testing.T, model string, limit int) (*FIFO, *fakeEffects) {
|
||||||
|
t.Helper()
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states[model] = process.StateReady
|
||||||
|
models := map[string]config.ModelConfig{
|
||||||
|
model: {ConcurrencyLimit: limit},
|
||||||
|
}
|
||||||
|
s := NewFIFO("test", logmon.NewWriter(io.Discard), &stubPlanner{}, config.FifoConfig{}, models, eff)
|
||||||
|
return s, eff
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFIFO_ConcurrencyLimit_RejectsOverLimit verifies that a request arriving
|
||||||
|
// while the model is at capacity gets an error grant instead of being served,
|
||||||
|
// and that a new request succeeds once an in-flight one completes.
|
||||||
|
func TestFIFO_ConcurrencyLimit_RejectsOverLimit(t *testing.T) {
|
||||||
|
s, eff := newFIFOWithLimit(t, "a", 1)
|
||||||
|
|
||||||
|
// First request: served (inFlight 0 → 1).
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
if got := eff.served("a"); got != 1 {
|
||||||
|
t.Fatalf("served(a)=%d want 1", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second request while slot is occupied: rejected with HTTPError 429.
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
if got := eff.errored("a"); got != 1 {
|
||||||
|
t.Fatalf("errored(a)=%d want 1 (over-limit)", got)
|
||||||
|
}
|
||||||
|
var httpErr shared.HTTPError
|
||||||
|
if !errors.As(eff.grants[len(eff.grants)-1].err, &httpErr) {
|
||||||
|
t.Fatalf("err=%v want HTTPError", eff.grants[len(eff.grants)-1].err)
|
||||||
|
}
|
||||||
|
if httpErr.StatusCode() != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("StatusCode()=%d want 429", httpErr.StatusCode())
|
||||||
|
}
|
||||||
|
if httpErr.Header().Get("Retry-After") == "" {
|
||||||
|
t.Fatal("missing Retry-After header")
|
||||||
|
}
|
||||||
|
|
||||||
|
// After the in-flight request finishes, a new request succeeds.
|
||||||
|
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
if got := eff.served("a"); got != 2 {
|
||||||
|
t.Fatalf("served(a)=%d want 2 after drain", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFIFO_ConcurrencyLimit_DefaultIsTen verifies that a model without an
|
||||||
|
// explicit ConcurrencyLimit gets the default cap of 10.
|
||||||
|
func TestFIFO_ConcurrencyLimit_DefaultIsTen(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
// nil models → every model gets defaultConcurrencyLimit (10).
|
||||||
|
s := newFIFO(&stubPlanner{}, eff)
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
}
|
||||||
|
if got := eff.served("a"); got != 10 {
|
||||||
|
t.Fatalf("served(a)=%d want 10 (default limit)", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 11th request is rejected.
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
if got := eff.errored("a"); got != 1 {
|
||||||
|
t.Fatalf("errored(a)=%d want 1 (over default limit)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFIFO_ConcurrencyLimit_CustomLimit verifies a ConcurrencyLimit greater
|
||||||
|
// than zero overrides the default.
|
||||||
|
func TestFIFO_ConcurrencyLimit_CustomLimit(t *testing.T) {
|
||||||
|
s, eff := newFIFOWithLimit(t, "a", 2)
|
||||||
|
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
|
||||||
|
if got := eff.served("a"); got != 2 {
|
||||||
|
t.Fatalf("served(a)=%d want 2 (custom limit)", got)
|
||||||
|
}
|
||||||
|
if got := eff.errored("a"); got != 1 {
|
||||||
|
t.Fatalf("errored(a)=%d want 1 (over custom limit)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFIFO_ConcurrencyLimit_SwapWaiters verifies that when more swap waiters
|
||||||
|
// exist than the concurrency limit, excess waiters are rejected on swap
|
||||||
|
// completion rather than exceeding the limit.
|
||||||
|
func TestFIFO_ConcurrencyLimit_SwapWaiters(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
models := map[string]config.ModelConfig{
|
||||||
|
"a": {ConcurrencyLimit: 2},
|
||||||
|
}
|
||||||
|
s := NewFIFO("test", logmon.NewWriter(io.Discard), &stubPlanner{}, config.FifoConfig{}, models, eff)
|
||||||
|
|
||||||
|
// Three requests arrive while model is loading: one starts swap, two join.
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
|
||||||
|
if got := eff.startsFor("a"); got != 1 {
|
||||||
|
t.Fatalf("StartSwap(a)=%d want 1", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Swap completes: two served (limit), one rejected.
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||||
|
|
||||||
|
if got := eff.served("a"); got != 2 {
|
||||||
|
t.Fatalf("served(a)=%d want 2 (limit on swap completion)", got)
|
||||||
|
}
|
||||||
|
if got := eff.errored("a"); got != 1 {
|
||||||
|
t.Fatalf("errored(a)=%d want 1 (excess waiter rejected)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,135 @@
|
|||||||
|
// Package scheduler contains the request-scheduling strategies used by the
|
||||||
|
// router's baseRouter. A Scheduler owns the queue, in-flight tracking, and the
|
||||||
|
// decision tree for when to start a swap versus queue a request. The baseRouter
|
||||||
|
// owns the channels, run loop, and process machinery, and exposes the
|
||||||
|
// side-effects a scheduler needs through the Effects interface.
|
||||||
|
//
|
||||||
|
// Splitting these apart lets the scheduling strategy be swapped out
|
||||||
|
// independently of both the process machinery (baseRouter) and the eviction
|
||||||
|
// policy (Swapper). FIFO is the first and currently only implementation.
|
||||||
|
package scheduler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrModelNotFound is granted to callers whose model is not handled by this
|
||||||
|
// router. It is an alias for shared.ErrNoLocalModelFound.
|
||||||
|
var ErrModelNotFound = shared.ErrNoLocalModelFound
|
||||||
|
|
||||||
|
// Swapper is the eviction policy: it decides which running models must be
|
||||||
|
// stopped before a target can serve. It is orthogonal to the scheduling
|
||||||
|
// strategy — any Scheduler works with any Swapper.
|
||||||
|
type Swapper interface {
|
||||||
|
// EvictionFor returns running model IDs that must be stopped before
|
||||||
|
// target can serve. running is the complete set the scheduler considers
|
||||||
|
// live: every process that is not stopped, unioned with the targets of
|
||||||
|
// in-flight swaps the scheduler has already committed to (which are not yet
|
||||||
|
// visible in process state). The planner does not inspect process state
|
||||||
|
// itself. Pure decision; must not log.
|
||||||
|
EvictionFor(target string, running []string) []string
|
||||||
|
|
||||||
|
// OnSwapStart runs once at the start of every swap, with the same running
|
||||||
|
// set EvictionFor was given for this decision. Planners may log their
|
||||||
|
// decision here at whatever verbosity they choose.
|
||||||
|
OnSwapStart(target string, running []string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scheduler decides what happens to each event the router's run loop receives.
|
||||||
|
// All methods run on that single run-loop goroutine, so implementations need no
|
||||||
|
// internal locking for their own state.
|
||||||
|
type Scheduler interface {
|
||||||
|
// OnRequest handles one incoming ServeHTTP request.
|
||||||
|
OnRequest(req HandlerReq)
|
||||||
|
// OnCancel handles a request whose client has disconnected before it was
|
||||||
|
// granted. The scheduler must remove the request from its queue and from
|
||||||
|
// any in-flight swap's waiters so it never triggers a model load or grant
|
||||||
|
// for a caller that is no longer there.
|
||||||
|
OnCancel(req HandlerReq)
|
||||||
|
// OnSwapDone handles a swap goroutine reporting completion.
|
||||||
|
OnSwapDone(ev SwapDone)
|
||||||
|
// OnServeDone handles a tracked ServeHTTP finishing (in-flight decrement).
|
||||||
|
OnServeDone(ev ServeDoneEvent)
|
||||||
|
// OnUnload reconciles scheduler state for an unload, stops the targeted
|
||||||
|
// processes via Effects, and drains the queue. It must block until the
|
||||||
|
// targeted processes have stopped.
|
||||||
|
OnUnload(targets []string, timeout time.Duration)
|
||||||
|
// OnShutdown grants err to every waiter the scheduler still holds (active
|
||||||
|
// swap waiters and queued requests). Process teardown is the baseRouter's
|
||||||
|
// responsibility.
|
||||||
|
OnShutdown(err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Effects is implemented by the baseRouter. The scheduler calls back through it
|
||||||
|
// for every side-effect: inspecting process state, launching swaps, responding
|
||||||
|
// to callers, and stopping processes.
|
||||||
|
type Effects interface {
|
||||||
|
// ModelState returns the current state of a model's process. ok is false
|
||||||
|
// when the model is not handled by this router.
|
||||||
|
ModelState(modelID string) (process.ProcessState, bool)
|
||||||
|
// RunningModels returns the state of every process that is not stopped or
|
||||||
|
// shut down, keyed by model ID. The scheduler uses it to build the running
|
||||||
|
// set it hands the Swapper.
|
||||||
|
RunningModels() map[string]process.ProcessState
|
||||||
|
// StartSwap launches the swap goroutine for modelID, stopping evict first.
|
||||||
|
StartSwap(modelID string, evict []string)
|
||||||
|
// GrantError responds to a caller with an error.
|
||||||
|
GrantError(req HandlerReq, err error)
|
||||||
|
// GrantServe hands a caller the wrapped handler for modelID and reports
|
||||||
|
// whether the caller was still there to receive it. The scheduler bumps
|
||||||
|
// its in-flight count only when this returns true.
|
||||||
|
GrantServe(req HandlerReq, modelID string) bool
|
||||||
|
// StopProcesses stops the named processes in parallel and blocks until all
|
||||||
|
// have stopped. Unknown IDs are skipped.
|
||||||
|
StopProcesses(timeout time.Duration, ids []string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns a Scheduler selected by conf.Routing.Scheduler.Use, configured
|
||||||
|
// from conf and bound to the given planner and effects. Currently only "fifo"
|
||||||
|
// (the default) is supported.
|
||||||
|
func New(conf config.Config, name string, logger *logmon.Monitor, planner Swapper, eff Effects) (Scheduler, error) {
|
||||||
|
use := conf.Routing.Scheduler.Use
|
||||||
|
if use == "" {
|
||||||
|
use = "fifo"
|
||||||
|
}
|
||||||
|
switch use {
|
||||||
|
case "fifo":
|
||||||
|
return NewFIFO(name, logger, planner, conf.Routing.Scheduler.Settings.Fifo, conf.Models, eff), nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported scheduler type: %q", use)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandlerReq is one in-flight ServeHTTP request waiting for a routing decision.
|
||||||
|
type HandlerReq struct {
|
||||||
|
Model string
|
||||||
|
Ctx context.Context
|
||||||
|
Respond chan HandlerResp
|
||||||
|
PositionCh chan int
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandlerResp is the routing decision returned to a HandlerReq's caller: either
|
||||||
|
// a handler to serve with, or an error.
|
||||||
|
type HandlerResp struct {
|
||||||
|
HandleFunc http.HandlerFunc
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// SwapDone is reported by a swap goroutine when its target is ready (or failed).
|
||||||
|
type SwapDone struct {
|
||||||
|
ModelID string
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeDoneEvent is reported when a tracked ServeHTTP handler returns.
|
||||||
|
type ServeDoneEvent struct {
|
||||||
|
ModelID string
|
||||||
|
}
|
||||||
+149
-43
@@ -2,6 +2,7 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -9,7 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/event"
|
"github.com/mostlygeek/llama-swap/internal/event"
|
||||||
"github.com/mostlygeek/llama-swap/internal/router"
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -18,13 +19,118 @@ const apiUnloadTimeout = 10 * time.Second
|
|||||||
|
|
||||||
// modelRecord is one entry in the OpenAI-compatible /v1/models listing.
|
// modelRecord is one entry in the OpenAI-compatible /v1/models listing.
|
||||||
type modelRecord struct {
|
type modelRecord struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
Created int64 `json:"created"`
|
Created int64 `json:"created"`
|
||||||
OwnedBy string `json:"owned_by"`
|
OwnedBy string `json:"owned_by"`
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
Description string `json:"description,omitempty"`
|
Description string `json:"description,omitempty"`
|
||||||
Meta map[string]any `json:"meta,omitempty"`
|
Architecture map[string]any `json:"architecture,omitempty"`
|
||||||
|
Capabilities map[string]any `json:"capabilities,omitempty"`
|
||||||
|
SupportedParameters []string `json:"supported_parameters,omitempty"`
|
||||||
|
ContextLength int `json:"context_length,omitempty"`
|
||||||
|
Meta map[string]any `json:"meta,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// cappedMetadataKeys are top-level /v1/models fields produced by the
|
||||||
|
// capabilities renderer. If a model's metadata block defines any of these
|
||||||
|
// keys, the renderer's values win and the metadata keys are dropped.
|
||||||
|
var cappedMetadataKeys = map[string]struct{}{
|
||||||
|
"architecture": {},
|
||||||
|
"capabilities": {},
|
||||||
|
"supported_parameters": {},
|
||||||
|
"context_length": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderCapabilities converts a model's capabilities config into additional
|
||||||
|
// /v1/models fields. Returns zero values when caps.Empty() is true.
|
||||||
|
func renderCapabilities(caps config.ModelCapConfig) (arch map[string]any, capsMap map[string]any, params []string, ctxLen int) {
|
||||||
|
if caps.Empty() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hasIn := len(caps.In) > 0
|
||||||
|
hasOut := len(caps.Out) > 0
|
||||||
|
|
||||||
|
if hasIn || hasOut {
|
||||||
|
arch = make(map[string]any)
|
||||||
|
}
|
||||||
|
if hasIn {
|
||||||
|
arch["input_modalities"] = caps.In
|
||||||
|
}
|
||||||
|
if hasOut {
|
||||||
|
arch["output_modalities"] = caps.Out
|
||||||
|
}
|
||||||
|
if hasIn && hasOut {
|
||||||
|
arch["modality"] = strings.Join(caps.In, "+") + "->" + strings.Join(caps.Out, "+")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build capabilities map only if there's something to put in it.
|
||||||
|
if hasIn || hasOut || caps.Tools || caps.Reranker {
|
||||||
|
capsMap = make(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasIn {
|
||||||
|
if contains(caps.In, "image") {
|
||||||
|
capsMap["vision"] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hasIn && hasOut {
|
||||||
|
if contains(caps.In, "audio") && contains(caps.Out, "text") {
|
||||||
|
capsMap["audio_transcriptions"] = true
|
||||||
|
}
|
||||||
|
if contains(caps.In, "text") && contains(caps.Out, "audio") {
|
||||||
|
capsMap["audio_speech"] = true
|
||||||
|
}
|
||||||
|
if contains(caps.In, "text") && contains(caps.Out, "image") {
|
||||||
|
capsMap["image_generation"] = true
|
||||||
|
}
|
||||||
|
if contains(caps.In, "image") && contains(caps.Out, "image") {
|
||||||
|
capsMap["image_to_image"] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if caps.Tools {
|
||||||
|
capsMap["function_calling"] = true
|
||||||
|
params = []string{"tools", "tool_choice"}
|
||||||
|
}
|
||||||
|
|
||||||
|
if caps.Reranker {
|
||||||
|
capsMap["reranker"] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if caps.Context > 0 {
|
||||||
|
ctxLen = caps.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// contains reports whether s is present in ss.
|
||||||
|
func contains(ss []string, s string) bool {
|
||||||
|
for _, v := range ss {
|
||||||
|
if v == s {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterCappedMetadata returns metadata with renderer-owned keys removed.
|
||||||
|
func filterCappedMetadata(md map[string]any) map[string]any {
|
||||||
|
if len(md) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
filtered := make(map[string]any, len(md))
|
||||||
|
for k, v := range md {
|
||||||
|
if _, capped := cappedMetadataKeys[k]; !capped {
|
||||||
|
filtered[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(filtered) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleListModels serves the OpenAI-compatible model listing: local models
|
// handleListModels serves the OpenAI-compatible model listing: local models
|
||||||
@@ -33,7 +139,7 @@ func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
|
|||||||
created := time.Now().Unix()
|
created := time.Now().Unix()
|
||||||
data := make([]modelRecord, 0, len(s.cfg.Models))
|
data := make([]modelRecord, 0, len(s.cfg.Models))
|
||||||
|
|
||||||
newRecord := func(id, name, description string, metadata map[string]any) modelRecord {
|
newRecord := func(id, name, description string, metadata map[string]any, caps config.ModelCapConfig) modelRecord {
|
||||||
rec := modelRecord{
|
rec := modelRecord{
|
||||||
ID: id,
|
ID: id,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -42,6 +148,10 @@ func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
|
|||||||
Name: strings.TrimSpace(name),
|
Name: strings.TrimSpace(name),
|
||||||
Description: strings.TrimSpace(description),
|
Description: strings.TrimSpace(description),
|
||||||
}
|
}
|
||||||
|
rec.Architecture, rec.Capabilities, rec.SupportedParameters, rec.ContextLength = renderCapabilities(caps)
|
||||||
|
if !caps.Empty() {
|
||||||
|
metadata = filterCappedMetadata(metadata)
|
||||||
|
}
|
||||||
if len(metadata) > 0 {
|
if len(metadata) > 0 {
|
||||||
rec.Meta = map[string]any{"llamaswap": metadata}
|
rec.Meta = map[string]any{"llamaswap": metadata}
|
||||||
}
|
}
|
||||||
@@ -52,12 +162,12 @@ func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
|
|||||||
if mc.Unlisted {
|
if mc.Unlisted {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
data = append(data, newRecord(id, mc.Name, mc.Description, mc.Metadata))
|
data = append(data, newRecord(id, mc.Name, mc.Description, mc.Metadata, mc.Capabilities))
|
||||||
|
|
||||||
if s.cfg.IncludeAliasesInList {
|
if s.cfg.IncludeAliasesInList {
|
||||||
for _, alias := range mc.Aliases {
|
for _, alias := range mc.Aliases {
|
||||||
if alias := strings.TrimSpace(alias); alias != "" {
|
if alias := strings.TrimSpace(alias); alias != "" {
|
||||||
data = append(data, newRecord(alias, mc.Name, mc.Description, mc.Metadata))
|
data = append(data, newRecord(alias, mc.Name, mc.Description, mc.Metadata, mc.Capabilities))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -65,7 +175,7 @@ func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
for peerID, peer := range s.cfg.Peers {
|
for peerID, peer := range s.cfg.Peers {
|
||||||
for _, modelID := range peer.Models {
|
for _, modelID := range peer.Models {
|
||||||
data = append(data, newRecord(modelID, peerID+": "+modelID, "", map[string]any{"peerID": peerID}))
|
data = append(data, newRecord(modelID, peerID+": "+modelID, "", map[string]any{"peerID": peerID}, config.ModelCapConfig{}))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -163,7 +273,7 @@ func (s *Server) startPreload() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
req = req.WithContext(router.SetContext(req.Context(), router.ReqContextData{Model: modelID, ModelID: modelID}))
|
req = req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: modelID, ModelID: modelID, Metadata: make(map[string]string)}))
|
||||||
|
|
||||||
dw := &discardResponseWriter{status: http.StatusOK}
|
dw := &discardResponseWriter{status: http.StatusOK}
|
||||||
s.local.ServeHTTP(dw, req)
|
s.local.ServeHTTP(dw, req)
|
||||||
@@ -206,9 +316,9 @@ func handleUpstreamRedirect(w http.ResponseWriter, r *http.Request) {
|
|||||||
func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
||||||
upstreamPath := r.PathValue("upstreamPath")
|
upstreamPath := r.PathValue("upstreamPath")
|
||||||
|
|
||||||
searchName, modelID, remainingPath, found := findModelInPath(s.cfg, "/"+upstreamPath)
|
searchName, modelID, remainingPath, found := shared.FindModelInPath(s.cfg, "/"+upstreamPath)
|
||||||
if !found {
|
if !found {
|
||||||
router.SendResponse(w, r, http.StatusNotFound, "model not found")
|
shared.SendResponse(w, r, http.StatusNotFound, "model not found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -230,7 +340,29 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Strip the /upstream/<model> prefix before forwarding.
|
// Strip the /upstream/<model> prefix before forwarding.
|
||||||
r.URL.Path = remainingPath
|
r.URL.Path = remainingPath
|
||||||
// Pin the resolved model so the router skips body/query extraction.
|
// Pin the resolved model so the router skips body/query extraction.
|
||||||
*r = *r.WithContext(router.SetContext(r.Context(), router.ReqContextData{Model: searchName, ModelID: modelID}))
|
*r = *r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: searchName, ModelID: modelID, Metadata: make(map[string]string)}))
|
||||||
|
|
||||||
|
// If the path matches an upstream.ignorePaths entry and the model is
|
||||||
|
// not already loaded, refuse the request without triggering a swap. The
|
||||||
|
// server was not able to process the response because the model was not
|
||||||
|
// already loaded.
|
||||||
|
for _, re := range s.cfg.Upstream.IgnorePaths {
|
||||||
|
if !re.MatchString(remainingPath) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if s.local.Handles(modelID) {
|
||||||
|
state, ok := s.local.RunningModels()[modelID]
|
||||||
|
if !ok || state != process.StateReady {
|
||||||
|
shared.SendResponse(w, r, http.StatusConflict,
|
||||||
|
fmt.Sprintf("model %s is not loaded; path matches upstream.ignorePaths", modelID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Either the model is already loaded (no swap would be triggered)
|
||||||
|
// or this is a peer model (peer proxying never swaps). Fall through
|
||||||
|
// to normal dispatch.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case s.local.Handles(modelID):
|
case s.local.Handles(modelID):
|
||||||
@@ -238,32 +370,6 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
|||||||
case s.peer.Handles(modelID):
|
case s.peer.Handles(modelID):
|
||||||
s.peer.ServeHTTP(w, r)
|
s.peer.ServeHTTP(w, r)
|
||||||
default:
|
default:
|
||||||
router.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID)
|
shared.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// findModelInPath walks a slash-separated path, building up segments until one
|
|
||||||
// matches a configured model. This resolves model names that contain slashes
|
|
||||||
// (e.g. "author/model"). Returns the matched name, its real model ID, the
|
|
||||||
// remaining path, and whether a match was found.
|
|
||||||
func findModelInPath(cfg config.Config, path string) (searchName, realName, remainingPath string, found bool) {
|
|
||||||
parts := strings.Split(strings.TrimSpace(path), "/")
|
|
||||||
name := ""
|
|
||||||
|
|
||||||
for i, part := range parts {
|
|
||||||
if part == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if name == "" {
|
|
||||||
name = part
|
|
||||||
} else {
|
|
||||||
name = name + "/" + part
|
|
||||||
}
|
|
||||||
|
|
||||||
if modelID, ok := cfg.RealModelName(name); ok {
|
|
||||||
return name, modelID, "/" + strings.Join(parts[i+1:], "/"), true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", "", "", false
|
|
||||||
}
|
|
||||||
|
|||||||
+428
-2
@@ -2,11 +2,17 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestServer_HandleListModels(t *testing.T) {
|
func TestServer_HandleListModels(t *testing.T) {
|
||||||
@@ -78,6 +84,7 @@ func TestServer_HandleListModels_Aliases(t *testing.T) {
|
|||||||
|
|
||||||
func TestServer_FindModelInPath(t *testing.T) {
|
func TestServer_FindModelInPath(t *testing.T) {
|
||||||
cfg := config.Config{Models: map[string]config.ModelConfig{
|
cfg := config.Config{Models: map[string]config.ModelConfig{
|
||||||
|
"author": {},
|
||||||
"author/model": {},
|
"author/model": {},
|
||||||
"simple": {},
|
"simple": {},
|
||||||
}}
|
}}
|
||||||
@@ -91,13 +98,14 @@ func TestServer_FindModelInPath(t *testing.T) {
|
|||||||
{"/simple/v1/chat", "simple", "/v1/chat", true},
|
{"/simple/v1/chat", "simple", "/v1/chat", true},
|
||||||
{"/author/model/v1/chat", "author/model", "/v1/chat", true},
|
{"/author/model/v1/chat", "author/model", "/v1/chat", true},
|
||||||
{"/author/model", "author/model", "/", true},
|
{"/author/model", "author/model", "/", true},
|
||||||
|
{"/author/v1/chat", "author", "/v1/chat", true},
|
||||||
{"/missing/v1", "", "", false},
|
{"/missing/v1", "", "", false},
|
||||||
{"/", "", "", false},
|
{"/", "", "", false},
|
||||||
}
|
}
|
||||||
for _, c := range cases {
|
for _, c := range cases {
|
||||||
name, _, rem, found := findModelInPath(cfg, c.path)
|
name, _, rem, found := shared.FindModelInPath(cfg, c.path)
|
||||||
if found != c.wantFound || name != c.wantName || (found && rem != c.wantRem) {
|
if found != c.wantFound || name != c.wantName || (found && rem != c.wantRem) {
|
||||||
t.Errorf("findModelInPath(%q) = (%q,%q,%v), want (%q,%q,%v)",
|
t.Errorf("FindModelInPath(%q) = (%q,%q,%v), want (%q,%q,%v)",
|
||||||
c.path, name, rem, found, c.wantName, c.wantRem, c.wantFound)
|
c.path, name, rem, found, c.wantName, c.wantRem, c.wantFound)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -133,6 +141,165 @@ func TestServer_HandleUpstream(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func upstreamMetricsServer(response string) *Server {
|
||||||
|
cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
|
||||||
|
proxylog := logmon.NewWriter(io.Discard)
|
||||||
|
s := &Server{
|
||||||
|
cfg: cfg,
|
||||||
|
muxlog: logmon.NewWriter(io.Discard),
|
||||||
|
proxylog: proxylog,
|
||||||
|
upstreamlog: logmon.NewWriter(io.Discard),
|
||||||
|
inflight: &inflightCounter{},
|
||||||
|
metrics: newMetricsMonitor(proxylog, 10, 0),
|
||||||
|
local: newStubRouter([]string{"m1"}, response),
|
||||||
|
peer: newStubRouter(nil, ""),
|
||||||
|
}
|
||||||
|
s.routes()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_HandleUpstream_IgnorePaths(t *testing.T) {
|
||||||
|
// Compile a pattern that matches static asset suffixes.
|
||||||
|
pattern := regexp.MustCompile(`.*\.(js|json|css|png|gif|jpg|jpeg|txt)$`)
|
||||||
|
|
||||||
|
t.Run("matched path, model not loaded, returns 409", func(t *testing.T) {
|
||||||
|
local := newStubRouter([]string{"m1"}, "upstream-body")
|
||||||
|
// running is nil/empty: model is not in RunningModels() => not loaded.
|
||||||
|
s := newTestServer(local, newStubRouter(nil, ""))
|
||||||
|
s.cfg = config.Config{
|
||||||
|
Models: map[string]config.ModelConfig{"m1": {}},
|
||||||
|
Upstream: config.UpstreamConfig{
|
||||||
|
IgnorePaths: []*regexp.Regexp{pattern},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/foo.js", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusConflict {
|
||||||
|
t.Fatalf("status = %d, want %d (body=%q)", w.Code, http.StatusConflict, w.Body.String())
|
||||||
|
}
|
||||||
|
if !strings.Contains(w.Body.String(), "not loaded") {
|
||||||
|
t.Errorf("body = %q, want it to contain 'not loaded'", w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("matched path, model already loaded, serves normally", func(t *testing.T) {
|
||||||
|
local := newStubRouter([]string{"m1"}, "upstream-body")
|
||||||
|
local.running = map[string]process.ProcessState{"m1": process.StateReady}
|
||||||
|
s := newTestServer(local, newStubRouter(nil, ""))
|
||||||
|
s.cfg = config.Config{
|
||||||
|
Models: map[string]config.ModelConfig{"m1": {}},
|
||||||
|
Upstream: config.UpstreamConfig{
|
||||||
|
IgnorePaths: []*regexp.Regexp{pattern},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/foo.js", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK || w.Body.String() != "upstream-body" {
|
||||||
|
t.Fatalf("status=%d body=%q, want 200 'upstream-body'", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-matched path, model not loaded, serves normally", func(t *testing.T) {
|
||||||
|
local := newStubRouter([]string{"m1"}, "upstream-body")
|
||||||
|
s := newTestServer(local, newStubRouter(nil, ""))
|
||||||
|
s.cfg = config.Config{
|
||||||
|
Models: map[string]config.ModelConfig{"m1": {}},
|
||||||
|
Upstream: config.UpstreamConfig{
|
||||||
|
IgnorePaths: []*regexp.Regexp{pattern},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/v1/chat/completions", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK || w.Body.String() != "upstream-body" {
|
||||||
|
t.Fatalf("status=%d body=%q, want 200 'upstream-body'", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("matched path, peer model, serves normally", func(t *testing.T) {
|
||||||
|
// Peer routers do not appear via RunningModels on the local router;
|
||||||
|
// they should fall through to normal dispatch without 409.
|
||||||
|
local := newStubRouter(nil, "")
|
||||||
|
peer := newStubRouter([]string{"m1"}, "peer-body")
|
||||||
|
s := newTestServer(local, peer)
|
||||||
|
s.cfg = config.Config{
|
||||||
|
Models: map[string]config.ModelConfig{"m1": {}},
|
||||||
|
Upstream: config.UpstreamConfig{
|
||||||
|
IgnorePaths: []*regexp.Regexp{pattern},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/foo.js", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK || w.Body.String() != "peer-body" {
|
||||||
|
t.Fatalf("status=%d body=%q, want 200 'peer-body'", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_HandleUpstream_MetricsRecordsSupportedPath(t *testing.T) {
|
||||||
|
resp := `{"usage":{"prompt_tokens":3,"completion_tokens":5}}`
|
||||||
|
s := upstreamMetricsServer(resp)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/chat/completions", strings.NewReader(`{}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
s.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK || w.Body.String() != resp {
|
||||||
|
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
entries := s.metrics.getMetrics()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("want 1 metrics entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Model != "m1" {
|
||||||
|
t.Errorf("model = %q, want m1", entries[0].Model)
|
||||||
|
}
|
||||||
|
if entries[0].ReqPath != "/v1/chat/completions" {
|
||||||
|
t.Errorf("req_path = %q, want /v1/chat/completions", entries[0].ReqPath)
|
||||||
|
}
|
||||||
|
if entries[0].Tokens.InputTokens != 3 || entries[0].Tokens.OutputTokens != 5 {
|
||||||
|
t.Errorf("tokens = %+v, want input=3 output=5", entries[0].Tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_HandleUpstream_MetricsSkipsUnsupportedPath(t *testing.T) {
|
||||||
|
s := upstreamMetricsServer("ok")
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/upstream/m1/probe", strings.NewReader(`{}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
s.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK || w.Body.String() != "ok" {
|
||||||
|
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if len(s.metrics.getMetrics()) != 0 {
|
||||||
|
t.Errorf("want no metrics entries for unsupported path, got %d", len(s.metrics.getMetrics()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_HandleUpstream_MetricsSkipsGET(t *testing.T) {
|
||||||
|
s := upstreamMetricsServer(`{"usage":{}}`)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/v1/chat/completions", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d", w.Code)
|
||||||
|
}
|
||||||
|
if len(s.metrics.getMetrics()) != 0 {
|
||||||
|
t.Errorf("want no metrics entries for GET upstream, got %d", len(s.metrics.getMetrics()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestServer_HandleMetrics_Unavailable(t *testing.T) {
|
func TestServer_HandleMetrics_Unavailable(t *testing.T) {
|
||||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||||
|
|
||||||
@@ -157,3 +324,262 @@ func TestServer_Redirects(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServer_HandleListModels_Capabilities(t *testing.T) {
|
||||||
|
newServer := func(mc config.ModelConfig) *Server {
|
||||||
|
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||||
|
s.cfg = config.Config{Models: map[string]config.ModelConfig{"m": mc}}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
getModel := func(t *testing.T, s *Server) modelRecord {
|
||||||
|
t.Helper()
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/v1/models", nil))
|
||||||
|
var resp struct {
|
||||||
|
Data []modelRecord `json:"data"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if len(resp.Data) != 1 {
|
||||||
|
t.Fatalf("expected 1 model, got %d", len(resp.Data))
|
||||||
|
}
|
||||||
|
return resp.Data[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("all_fields", func(t *testing.T) {
|
||||||
|
m := getModel(t, newServer(config.ModelConfig{
|
||||||
|
Capabilities: config.ModelCapConfig{
|
||||||
|
In: []string{"text", "image"},
|
||||||
|
Out: []string{"text", "audio"},
|
||||||
|
Tools: true,
|
||||||
|
Context: 100000,
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
if m.Architecture == nil {
|
||||||
|
t.Fatal("architecture is nil")
|
||||||
|
}
|
||||||
|
if !anySliceStrEqual(m.Architecture["input_modalities"], []string{"text", "image"}) {
|
||||||
|
t.Errorf("input_modalities = %v", m.Architecture["input_modalities"])
|
||||||
|
}
|
||||||
|
if !anySliceStrEqual(m.Architecture["output_modalities"], []string{"text", "audio"}) {
|
||||||
|
t.Errorf("output_modalities = %v", m.Architecture["output_modalities"])
|
||||||
|
}
|
||||||
|
if m.Architecture["modality"] != "text+image->text+audio" {
|
||||||
|
t.Errorf("modality = %v", m.Architecture["modality"])
|
||||||
|
}
|
||||||
|
if m.Capabilities == nil || m.Capabilities["vision"] != true {
|
||||||
|
t.Errorf("vision = %v", m.Capabilities)
|
||||||
|
}
|
||||||
|
if m.Capabilities["audio_speech"] != true {
|
||||||
|
t.Errorf("audio_speech = %v", m.Capabilities["audio_speech"])
|
||||||
|
}
|
||||||
|
if m.Capabilities["function_calling"] != true {
|
||||||
|
t.Errorf("function_calling = %v", m.Capabilities["function_calling"])
|
||||||
|
}
|
||||||
|
if !stringSliceEqual(m.SupportedParameters, []string{"tools", "tool_choice"}) {
|
||||||
|
t.Errorf("supported_parameters = %v", m.SupportedParameters)
|
||||||
|
}
|
||||||
|
if m.ContextLength != 100000 {
|
||||||
|
t.Errorf("context_length = %d", m.ContextLength)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("in_only", func(t *testing.T) {
|
||||||
|
m := getModel(t, newServer(config.ModelConfig{
|
||||||
|
Capabilities: config.ModelCapConfig{In: []string{"text", "image"}},
|
||||||
|
}))
|
||||||
|
if m.Architecture == nil {
|
||||||
|
t.Fatal("architecture is nil")
|
||||||
|
}
|
||||||
|
if _, ok := m.Architecture["output_modalities"]; ok {
|
||||||
|
t.Error("should not have output_modalities")
|
||||||
|
}
|
||||||
|
if _, ok := m.Architecture["modality"]; ok {
|
||||||
|
t.Error("should not have modality")
|
||||||
|
}
|
||||||
|
if m.Capabilities == nil || m.Capabilities["vision"] != true {
|
||||||
|
t.Error("expected vision: true")
|
||||||
|
}
|
||||||
|
if m.SupportedParameters != nil {
|
||||||
|
t.Error("should not have supported_parameters")
|
||||||
|
}
|
||||||
|
if m.ContextLength != 0 {
|
||||||
|
t.Error("should not have context_length")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("out_only", func(t *testing.T) {
|
||||||
|
m := getModel(t, newServer(config.ModelConfig{
|
||||||
|
Capabilities: config.ModelCapConfig{Out: []string{"audio"}},
|
||||||
|
}))
|
||||||
|
if m.Architecture == nil {
|
||||||
|
t.Fatal("architecture is nil")
|
||||||
|
}
|
||||||
|
if _, ok := m.Architecture["input_modalities"]; ok {
|
||||||
|
t.Error("should not have input_modalities")
|
||||||
|
}
|
||||||
|
if len(m.Capabilities) > 0 {
|
||||||
|
t.Errorf("expected no capabilities, got %v", m.Capabilities)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("tools", func(t *testing.T) {
|
||||||
|
m := getModel(t, newServer(config.ModelConfig{
|
||||||
|
Capabilities: config.ModelCapConfig{Tools: true},
|
||||||
|
}))
|
||||||
|
if m.Capabilities == nil || m.Capabilities["function_calling"] != true {
|
||||||
|
t.Error("expected function_calling: true")
|
||||||
|
}
|
||||||
|
if !stringSliceEqual(m.SupportedParameters, []string{"tools", "tool_choice"}) {
|
||||||
|
t.Errorf("supported_parameters = %v", m.SupportedParameters)
|
||||||
|
}
|
||||||
|
if m.Architecture != nil {
|
||||||
|
t.Error("should not have architecture")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("reranker", func(t *testing.T) {
|
||||||
|
m := getModel(t, newServer(config.ModelConfig{
|
||||||
|
Capabilities: config.ModelCapConfig{Reranker: true},
|
||||||
|
}))
|
||||||
|
if m.Capabilities == nil || m.Capabilities["reranker"] != true {
|
||||||
|
t.Error("expected reranker: true")
|
||||||
|
}
|
||||||
|
if m.Architecture != nil {
|
||||||
|
t.Error("should not have architecture")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("context", func(t *testing.T) {
|
||||||
|
m := getModel(t, newServer(config.ModelConfig{
|
||||||
|
Capabilities: config.ModelCapConfig{Context: 32768},
|
||||||
|
}))
|
||||||
|
if m.ContextLength != 32768 {
|
||||||
|
t.Errorf("context_length = %d", m.ContextLength)
|
||||||
|
}
|
||||||
|
if m.Architecture != nil {
|
||||||
|
t.Error("should not have architecture")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("audio_transcriptions", func(t *testing.T) {
|
||||||
|
m := getModel(t, newServer(config.ModelConfig{
|
||||||
|
Capabilities: config.ModelCapConfig{In: []string{"audio"}, Out: []string{"text"}},
|
||||||
|
}))
|
||||||
|
if m.Capabilities == nil || m.Capabilities["audio_transcriptions"] != true {
|
||||||
|
t.Error("expected audio_transcriptions: true")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("image_generation", func(t *testing.T) {
|
||||||
|
m := getModel(t, newServer(config.ModelConfig{
|
||||||
|
Capabilities: config.ModelCapConfig{In: []string{"text"}, Out: []string{"image"}},
|
||||||
|
}))
|
||||||
|
if m.Capabilities == nil || m.Capabilities["image_generation"] != true {
|
||||||
|
t.Error("expected image_generation: true")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("image_to_image", func(t *testing.T) {
|
||||||
|
m := getModel(t, newServer(config.ModelConfig{
|
||||||
|
Capabilities: config.ModelCapConfig{In: []string{"image"}, Out: []string{"image"}},
|
||||||
|
}))
|
||||||
|
if m.Capabilities == nil || m.Capabilities["image_to_image"] != true {
|
||||||
|
t.Error("expected image_to_image: true")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty_skip", func(t *testing.T) {
|
||||||
|
m := getModel(t, newServer(config.ModelConfig{}))
|
||||||
|
if m.Architecture != nil {
|
||||||
|
t.Error("should not have architecture")
|
||||||
|
}
|
||||||
|
if m.Capabilities != nil {
|
||||||
|
t.Error("should not have capabilities")
|
||||||
|
}
|
||||||
|
if m.SupportedParameters != nil {
|
||||||
|
t.Error("should not have supported_parameters")
|
||||||
|
}
|
||||||
|
if m.ContextLength != 0 {
|
||||||
|
t.Error("should not have context_length")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("metadata_precedence", func(t *testing.T) {
|
||||||
|
m := getModel(t, newServer(config.ModelConfig{
|
||||||
|
Capabilities: config.ModelCapConfig{In: []string{"text"}},
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"architecture": "should-be-dropped",
|
||||||
|
"custom_field": "should-remain",
|
||||||
|
"capabilities": "also-dropped",
|
||||||
|
"other_metadata": "also-remain",
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
if m.Architecture == nil || m.Architecture["input_modalities"] == nil {
|
||||||
|
t.Fatal("architecture should be rendered, not from metadata")
|
||||||
|
}
|
||||||
|
if m.Meta == nil || m.Meta["llamaswap"] == nil {
|
||||||
|
t.Fatal("meta.llamaswap should exist")
|
||||||
|
}
|
||||||
|
meta := m.Meta["llamaswap"].(map[string]any)
|
||||||
|
if _, ok := meta["architecture"]; ok {
|
||||||
|
t.Error("architecture should be filtered from metadata")
|
||||||
|
}
|
||||||
|
if _, ok := meta["custom_field"]; !ok {
|
||||||
|
t.Error("custom_field should remain in metadata")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("metadata_passthrough_no_caps", func(t *testing.T) {
|
||||||
|
m := getModel(t, newServer(config.ModelConfig{
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"architecture": "preserved",
|
||||||
|
"context_length": 4096,
|
||||||
|
"capabilities": "preserved",
|
||||||
|
"custom_field": "preserved",
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
if m.Architecture != nil {
|
||||||
|
t.Error("should not have architecture when caps is empty")
|
||||||
|
}
|
||||||
|
if m.Meta == nil || m.Meta["llamaswap"] == nil {
|
||||||
|
t.Fatal("meta.llamaswap should exist")
|
||||||
|
}
|
||||||
|
meta := m.Meta["llamaswap"].(map[string]any)
|
||||||
|
if _, ok := meta["architecture"]; !ok {
|
||||||
|
t.Error("architecture should be preserved in metadata when caps is empty")
|
||||||
|
}
|
||||||
|
if _, ok := meta["context_length"]; !ok {
|
||||||
|
t.Error("context_length should be preserved in metadata when caps is empty")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringSliceEqual(a, b []string) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if a[i] != b[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func anySliceStrEqual(v any, want []string) bool {
|
||||||
|
arr, ok := v.([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(arr) != len(want) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range arr {
|
||||||
|
if s, ok := arr[i].(string); !ok || s != want[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|||||||
+28
-23
@@ -12,19 +12,19 @@ import (
|
|||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/event"
|
"github.com/mostlygeek/llama-swap/internal/event"
|
||||||
"github.com/mostlygeek/llama-swap/internal/perf"
|
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||||
"github.com/mostlygeek/llama-swap/internal/router"
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
// apiModel is one entry in the /api/events modelStatus payload.
|
// apiModel is one entry in the /api/events modelStatus payload.
|
||||||
type apiModel struct {
|
type apiModel struct {
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
State string `json:"state"`
|
State string `json:"state"`
|
||||||
Unlisted bool `json:"unlisted"`
|
Unlisted bool `json:"unlisted"`
|
||||||
PeerID string `json:"peerID"`
|
PeerID string `json:"peerID"`
|
||||||
Aliases []string `json:"aliases,omitempty"`
|
Aliases []string `json:"aliases,omitempty"`
|
||||||
|
Capabilities map[string]any `json:"capabilities,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// modelStatus returns every configured model joined with its current process
|
// modelStatus returns every configured model joined with its current process
|
||||||
@@ -45,13 +45,15 @@ func (s *Server) modelStatus() []apiModel {
|
|||||||
if st, ok := running[id]; ok {
|
if st, ok := running[id]; ok {
|
||||||
state = string(st)
|
state = string(st)
|
||||||
}
|
}
|
||||||
|
_, capsMap, _, _ := renderCapabilities(mc.Capabilities)
|
||||||
models = append(models, apiModel{
|
models = append(models, apiModel{
|
||||||
Id: id,
|
Id: id,
|
||||||
Name: mc.Name,
|
Name: mc.Name,
|
||||||
Description: mc.Description,
|
Description: mc.Description,
|
||||||
State: state,
|
State: state,
|
||||||
Unlisted: mc.Unlisted,
|
Unlisted: mc.Unlisted,
|
||||||
Aliases: mc.Aliases,
|
Aliases: mc.Aliases,
|
||||||
|
Capabilities: capsMap,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,11 +78,11 @@ func (s *Server) handleAPIUnloadModel(w http.ResponseWriter, r *http.Request) {
|
|||||||
requested := strings.TrimPrefix(r.PathValue("model"), "/")
|
requested := strings.TrimPrefix(r.PathValue("model"), "/")
|
||||||
realName, found := s.cfg.RealModelName(requested)
|
realName, found := s.cfg.RealModelName(requested)
|
||||||
if !found {
|
if !found {
|
||||||
router.SendResponse(w, r, http.StatusNotFound, "model not found")
|
shared.SendResponse(w, r, http.StatusNotFound, "model not found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !s.local.Handles(realName) {
|
if !s.local.Handles(realName) {
|
||||||
router.SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
|
shared.SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.local.Unload(apiUnloadTimeout, realName)
|
s.local.Unload(apiUnloadTimeout, realName)
|
||||||
@@ -92,7 +94,7 @@ func (s *Server) handleAPIUnloadModel(w http.ResponseWriter, r *http.Request) {
|
|||||||
func (s *Server) handleAPIMetrics(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleAPIMetrics(w http.ResponseWriter, r *http.Request) {
|
||||||
data, err := s.metrics.getMetricsJSON()
|
data, err := s.metrics.getMetricsJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendResponse(w, r, http.StatusInternalServerError, "failed to get metrics")
|
shared.SendResponse(w, r, http.StatusInternalServerError, "failed to get metrics")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
@@ -103,7 +105,9 @@ func (s *Server) handleAPIMetrics(w http.ResponseWriter, r *http.Request) {
|
|||||||
// filtered to samples after the ?after=<RFC3339> timestamp.
|
// filtered to samples after the ?after=<RFC3339> timestamp.
|
||||||
func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
||||||
if s.perf == nil {
|
if s.perf == nil {
|
||||||
router.SendResponse(w, r, http.StatusServiceUnavailable, "performance monitor not available")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
json.NewEncoder(w).Encode(map[string]bool{"enabled": false})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,7 +116,7 @@ func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
|||||||
if afterStr := r.URL.Query().Get("after"); afterStr != "" {
|
if afterStr := r.URL.Query().Get("after"); afterStr != "" {
|
||||||
after, err := time.Parse(time.RFC3339, afterStr)
|
after, err := time.Parse(time.RFC3339, afterStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendResponse(w, r, http.StatusBadRequest, "invalid 'after' timestamp, use RFC3339 format")
|
shared.SendResponse(w, r, http.StatusBadRequest, "invalid 'after' timestamp, use RFC3339 format")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
filteredSys := make([]perf.SysStat, 0, len(sysStats))
|
filteredSys := make([]perf.SysStat, 0, len(sysStats))
|
||||||
@@ -134,6 +138,7 @@ func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
json.NewEncoder(w).Encode(map[string]any{
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"enabled": true,
|
||||||
"sys_stats": sysStats,
|
"sys_stats": sysStats,
|
||||||
"gpu_stats": gpuStats,
|
"gpu_stats": gpuStats,
|
||||||
})
|
})
|
||||||
@@ -153,19 +158,19 @@ func (s *Server) handleAPIVersion(w http.ResponseWriter, r *http.Request) {
|
|||||||
func (s *Server) handleAPICapture(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleAPICapture(w http.ResponseWriter, r *http.Request) {
|
||||||
id, err := strconv.Atoi(r.PathValue("id"))
|
id, err := strconv.Atoi(r.PathValue("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendResponse(w, r, http.StatusBadRequest, "invalid capture ID")
|
shared.SendResponse(w, r, http.StatusBadRequest, "invalid capture ID")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
capture := s.metrics.getCaptureByID(id)
|
capture := s.metrics.getCaptureByID(id)
|
||||||
if capture == nil {
|
if capture == nil {
|
||||||
router.SendResponse(w, r, http.StatusNotFound, "capture not found")
|
shared.SendResponse(w, r, http.StatusNotFound, "capture not found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
jsonBytes, err := json.Marshal(capture)
|
jsonBytes, err := json.Marshal(capture)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendResponse(w, r, http.StatusInternalServerError, "failed to marshal capture")
|
shared.SendResponse(w, r, http.StatusInternalServerError, "failed to marshal capture")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
@@ -198,7 +203,7 @@ func (s *Server) handleAPIEvents(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
flusher, ok := w.(http.Flusher)
|
flusher, ok := w.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
router.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
|
shared.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+17
-31
@@ -1,19 +1,17 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/router"
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CreateAuthMiddleware returns middleware that validates API keys when the
|
// CreateAuthMiddleware returns middleware that validates API keys when the
|
||||||
// config declares any. It accepts the key via Authorization: Bearer,
|
// config declares any. It accepts the key via Authorization: Bearer,
|
||||||
// Authorization: Basic (password field), or x-api-key. On success the auth
|
// Authorization: Basic (password field), or x-api-key. When no keys are
|
||||||
// headers are stripped so they never leak to upstream. When no keys are
|
|
||||||
// configured the middleware is a pass-through.
|
// configured the middleware is a pass-through.
|
||||||
func CreateAuthMiddleware(cfg config.Config) chain.Middleware {
|
func CreateAuthMiddleware(cfg config.Config) chain.Middleware {
|
||||||
keys := cfg.RequiredAPIKeys
|
keys := cfg.RequiredAPIKeys
|
||||||
@@ -22,7 +20,7 @@ func CreateAuthMiddleware(cfg config.Config) chain.Middleware {
|
|||||||
return next
|
return next
|
||||||
}
|
}
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
provided := extractAPIKey(r)
|
provided := shared.ExtractAPIKey(r)
|
||||||
|
|
||||||
valid := false
|
valid := false
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
@@ -33,41 +31,29 @@ func CreateAuthMiddleware(cfg config.Config) chain.Middleware {
|
|||||||
}
|
}
|
||||||
if !valid {
|
if !valid {
|
||||||
w.Header().Set("WWW-Authenticate", `Basic realm="llama-swap"`)
|
w.Header().Set("WWW-Authenticate", `Basic realm="llama-swap"`)
|
||||||
router.SendResponse(w, r, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
|
shared.SendResponse(w, r, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Header.Del("Authorization")
|
|
||||||
r.Header.Del("x-api-key")
|
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractAPIKey pulls a candidate API key from the request, preferring Basic,
|
// CreateRequestContextMiddleware returns middleware that extracts model and
|
||||||
// then Bearer, then x-api-key.
|
// auth info from the request into the context. Requests where no model can be
|
||||||
func extractAPIKey(r *http.Request) string {
|
// identified are rejected with a 404.
|
||||||
var bearerKey, basicKey string
|
func CreateRequestContextMiddleware(cfg config.Config) chain.Middleware {
|
||||||
if auth := r.Header.Get("Authorization"); auth != "" {
|
return func(next http.Handler) http.Handler {
|
||||||
if strings.HasPrefix(auth, "Bearer ") {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
bearerKey = strings.TrimPrefix(auth, "Bearer ")
|
data, err := shared.FetchContext(r, cfg)
|
||||||
} else if strings.HasPrefix(auth, "Basic ") {
|
if err != nil {
|
||||||
encoded := strings.TrimPrefix(auth, "Basic ")
|
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||||
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
|
return
|
||||||
if parts := strings.SplitN(string(decoded), ":", 2); len(parts) == 2 {
|
|
||||||
basicKey = parts[1] // password field is the API key
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
_ = data
|
||||||
}
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
switch {
|
|
||||||
case basicKey != "":
|
|
||||||
return basicKey
|
|
||||||
case bearerKey != "":
|
|
||||||
return bearerKey
|
|
||||||
default:
|
|
||||||
return r.Header.Get("x-api-key")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,48 +1,14 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestServer_ExtractAPIKey(t *testing.T) {
|
|
||||||
basicHeader := func(user, pass string) string {
|
|
||||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
|
|
||||||
}
|
|
||||||
cases := []struct {
|
|
||||||
name string
|
|
||||||
auth string
|
|
||||||
xapi string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"none", "", "", ""},
|
|
||||||
{"bearer", "Bearer tok123", "", "tok123"},
|
|
||||||
{"basic", basicHeader("user", "pw-key"), "", "pw-key"},
|
|
||||||
{"x-api-key", "", "xkey", "xkey"},
|
|
||||||
{"basic beats bearer", basicHeader("u", "bk"), "", "bk"},
|
|
||||||
{"bearer beats x-api-key", "Bearer btok", "xkey", "btok"},
|
|
||||||
{"malformed basic falls back to x-api-key", "Basic !!!notbase64", "xkey", "xkey"},
|
|
||||||
}
|
|
||||||
for _, c := range cases {
|
|
||||||
t.Run(c.name, func(t *testing.T) {
|
|
||||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
if c.auth != "" {
|
|
||||||
r.Header.Set("Authorization", c.auth)
|
|
||||||
}
|
|
||||||
if c.xapi != "" {
|
|
||||||
r.Header.Set("x-api-key", c.xapi)
|
|
||||||
}
|
|
||||||
if got := extractAPIKey(r); got != c.want {
|
|
||||||
t.Errorf("extractAPIKey() = %q, want %q", got, c.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServer_SanitizeAccessControlRequestHeaders(t *testing.T) {
|
func TestServer_SanitizeAccessControlRequestHeaders(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
in string
|
in string
|
||||||
@@ -74,11 +40,42 @@ func TestServer_IsTokenChar(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServer_RequestContextMiddleware(t *testing.T) {
|
||||||
|
cfg := config.Config{
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"llama3": {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
mw := CreateRequestContextMiddleware(cfg)
|
||||||
|
|
||||||
|
t.Run("known model passes through", func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"model":"llama3"}`))
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
mw(final).ServeHTTP(w, r)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("status = %d, want 200", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing model returns 404", func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{}`))
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
mw(final).ServeHTTP(w, r)
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Errorf("status = %d, want 404", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestServer_AuthMiddleware(t *testing.T) {
|
func TestServer_AuthMiddleware(t *testing.T) {
|
||||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Header.Get("Authorization") != "" || r.Header.Get("x-api-key") != "" {
|
|
||||||
t.Error("auth headers leaked to upstream")
|
|
||||||
}
|
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -1,57 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"golang.org/x/sync/semaphore"
|
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/router"
|
|
||||||
)
|
|
||||||
|
|
||||||
// defaultConcurrencyLimit caps simultaneous in-flight requests per model when
|
|
||||||
// the model config leaves concurrencyLimit unset. Matches the legacy
|
|
||||||
// proxy.Process default.
|
|
||||||
const defaultConcurrencyLimit = 10
|
|
||||||
|
|
||||||
// CreateConcurrencyMiddleware returns middleware that limits simultaneous
|
|
||||||
// model-dispatched requests per model. Each model gets a semaphore sized to
|
|
||||||
// its concurrencyLimit (or defaultConcurrencyLimit). A request that cannot
|
|
||||||
// immediately acquire a slot is rejected with 429. Models without a local
|
|
||||||
// config entry (e.g. peer-routed models) are not limited.
|
|
||||||
func CreateConcurrencyMiddleware(cfg config.Config) chain.Middleware {
|
|
||||||
semaphores := make(map[string]*semaphore.Weighted, len(cfg.Models))
|
|
||||||
for id, mc := range cfg.Models {
|
|
||||||
limit := defaultConcurrencyLimit
|
|
||||||
if mc.ConcurrencyLimit > 0 {
|
|
||||||
limit = mc.ConcurrencyLimit
|
|
||||||
}
|
|
||||||
semaphores[id] = semaphore.NewWeighted(int64(limit))
|
|
||||||
}
|
|
||||||
|
|
||||||
return func(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
data, err := router.FetchContext(r, cfg)
|
|
||||||
if err != nil {
|
|
||||||
router.SendError(w, r, router.ErrNoModelInContext)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// fall through for peer models
|
|
||||||
sem, ok := semaphores[data.ModelID]
|
|
||||||
if !ok {
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !sem.TryAcquire(1) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusTooManyRequests)
|
|
||||||
w.Write([]byte(`{"error":"Too many requests"}`))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer sem.Release(1)
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,75 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/router"
|
|
||||||
)
|
|
||||||
|
|
||||||
func concurrencyTestReq(model string) *http.Request {
|
|
||||||
r := httptest.NewRequest("GET", "/v1/chat/completions", nil)
|
|
||||||
return r.WithContext(router.SetContext(r.Context(), router.ReqContextData{Model: model, ModelID: model}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServer_ConcurrencyMiddleware_RejectsOverLimit(t *testing.T) {
|
|
||||||
cfg := config.Config{
|
|
||||||
Models: map[string]config.ModelConfig{
|
|
||||||
"m1": {ConcurrencyLimit: 1},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
entered := make(chan struct{})
|
|
||||||
release := make(chan struct{})
|
|
||||||
var once sync.Once
|
|
||||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
once.Do(func() { close(entered) })
|
|
||||||
<-release
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
h := CreateConcurrencyMiddleware(cfg)(final)
|
|
||||||
|
|
||||||
// First request occupies the only slot.
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer close(done)
|
|
||||||
h.ServeHTTP(httptest.NewRecorder(), concurrencyTestReq("m1"))
|
|
||||||
}()
|
|
||||||
<-entered
|
|
||||||
|
|
||||||
// Second concurrent request is rejected with 429.
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
h.ServeHTTP(w, concurrencyTestReq("m1"))
|
|
||||||
if w.Code != http.StatusTooManyRequests {
|
|
||||||
t.Fatalf("over-limit status = %d, want 429", w.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Once the slot frees, a new request succeeds.
|
|
||||||
close(release)
|
|
||||||
<-done
|
|
||||||
w = httptest.NewRecorder()
|
|
||||||
h.ServeHTTP(w, concurrencyTestReq("m1"))
|
|
||||||
if w.Code != http.StatusOK {
|
|
||||||
t.Fatalf("post-release status = %d, want 200", w.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServer_ConcurrencyMiddleware_UnconfiguredModelPassesThrough(t *testing.T) {
|
|
||||||
cfg := config.Config{Models: map[string]config.ModelConfig{}}
|
|
||||||
|
|
||||||
called := 0
|
|
||||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
called++
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
h := CreateConcurrencyMiddleware(cfg)(final)
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
h.ServeHTTP(w, concurrencyTestReq("peer-model"))
|
|
||||||
if w.Code != http.StatusOK || called != 1 {
|
|
||||||
t.Fatalf("unconfigured model: status=%d called=%d, want 200/1", w.Code, called)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/router"
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -34,9 +34,9 @@ func CreateFilterMiddleware(cfg config.Config) chain.Middleware {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := router.FetchContext(r, cfg)
|
data, err := shared.FetchContext(r, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendError(w, r, router.ErrNoModelInContext)
|
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,13 +48,13 @@ func CreateFilterMiddleware(cfg config.Config) chain.Middleware {
|
|||||||
|
|
||||||
body, err := io.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendResponse(w, r, http.StatusBadRequest, "could not read request body")
|
shared.SendResponse(w, r, http.StatusBadRequest, "could not read request body")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err = applyFilters(body, data.Model, useModelName, filters)
|
body, err = applyFilters(body, data.Model, useModelName, filters)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendResponse(w, r, http.StatusInternalServerError, err.Error())
|
shared.SendResponse(w, r, http.StatusInternalServerError, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,9 +84,9 @@ func CreateFormFilterMiddleware(cfg config.Config) chain.Middleware {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := router.FetchContext(r, cfg)
|
data, err := shared.FetchContext(r, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendError(w, r, router.ErrNoModelInContext)
|
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -97,13 +97,13 @@ func CreateFormFilterMiddleware(cfg config.Config) chain.Middleware {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||||
router.SendResponse(w, r, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
shared.SendResponse(w, r, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
body, contentType, err := rewriteMultipartModel(r.MultipartForm, useModelName)
|
body, contentType, err := rewriteMultipartModel(r.MultipartForm, useModelName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendResponse(w, r, http.StatusInternalServerError, err.Error())
|
shared.SendResponse(w, r, http.StatusInternalServerError, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
"github.com/mostlygeek/llama-swap/internal/router"
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewLoggers builds the proxy, upstream, and combined (mux) log monitors,
|
// NewLoggers builds the proxy, upstream, and combined (mux) log monitors,
|
||||||
@@ -76,7 +76,7 @@ func (s *Server) getLogger(logMonitorID string) (*logmon.Monitor, error) {
|
|||||||
case "upstream":
|
case "upstream":
|
||||||
return s.upstreamlog, nil
|
return s.upstreamlog, nil
|
||||||
default:
|
default:
|
||||||
if _, modelID, _, found := findModelInPath(s.cfg, "/"+logMonitorID); found {
|
if _, modelID, _, found := shared.FindModelInPath(s.cfg, "/"+logMonitorID); found {
|
||||||
if log, ok := s.local.ProcessLogger(modelID); ok {
|
if log, ok := s.local.ProcessLogger(modelID); ok {
|
||||||
return log, nil
|
return log, nil
|
||||||
}
|
}
|
||||||
@@ -102,13 +102,13 @@ func (s *Server) handleLogStream(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
logger, err := s.getLogger(logMonitorID)
|
logger, err := s.getLogger(logMonitorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendResponse(w, r, http.StatusBadRequest, err.Error())
|
shared.SendResponse(w, r, http.StatusBadRequest, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
flusher, ok := w.(http.Flusher)
|
flusher, ok := w.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
router.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
|
shared.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+132
-33
@@ -25,6 +25,8 @@ import (
|
|||||||
// TokenMetrics holds token usage and performance metrics.
|
// TokenMetrics holds token usage and performance metrics.
|
||||||
type TokenMetrics struct {
|
type TokenMetrics struct {
|
||||||
CachedTokens int `json:"cache_tokens"`
|
CachedTokens int `json:"cache_tokens"`
|
||||||
|
DraftTokens int `json:"draft_tokens"`
|
||||||
|
DraftAccTokens int `json:"draft_acc_tokens"`
|
||||||
InputTokens int `json:"input_tokens"`
|
InputTokens int `json:"input_tokens"`
|
||||||
OutputTokens int `json:"output_tokens"`
|
OutputTokens int `json:"output_tokens"`
|
||||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||||
@@ -33,15 +35,17 @@ type TokenMetrics struct {
|
|||||||
|
|
||||||
// ActivityLogEntry represents parsed token statistics from llama-server logs.
|
// ActivityLogEntry represents parsed token statistics from llama-server logs.
|
||||||
type ActivityLogEntry struct {
|
type ActivityLogEntry struct {
|
||||||
ID int `json:"id"`
|
ID int `json:"id"`
|
||||||
Timestamp time.Time `json:"timestamp"`
|
Timestamp time.Time `json:"timestamp"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
ReqPath string `json:"req_path"`
|
ReqPath string `json:"req_path"`
|
||||||
RespContentType string `json:"resp_content_type"`
|
RespContentType string `json:"resp_content_type"`
|
||||||
RespStatusCode int `json:"resp_status_code"`
|
RespStatusCode int `json:"resp_status_code"`
|
||||||
Tokens TokenMetrics `json:"tokens"`
|
Tokens TokenMetrics `json:"tokens"`
|
||||||
DurationMs int `json:"duration_ms"`
|
DurationMs int `json:"duration_ms"`
|
||||||
HasCapture bool `json:"has_capture"`
|
HasCapture bool `json:"has_capture"`
|
||||||
|
ErrorMsg string `json:"error_msg,omitempty"`
|
||||||
|
Metadata map[string]string `json:"metadata,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ActivityLogEvent carries a single activity log entry to event subscribers.
|
// ActivityLogEvent carries a single activity log entry to event subscribers.
|
||||||
@@ -122,9 +126,11 @@ func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// record parses a completed response body and stores/emits an activity entry.
|
// record parses a completed response body and stores/emits an activity entry.
|
||||||
// When captures are enabled, a zstd+CBOR capture is stored for successful
|
// Successful requests store a zstd+CBOR capture (when enabled) with cf
|
||||||
// requests, with cf controlling which request/response parts are retained.
|
// controlling which parts are retained. Failed (non-200) requests capture the
|
||||||
// reqBody and reqHeaders are the request data buffered before dispatch.
|
// request only and set ErrorMsg to a description of the failure, so the error
|
||||||
|
// can be inspected without storing unreadable raw response bytes. reqBody and
|
||||||
|
// reqHeaders are the request data buffered before dispatch.
|
||||||
func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *responseBodyCopier, cf captureFields, reqBody []byte, reqHeaders map[string]string) {
|
func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *responseBodyCopier, cf captureFields, reqBody []byte, reqHeaders map[string]string) {
|
||||||
tm := ActivityLogEntry{
|
tm := ActivityLogEntry{
|
||||||
Timestamp: time.Now(),
|
Timestamp: time.Now(),
|
||||||
@@ -135,6 +141,13 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
|
|||||||
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
|
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ctxData, ok := shared.ReadContext(r.Context()); ok && len(ctxData.Metadata) > 0 {
|
||||||
|
tm.Metadata = make(map[string]string, len(ctxData.Metadata))
|
||||||
|
for k, v := range ctxData.Metadata {
|
||||||
|
tm.Metadata[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
queueAndEmit := func() {
|
queueAndEmit := func() {
|
||||||
tm.ID = mp.queueMetrics(tm)
|
tm.ID = mp.queueMetrics(tm)
|
||||||
mp.emitMetric(tm)
|
mp.emitMetric(tm)
|
||||||
@@ -142,7 +155,13 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
|
|||||||
|
|
||||||
if recorder.Status() != http.StatusOK {
|
if recorder.Status() != http.StatusOK {
|
||||||
mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), r.URL.Path)
|
mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), r.URL.Path)
|
||||||
queueAndEmit()
|
decoded, decErr := mp.decodeResponseBody(recorder, r.URL.Path)
|
||||||
|
tm.ErrorMsg = failedErrorMessage(recorder.Status(), decoded, decErr)
|
||||||
|
tm.ID = mp.queueMetrics(tm)
|
||||||
|
// Capture the request only; the failure is surfaced via ErrorMsg
|
||||||
|
// rather than storing the (possibly undisplayable) response body.
|
||||||
|
tm.HasCapture = mp.storeCapture(tm.ID, r, recorder, cf&^captureRespBody, reqBody, reqHeaders, nil)
|
||||||
|
mp.emitMetric(tm)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,6 +176,7 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
|
|||||||
decoded, err := decompressBody(body, encoding)
|
decoded, err := decompressBody(body, encoding)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, r.URL.Path)
|
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, r.URL.Path)
|
||||||
|
tm.ErrorMsg = fmt.Sprintf("response decompression failed: %v", err)
|
||||||
queueAndEmit()
|
queueAndEmit()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -195,28 +215,99 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
|
|||||||
}
|
}
|
||||||
|
|
||||||
tm.ID = mp.queueMetrics(tm)
|
tm.ID = mp.queueMetrics(tm)
|
||||||
if mp.enableCaptures {
|
tm.HasCapture = mp.storeCapture(tm.ID, r, recorder, cf, reqBody, reqHeaders, body)
|
||||||
capture := ReqRespCapture{
|
mp.emitMetric(tm)
|
||||||
ID: tm.ID,
|
}
|
||||||
ReqPath: r.URL.Path,
|
|
||||||
ReqHeaders: reqHeaders,
|
// storeCapture assembles a ReqRespCapture for id, honoring the captureFields
|
||||||
}
|
// mask, and stores it when captures are enabled. body is the response body to
|
||||||
if cf&captureReqBody != 0 {
|
// capture (already decompressed by the caller); pass nil to omit it. Returns
|
||||||
capture.ReqBody = reqBody
|
// true if a capture was stored.
|
||||||
}
|
func (mp *metricsMonitor) storeCapture(id int, r *http.Request, recorder *responseBodyCopier, cf captureFields, reqBody []byte, reqHeaders map[string]string, body []byte) bool {
|
||||||
if cf&captureRespHeaders != 0 {
|
if !mp.enableCaptures {
|
||||||
capture.RespHeaders = headerMap(recorder.Header())
|
return false
|
||||||
redactHeaders(capture.RespHeaders)
|
}
|
||||||
delete(capture.RespHeaders, "Content-Encoding")
|
capture := ReqRespCapture{
|
||||||
}
|
ID: id,
|
||||||
if cf&captureRespBody != 0 {
|
ReqPath: r.URL.Path,
|
||||||
capture.RespBody = body
|
ReqHeaders: reqHeaders,
|
||||||
}
|
}
|
||||||
if mp.addCapture(capture) {
|
if cf&captureReqBody != 0 {
|
||||||
tm.HasCapture = true
|
capture.ReqBody = reqBody
|
||||||
|
}
|
||||||
|
if cf&captureRespHeaders != 0 {
|
||||||
|
capture.RespHeaders = headerMap(recorder.Header())
|
||||||
|
redactHeaders(capture.RespHeaders)
|
||||||
|
delete(capture.RespHeaders, "Content-Encoding")
|
||||||
|
}
|
||||||
|
if cf&captureRespBody != 0 {
|
||||||
|
capture.RespBody = body
|
||||||
|
}
|
||||||
|
return mp.addCapture(capture)
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeResponseBody returns the buffered response body, decompressing it when
|
||||||
|
// the upstream set a Content-Encoding we recognize. On decompression failure it
|
||||||
|
// logs a warning and returns an error so the caller can record a description
|
||||||
|
// (via ErrorMsg) instead of storing unreadable raw bytes.
|
||||||
|
func (mp *metricsMonitor) decodeResponseBody(recorder *responseBodyCopier, path string) ([]byte, error) {
|
||||||
|
body := recorder.body.Bytes()
|
||||||
|
if len(body) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
encoding := recorder.Header().Get("Content-Encoding")
|
||||||
|
if encoding == "" {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
decoded, err := decompressBody(body, encoding)
|
||||||
|
if err != nil {
|
||||||
|
mp.logger.Warnf("metrics: response decompression failed: %v, path=%s", err, path)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return decoded, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// errorMessagePaths lists JSON paths where a human-readable error message can
|
||||||
|
// live across OpenAI- and llama.cpp-style error responses.
|
||||||
|
var errorMessagePaths = []string{"error.message", "error", "message", "detail"}
|
||||||
|
|
||||||
|
// extractErrorMessage pulls a human-readable error string from a JSON error
|
||||||
|
// response. Returns "" if no message is found or the body is not valid JSON.
|
||||||
|
func extractErrorMessage(body []byte) string {
|
||||||
|
if !gjson.ValidBytes(body) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
parsed := gjson.ParseBytes(body)
|
||||||
|
for _, path := range errorMessagePaths {
|
||||||
|
v := parsed.Get(path)
|
||||||
|
if v.Exists() && v.Type == gjson.String {
|
||||||
|
if s := strings.TrimSpace(v.String()); s != "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
mp.emitMetric(tm)
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// failedErrorMessage builds a human-readable description for a non-200 response.
|
||||||
|
// It prefers an error message parsed from the (decompressed) body and falls back
|
||||||
|
// to the HTTP status text. A non-nil decErr indicates the body could not be
|
||||||
|
// decoded, in which case the decode error is described instead.
|
||||||
|
func failedErrorMessage(status int, body []byte, decErr error) string {
|
||||||
|
const maxLen = 500
|
||||||
|
if decErr != nil {
|
||||||
|
return fmt.Sprintf("response decode failed: %v", decErr)
|
||||||
|
}
|
||||||
|
if msg := extractErrorMessage(body); msg != "" {
|
||||||
|
if len(msg) > maxLen {
|
||||||
|
msg = msg[:maxLen] + "..."
|
||||||
|
}
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
if text := http.StatusText(status); text != "" {
|
||||||
|
return fmt.Sprintf("%d %s", status, text)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("HTTP %d", status)
|
||||||
}
|
}
|
||||||
|
|
||||||
// usagePaths lists the JSON paths where a per-event usage object can live.
|
// usagePaths lists the JSON paths where a per-event usage object can live.
|
||||||
@@ -337,6 +428,8 @@ func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, ca
|
|||||||
durationMs := wallDurationMs
|
durationMs := wallDurationMs
|
||||||
tokensPerSecond := -1.0
|
tokensPerSecond := -1.0
|
||||||
promptPerSecond := -1.0
|
promptPerSecond := -1.0
|
||||||
|
draftTokens := -1
|
||||||
|
draftAccTokens := -1
|
||||||
|
|
||||||
if timings.Exists() {
|
if timings.Exists() {
|
||||||
inputTokens = timings.Get("prompt_n").Int()
|
inputTokens = timings.Get("prompt_n").Int()
|
||||||
@@ -350,6 +443,10 @@ func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, ca
|
|||||||
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
|
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
|
||||||
cachedTokens = cachedValue.Int()
|
cachedTokens = cachedValue.Int()
|
||||||
}
|
}
|
||||||
|
if timings.Get("draft_n").Exists() && timings.Get("draft_n_accepted").Exists() {
|
||||||
|
draftTokens = int(timings.Get("draft_n").Int())
|
||||||
|
draftAccTokens = int(timings.Get("draft_n_accepted").Int())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ActivityLogEntry{
|
return ActivityLogEntry{
|
||||||
@@ -357,6 +454,8 @@ func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, ca
|
|||||||
Model: modelID,
|
Model: modelID,
|
||||||
Tokens: TokenMetrics{
|
Tokens: TokenMetrics{
|
||||||
CachedTokens: int(cachedTokens),
|
CachedTokens: int(cachedTokens),
|
||||||
|
DraftTokens: draftTokens,
|
||||||
|
DraftAccTokens: draftAccTokens,
|
||||||
InputTokens: int(inputTokens),
|
InputTokens: int(inputTokens),
|
||||||
OutputTokens: int(outputTokens),
|
OutputTokens: int(outputTokens),
|
||||||
PromptPerSecond: promptPerSecond,
|
PromptPerSecond: promptPerSecond,
|
||||||
|
|||||||
@@ -4,10 +4,11 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/router"
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CreateMetricsMiddleware returns middleware that records token metrics for
|
// CreateMetricsMiddleware returns middleware that records token metrics for
|
||||||
@@ -21,17 +22,36 @@ func CreateMetricsMiddleware(mm *metricsMonitor, cfg config.Config) chain.Middle
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Determine the model-routed endpoint path. Regular routes are
|
||||||
|
// already meterable; /upstream/<model>/<path> is metered only when
|
||||||
|
// the remaining path matches a model-dispatched endpoint.
|
||||||
|
checkPath := r.URL.Path
|
||||||
|
if strings.HasPrefix(r.URL.Path, "/upstream/") {
|
||||||
|
var found bool
|
||||||
|
_, _, checkPath, found = shared.FindModelInPath(cfg, strings.TrimPrefix(r.URL.Path, "/upstream"))
|
||||||
|
if !found {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isMetricsRecordPath(checkPath) {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Resolve the model now so downstream dispatch hits the context
|
// Resolve the model now so downstream dispatch hits the context
|
||||||
// fast path; FetchContext restores the request body.
|
// fast path; FetchContext restores the request body for regular
|
||||||
data, err := router.FetchContext(r, cfg)
|
// routes and extracts the model from the URL for /upstream routes.
|
||||||
|
data, err := shared.FetchContext(r, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendError(w, r, router.ErrNoModelInContext)
|
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Buffer the request body/headers for capture before dispatch
|
// Buffer the request body/headers for capture before dispatch
|
||||||
// consumes them.
|
// consumes them.
|
||||||
cf := captureFieldsFor(r.URL.Path)
|
cf := captureFieldsFor(checkPath)
|
||||||
var reqBody []byte
|
var reqBody []byte
|
||||||
var reqHeaders map[string]string
|
var reqHeaders map[string]string
|
||||||
if mm.enableCaptures {
|
if mm.enableCaptures {
|
||||||
|
|||||||
@@ -1,9 +1,16 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -56,6 +63,199 @@ func TestServer_ProcessStreamingResponse_NoData(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_RecordMetadata(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(nil, 10, 0)
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"usage":{}}`))
|
||||||
|
r = r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{
|
||||||
|
ModelID: "m",
|
||||||
|
Metadata: map[string]string{"client": "web", "trace": "abc"},
|
||||||
|
}))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
copier := newBodyCopier(w)
|
||||||
|
copier.WriteHeader(http.StatusOK)
|
||||||
|
copier.Write([]byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2}}`))
|
||||||
|
|
||||||
|
mm.record("m", r, copier, 0, nil, nil)
|
||||||
|
|
||||||
|
entries := mm.getMetrics()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("want 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Metadata["client"] != "web" {
|
||||||
|
t.Errorf("client = %q, want web", entries[0].Metadata["client"])
|
||||||
|
}
|
||||||
|
if entries[0].Metadata["trace"] != "abc" {
|
||||||
|
t.Errorf("trace = %q, want abc", entries[0].Metadata["trace"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_RecordFailedRequestCapture(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
reqHeaders := map[string]string{"content-type": "application/json"}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
copier := newBodyCopier(w)
|
||||||
|
copier.Header().Set("Content-Type", "application/json")
|
||||||
|
copier.WriteHeader(http.StatusBadGateway)
|
||||||
|
copier.Write([]byte(`{"error":{"message":"model unavailable"}}`))
|
||||||
|
|
||||||
|
reqBody := []byte(`{"model":"m","messages":[]}`)
|
||||||
|
mm.record("m", r, copier, captureAll, reqBody, reqHeaders)
|
||||||
|
|
||||||
|
entries := mm.getMetrics()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("want 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
entry := entries[0]
|
||||||
|
if entry.RespStatusCode != http.StatusBadGateway {
|
||||||
|
t.Errorf("status = %d, want %d", entry.RespStatusCode, http.StatusBadGateway)
|
||||||
|
}
|
||||||
|
if entry.ErrorMsg != "model unavailable" {
|
||||||
|
t.Errorf("error_msg = %q, want extracted message", entry.ErrorMsg)
|
||||||
|
}
|
||||||
|
if !entry.HasCapture {
|
||||||
|
t.Fatal("failed request should capture the request so it can be inspected")
|
||||||
|
}
|
||||||
|
|
||||||
|
got := mm.getCaptureByID(entry.ID)
|
||||||
|
if got == nil {
|
||||||
|
t.Fatal("capture not found")
|
||||||
|
}
|
||||||
|
if string(got.ReqBody) != `{"model":"m","messages":[]}` {
|
||||||
|
t.Errorf("req body = %q", got.ReqBody)
|
||||||
|
}
|
||||||
|
if len(got.RespBody) != 0 {
|
||||||
|
t.Errorf("resp body stored for failed request (len=%d); want none", len(got.RespBody))
|
||||||
|
}
|
||||||
|
if got.RespHeaders["Content-Type"] != "application/json" {
|
||||||
|
t.Errorf("resp Content-Type = %q", got.RespHeaders["Content-Type"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_RecordFailedRequestStatusFallback(t *testing.T) {
|
||||||
|
// Non-JSON error body: ErrorMsg falls back to the HTTP status text.
|
||||||
|
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
copier := newBodyCopier(w)
|
||||||
|
copier.WriteHeader(http.StatusBadGateway)
|
||||||
|
copier.Write([]byte("<html>upstream down</html>"))
|
||||||
|
|
||||||
|
mm.record("m", r, copier, captureAll, nil, nil)
|
||||||
|
|
||||||
|
entries := mm.getMetrics()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("want 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].ErrorMsg != "502 Bad Gateway" {
|
||||||
|
t.Errorf("error_msg = %q, want status text", entries[0].ErrorMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_RecordFailedRequestCaptureDisabled(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 0) // captures disabled
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
copier := newBodyCopier(w)
|
||||||
|
copier.WriteHeader(http.StatusInternalServerError)
|
||||||
|
copier.Write([]byte(`{"error":"boom"}`))
|
||||||
|
|
||||||
|
mm.record("m", r, copier, captureAll, []byte("req"), nil)
|
||||||
|
|
||||||
|
entries := mm.getMetrics()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("want 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].HasCapture {
|
||||||
|
t.Fatal("captures disabled, HasCapture should be false")
|
||||||
|
}
|
||||||
|
// ErrorMsg is independent of whether captures are enabled.
|
||||||
|
if entries[0].ErrorMsg != "boom" {
|
||||||
|
t.Errorf("error_msg = %q, want boom", entries[0].ErrorMsg)
|
||||||
|
}
|
||||||
|
if mm.getCaptureByID(entries[0].ID) != nil {
|
||||||
|
t.Fatal("no capture should be stored when disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_RecordDecompressionFailureSetsErrorMsg(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
copier := newBodyCopier(w)
|
||||||
|
copier.Header().Set("Content-Encoding", "gzip")
|
||||||
|
copier.WriteHeader(http.StatusOK)
|
||||||
|
copier.Write([]byte("not-really-gzip"))
|
||||||
|
|
||||||
|
mm.record("m", r, copier, captureAll, []byte("req"), nil)
|
||||||
|
|
||||||
|
entries := mm.getMetrics()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("want 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].ErrorMsg == "" {
|
||||||
|
t.Fatal("expected ErrorMsg for decompression failure")
|
||||||
|
}
|
||||||
|
// Raw bytes must not be stored when the body could not be decoded.
|
||||||
|
if entries[0].HasCapture {
|
||||||
|
t.Fatal("decompression failure should not store a capture")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_DecodeResponseBody(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
|
||||||
|
|
||||||
|
// No Content-Encoding: body returned unchanged.
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
copier := newBodyCopier(w)
|
||||||
|
copier.Write([]byte("plain"))
|
||||||
|
got, err := mm.decodeResponseBody(copier, "/p")
|
||||||
|
if err != nil || string(got) != "plain" {
|
||||||
|
t.Fatalf("plain body = %q, err = %v", got, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bogus gzip payload: returns an error and no body (no raw bytes kept).
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
copier2 := newBodyCopier(w2)
|
||||||
|
copier2.Header().Set("Content-Encoding", "gzip")
|
||||||
|
copier2.Write([]byte("not-really-gzip"))
|
||||||
|
got, err = mm.decodeResponseBody(copier2, "/p")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected decompression error")
|
||||||
|
}
|
||||||
|
if got != nil {
|
||||||
|
t.Errorf("expected nil body on failure, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_ExtractErrorMessage(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"openai object", `{"error":{"message":"rate limited"}}`, "rate limited"},
|
||||||
|
{"string error", `{"error":"bad request"}`, "bad request"},
|
||||||
|
{"message field", `{"message":"nope"}`, "nope"},
|
||||||
|
{"detail field", `{"detail":"oops"}`, "oops"},
|
||||||
|
{"object error ignored", `{"error":{"code":42}}`, ""},
|
||||||
|
{"no error", `{"usage":{}}`, ""},
|
||||||
|
{"invalid json", `not-json`, ""},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if got := extractErrorMessage([]byte(tc.body)); got != tc.want {
|
||||||
|
t.Errorf("extractErrorMessage = %q, want %q", got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestServer_ParseMetrics_Infill(t *testing.T) {
|
func TestServer_ParseMetrics_Infill(t *testing.T) {
|
||||||
// /infill responses are arrays; timings live in the last element.
|
// /infill responses are arrays; timings live in the last element.
|
||||||
body := `[{"content":"a"},{"content":"b","timings":{"prompt_n":5,"predicted_n":9,"prompt_ms":10,"predicted_ms":20}}]`
|
body := `[{"content":"a"},{"content":"b","timings":{"prompt_n":5,"predicted_n":9,"prompt_ms":10,"predicted_ms":20}}]`
|
||||||
@@ -72,3 +272,40 @@ func TestServer_ParseMetrics_Infill(t *testing.T) {
|
|||||||
t.Fatalf("tokens = %+v", entry.Tokens)
|
t.Fatalf("tokens = %+v", entry.Tokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestServer_MetricsMiddleware_UpstreamAudioCaptureSkipsRespBody verifies that
|
||||||
|
// an /upstream/<model>/v1/audio/speech request uses the path-specific capture
|
||||||
|
// mask (headers only) rather than falling back to captureAll.
|
||||||
|
func TestServer_MetricsMiddleware_UpstreamAudioCaptureSkipsRespBody(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 5)
|
||||||
|
cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
|
||||||
|
|
||||||
|
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "audio/mpeg")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("BINARY-AUDIO-DATA"))
|
||||||
|
})
|
||||||
|
handler := CreateMetricsMiddleware(mm, cfg)(inner)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/audio/speech", strings.NewReader(`{"model":"m1"}`))
|
||||||
|
handler.ServeHTTP(httptest.NewRecorder(), req)
|
||||||
|
|
||||||
|
entries := mm.getMetrics()
|
||||||
|
if len(entries) == 0 {
|
||||||
|
t.Fatal("no metrics recorded")
|
||||||
|
}
|
||||||
|
last := entries[len(entries)-1]
|
||||||
|
if !last.HasCapture {
|
||||||
|
t.Fatal("expected capture to be stored")
|
||||||
|
}
|
||||||
|
cap := mm.getCaptureByID(last.ID)
|
||||||
|
if cap == nil {
|
||||||
|
t.Fatal("capture not found")
|
||||||
|
}
|
||||||
|
if len(cap.RespBody) != 0 {
|
||||||
|
t.Errorf("RespBody stored for /upstream audio route (len=%d); want path-specific mask to skip body", len(cap.RespBody))
|
||||||
|
}
|
||||||
|
if len(cap.RespHeaders) == 0 {
|
||||||
|
t.Error("RespHeaders not stored; want captureRespHeaders mask")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
+40
-23
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
"github.com/mostlygeek/llama-swap/internal/perf"
|
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||||
"github.com/mostlygeek/llama-swap/internal/router"
|
"github.com/mostlygeek/llama-swap/internal/router"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server owns the HTTP mux, cross-cutting middleware, and the local/peer model
|
// Server owns the HTTP mux, cross-cutting middleware, and the local/peer model
|
||||||
@@ -88,6 +89,27 @@ var modelGetRoutes = []string{
|
|||||||
"/sdapi/v1/loras",
|
"/sdapi/v1/loras",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isMetricsRecordPath reports whether path is one of the model-dispatched
|
||||||
|
// endpoints that the metrics middleware records in the activity log.
|
||||||
|
func isMetricsRecordPath(path string) bool {
|
||||||
|
for _, p := range modelPostJSONRoutes {
|
||||||
|
if p == path {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, p := range modelPostFormRoutes {
|
||||||
|
if p == path {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, p := range modelGetRoutes {
|
||||||
|
if p == path {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// BuildInfo carries version metadata surfaced by GET /api/version.
|
// BuildInfo carries version metadata surfaced by GET /api/version.
|
||||||
type BuildInfo struct {
|
type BuildInfo struct {
|
||||||
Version string
|
Version string
|
||||||
@@ -99,12 +121,13 @@ func New(cfg config.Config, muxlog *logmon.Monitor, proxylog *logmon.Monitor, up
|
|||||||
var local router.LocalRouter
|
var local router.LocalRouter
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if cfg.Matrix != nil {
|
switch cfg.Routing.Router.Use {
|
||||||
|
case "matrix":
|
||||||
local, err = router.NewMatrix(cfg, proxylog, upstreamlog)
|
local, err = router.NewMatrix(cfg, proxylog, upstreamlog)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("creating matrix router: %w", err)
|
return nil, fmt.Errorf("creating matrix router: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
default: // "group"
|
||||||
local, err = router.NewGroup(cfg, proxylog, upstreamlog)
|
local, err = router.NewGroup(cfg, proxylog, upstreamlog)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("creating group router: %w", err)
|
return nil, fmt.Errorf("creating group router: %w", err)
|
||||||
@@ -137,13 +160,13 @@ func New(cfg config.Config, muxlog *logmon.Monitor, proxylog *logmon.Monitor, up
|
|||||||
}
|
}
|
||||||
|
|
||||||
// localPeerHandler dispatches a model-routed request to the local or peer
|
// localPeerHandler dispatches a model-routed request to the local or peer
|
||||||
// router. The model is resolved once via router.FetchContext.
|
// router. The model is resolved once via shared.FetchContext.
|
||||||
func (s *Server) localPeerHandler(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) localPeerHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
stripVersionPrefix(r)
|
stripVersionPrefix(r)
|
||||||
|
|
||||||
data, err := router.FetchContext(r, s.cfg)
|
data, err := shared.FetchContext(r, s.cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendError(w, r, router.ErrNoModelInContext)
|
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -155,7 +178,7 @@ func (s *Server) localPeerHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
s.proxylog.Debugf("dispatch: using peer for model: %s", data.ModelID)
|
s.proxylog.Debugf("dispatch: using peer for model: %s", data.ModelID)
|
||||||
s.peer.ServeHTTP(w, r)
|
s.peer.ServeHTTP(w, r)
|
||||||
default:
|
default:
|
||||||
router.SendError(w, r, router.ErrNoRouterFound)
|
shared.SendError(w, r, router.ErrNoRouterFound)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -170,21 +193,13 @@ func stripVersionPrefix(r *http.Request) {
|
|||||||
// routes builds the mux, registers every route, and wraps the mux with the
|
// routes builds the mux, registers every route, and wraps the mux with the
|
||||||
// global CORS middleware.
|
// global CORS middleware.
|
||||||
func (s *Server) routes() {
|
func (s *Server) routes() {
|
||||||
authMW := CreateAuthMiddleware(s.cfg)
|
|
||||||
filterMW := CreateFilterMiddleware(s.cfg)
|
|
||||||
formFilterMW := CreateFormFilterMiddleware(s.cfg)
|
|
||||||
|
|
||||||
// Model-dispatched routes get auth + per-model concurrency limiting + body
|
authMW := CreateAuthMiddleware(s.cfg)
|
||||||
// filters + in-flight tracking + token metrics. concurrencyMW rejects with
|
|
||||||
// 429 before the body filters do any rewrite work. filterMW rewrites JSON
|
|
||||||
// bodies and formFilterMW rewrites multipart bodies; each is a no-op for the
|
|
||||||
// other's Content-Type. Both run before the metrics middleware so it buffers
|
|
||||||
// the rewritten body.
|
|
||||||
modelChain := chain.New(
|
modelChain := chain.New(
|
||||||
authMW,
|
authMW,
|
||||||
CreateConcurrencyMiddleware(s.cfg),
|
CreateRequestContextMiddleware(s.cfg),
|
||||||
filterMW,
|
CreateFilterMiddleware(s.cfg),
|
||||||
formFilterMW,
|
CreateFormFilterMiddleware(s.cfg),
|
||||||
CreateInflightMiddleware(s.inflight),
|
CreateInflightMiddleware(s.inflight),
|
||||||
CreateMetricsMiddleware(s.metrics, s.cfg),
|
CreateMetricsMiddleware(s.metrics, s.cfg),
|
||||||
)
|
)
|
||||||
@@ -215,19 +230,21 @@ func (s *Server) routes() {
|
|||||||
mux.HandleFunc("GET /{$}", handleRootRedirect)
|
mux.HandleFunc("GET /{$}", handleRootRedirect)
|
||||||
|
|
||||||
// Embedded UI.
|
// Embedded UI.
|
||||||
mux.HandleFunc("GET /ui/", s.handleUI)
|
mux.Handle("GET /ui/", chain.New(authMW).ThenFunc(s.handleUI))
|
||||||
mux.HandleFunc("GET /favicon.ico", s.handleFavicon)
|
mux.HandleFunc("GET /favicon.ico", s.handleFavicon)
|
||||||
|
|
||||||
// Prometheus metrics (no auth, matches the legacy endpoint).
|
// Prometheus metrics (wrapped by apiChain, matches the legacy endpoint).
|
||||||
mux.HandleFunc("GET /metrics", s.handleMetrics)
|
mux.Handle("GET /metrics", apiChain.ThenFunc(s.handleMetrics))
|
||||||
|
|
||||||
// Operations endpoints.
|
// Operations endpoints.
|
||||||
mux.Handle("GET /unload", apiChain.ThenFunc(s.handleUnload))
|
mux.Handle("GET /unload", apiChain.ThenFunc(s.handleUnload))
|
||||||
mux.Handle("GET /running", apiChain.ThenFunc(s.handleRunning))
|
mux.Handle("GET /running", apiChain.ThenFunc(s.handleRunning))
|
||||||
|
|
||||||
// Upstream passthrough.
|
// Upstream passthrough. Meter only the model-dispatched endpoints that can
|
||||||
|
// produce token usage/timings.
|
||||||
|
upstreamChain := apiChain.Append(CreateMetricsMiddleware(s.metrics, s.cfg))
|
||||||
mux.HandleFunc("GET /upstream", handleUpstreamRedirect)
|
mux.HandleFunc("GET /upstream", handleUpstreamRedirect)
|
||||||
mux.Handle("/upstream/{upstreamPath...}", apiChain.ThenFunc(s.handleUpstream))
|
mux.Handle("/upstream/{upstreamPath...}", upstreamChain.ThenFunc(s.handleUpstream))
|
||||||
|
|
||||||
// API group (API-key protected) consumed by the UI.
|
// API group (API-key protected) consumed by the UI.
|
||||||
mux.Handle("POST /api/models/unload", apiChain.ThenFunc(s.handleAPIUnloadAll))
|
mux.Handle("POST /api/models/unload", apiChain.ThenFunc(s.handleAPIUnloadAll))
|
||||||
|
|||||||
@@ -84,10 +84,15 @@ func chatRequest(model string) *http.Request {
|
|||||||
|
|
||||||
func TestServer_New_GroupConfig(t *testing.T) {
|
func TestServer_New_GroupConfig(t *testing.T) {
|
||||||
discard := logmon.NewWriter(io.Discard)
|
discard := logmon.NewWriter(io.Discard)
|
||||||
s, err := New(config.Config{HealthCheckTimeout: 15}, discard, discard, discard, nil, BuildInfo{})
|
cfg := config.Config{HealthCheckTimeout: 15}
|
||||||
|
cfg.Routing.Router.Use = "group"
|
||||||
|
s, err := New(cfg, discard, discard, discard, nil, BuildInfo{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("New (group): %v", err)
|
t.Fatalf("New (group): %v", err)
|
||||||
}
|
}
|
||||||
|
if _, ok := s.local.(*router.Group); !ok {
|
||||||
|
t.Fatalf("localRouter=%T want *router.Group", s.local)
|
||||||
|
}
|
||||||
if err := s.Shutdown(time.Second); err != nil {
|
if err := s.Shutdown(time.Second); err != nil {
|
||||||
t.Fatalf("Shutdown: %v", err)
|
t.Fatalf("Shutdown: %v", err)
|
||||||
}
|
}
|
||||||
@@ -95,11 +100,16 @@ func TestServer_New_GroupConfig(t *testing.T) {
|
|||||||
|
|
||||||
func TestServer_New_MatrixConfig(t *testing.T) {
|
func TestServer_New_MatrixConfig(t *testing.T) {
|
||||||
discard := logmon.NewWriter(io.Discard)
|
discard := logmon.NewWriter(io.Discard)
|
||||||
cfg := config.Config{HealthCheckTimeout: 15, Matrix: &config.MatrixConfig{}}
|
cfg := config.Config{HealthCheckTimeout: 15}
|
||||||
|
cfg.Routing.Router.Use = "matrix"
|
||||||
|
cfg.Routing.Router.Settings.Matrix = &config.MatrixConfig{}
|
||||||
s, err := New(cfg, discard, discard, discard, nil, BuildInfo{})
|
s, err := New(cfg, discard, discard, discard, nil, BuildInfo{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("New (matrix): %v", err)
|
t.Fatalf("New (matrix): %v", err)
|
||||||
}
|
}
|
||||||
|
if _, ok := s.local.(*router.Matrix); !ok {
|
||||||
|
t.Fatalf("localRouter=%T want *router.Matrix", s.local)
|
||||||
|
}
|
||||||
if err := s.Shutdown(time.Second); err != nil {
|
if err := s.Shutdown(time.Second); err != nil {
|
||||||
t.Fatalf("Shutdown: %v", err)
|
t.Fatalf("Shutdown: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,298 @@
|
|||||||
|
package shared
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"html"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type contextkey struct {
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReqContextData struct {
|
||||||
|
ApiKey string
|
||||||
|
Model string
|
||||||
|
ModelID string
|
||||||
|
Streaming bool
|
||||||
|
SendLoadingState bool
|
||||||
|
// Metadata is a request-scoped key/value bag that handlers may mutate
|
||||||
|
// while processing. The metrics middleware copies it into ActivityLogEntry.
|
||||||
|
Metadata map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ReqContextKey = &contextkey{"context"}
|
||||||
|
ErrNoModelInContext = fmt.Errorf("no model in request context")
|
||||||
|
ErrNoRouterFound = fmt.Errorf("no router found for model")
|
||||||
|
ErrNoPeerModelFound = fmt.Errorf("peer model not found")
|
||||||
|
ErrNoLocalModelFound = fmt.Errorf("local model not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
func SendError(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
var httpErr HTTPError
|
||||||
|
if errors.As(err, &httpErr) {
|
||||||
|
for k, v := range httpErr.Header() {
|
||||||
|
w.Header()[k] = v
|
||||||
|
}
|
||||||
|
w.WriteHeader(httpErr.StatusCode())
|
||||||
|
w.Write(httpErr.Body())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, ErrNoModelInContext):
|
||||||
|
SendResponse(w, r, http.StatusNotFound, "no model id could be identified")
|
||||||
|
case errors.Is(err, ErrNoPeerModelFound):
|
||||||
|
SendResponse(w, r, http.StatusNotFound, "no peer found for requested model")
|
||||||
|
case errors.Is(err, ErrNoLocalModelFound):
|
||||||
|
SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
|
||||||
|
case errors.Is(err, ErrNoRouterFound):
|
||||||
|
SendResponse(w, r, http.StatusNotFound, "no router for requested model")
|
||||||
|
default:
|
||||||
|
SendResponse(w, r, http.StatusInternalServerError, fmt.Sprintf("unspecific error: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendResponse detects what content type the client prefers and returns an error response in that format.
|
||||||
|
func SendResponse(w http.ResponseWriter, r *http.Request, status int, message string) {
|
||||||
|
acceptHeader := r.Header.Get("Accept")
|
||||||
|
if strings.Contains(acceptHeader, "text/plain") {
|
||||||
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
w.Write([]byte(fmt.Sprintf("llama-swap: %s", message)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(acceptHeader, "text/html") {
|
||||||
|
w.Header().Set("Content-Type", "text/html")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
w.Write([]byte(fmt.Sprintf(`<html><body><h1>llama-swap</h1><p>%s</p></body></html>`, html.EscapeString(message))))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
resp, err := json.Marshal(map[string]string{"src": "llama-swap", "error": message})
|
||||||
|
if err != nil {
|
||||||
|
w.Write([]byte(`{"src":"llama-swap", "error": "failed to marshal response"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Write(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchContext will attempt to get the model id from the context, then
|
||||||
|
// from an /upstream/<model> path prefix, then from the request body/query.
|
||||||
|
// If it extracts the model it will store it in the context for downstream
|
||||||
|
// handlers. An error will be returned when a model cannot be identified.
|
||||||
|
func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
|
||||||
|
data, ok := ReadContext(r.Context())
|
||||||
|
if ok {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(r.URL.Path, "/upstream/") {
|
||||||
|
if data, ok := extractUpstreamContext(r, cfg); ok {
|
||||||
|
*r = *r.WithContext(SetContext(r.Context(), data))
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
return ReqContextData{}, ErrNoModelInContext
|
||||||
|
}
|
||||||
|
|
||||||
|
if data, err := extractContext(r); err == nil && data.Model != "" {
|
||||||
|
realName, _ := cfg.RealModelName(data.Model)
|
||||||
|
if realName == "" {
|
||||||
|
realName = data.Model
|
||||||
|
}
|
||||||
|
data.ModelID = realName
|
||||||
|
if mc, ok := cfg.Models[realName]; ok {
|
||||||
|
data.SendLoadingState = mc.SendLoadingState != nil && *mc.SendLoadingState
|
||||||
|
}
|
||||||
|
*r = *r.WithContext(SetContext(r.Context(), data))
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ReqContextData{}, ErrNoModelInContext
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractUpstreamContext resolves the model from an /upstream/<model>/... path.
|
||||||
|
func extractUpstreamContext(r *http.Request, cfg config.Config) (ReqContextData, bool) {
|
||||||
|
searchName, realName, _, found := FindModelInPath(cfg, strings.TrimPrefix(r.URL.Path, "/upstream"))
|
||||||
|
if !found {
|
||||||
|
return ReqContextData{}, false
|
||||||
|
}
|
||||||
|
return ReqContextData{
|
||||||
|
Model: searchName,
|
||||||
|
ModelID: realName,
|
||||||
|
ApiKey: ExtractAPIKey(r),
|
||||||
|
Streaming: r.URL.Query().Get("stream") == "true",
|
||||||
|
SendLoadingState: sendLoadingState(cfg, realName),
|
||||||
|
Metadata: make(map[string]string),
|
||||||
|
}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendLoadingState reports whether the configured model wants loading-state SSEs.
|
||||||
|
func sendLoadingState(cfg config.Config, modelID string) bool {
|
||||||
|
if mc, ok := cfg.Models[modelID]; ok {
|
||||||
|
return mc.SendLoadingState != nil && *mc.SendLoadingState
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindModelInPath walks a slash-separated path, building up segments until one
|
||||||
|
// matches a configured model. This resolves model names that contain slashes
|
||||||
|
// (e.g. "author/model"). Returns the matched name, its real model ID, the
|
||||||
|
// remaining path, and whether a match was found.
|
||||||
|
func FindModelInPath(cfg config.Config, path string) (searchName, realName, remainingPath string, found bool) {
|
||||||
|
parts := strings.Split(strings.TrimSpace(path), "/")
|
||||||
|
name := ""
|
||||||
|
|
||||||
|
for i, part := range parts {
|
||||||
|
if part == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
name = part
|
||||||
|
} else {
|
||||||
|
name = name + "/" + part
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelID, ok := cfg.RealModelName(name); ok {
|
||||||
|
searchName = name
|
||||||
|
realName = modelID
|
||||||
|
remainingPath = "/" + strings.Join(parts[i+1:], "/")
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetContext(ctx context.Context, data ReqContextData) context.Context {
|
||||||
|
return context.WithValue(ctx, ReqContextKey, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReadContext(ctx context.Context) (ReqContextData, bool) {
|
||||||
|
data, ok := ctx.Value(ReqContextKey).(ReqContextData)
|
||||||
|
return data, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReqData attaches a key/value pair to the request context's metadata map.
|
||||||
|
// The metadata map must already exist in the context's ReqContextData; callers
|
||||||
|
// should ensure FetchContext has run or initialize the map themselves.
|
||||||
|
// It returns an error for nil contexts or contexts without request data.
|
||||||
|
func SetReqData(ctx context.Context, key, value string) error {
|
||||||
|
if ctx == nil {
|
||||||
|
return fmt.Errorf("cannot set request metadata on nil context")
|
||||||
|
}
|
||||||
|
data, ok := ReadContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("no request context data found")
|
||||||
|
}
|
||||||
|
if data.Metadata == nil {
|
||||||
|
return fmt.Errorf("no metadata map in request context")
|
||||||
|
}
|
||||||
|
data.Metadata[key] = value
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractContext pulls fields from an HTTP request into a ReqContextData,
|
||||||
|
// returning whatever is available. For GET requests it reads query parameters.
|
||||||
|
// For POST requests it inspects Content-Type and parses JSON,
|
||||||
|
// multipart/form-data, or application/x-www-form-urlencoded bodies. The
|
||||||
|
// request body is always restored before returning. An error is returned only
|
||||||
|
// for I/O or parse failures, not for missing fields.
|
||||||
|
func extractContext(r *http.Request) (ReqContextData, error) {
|
||||||
|
|
||||||
|
apiKey := ExtractAPIKey(r)
|
||||||
|
|
||||||
|
if r.Method == http.MethodGet {
|
||||||
|
q := r.URL.Query()
|
||||||
|
return ReqContextData{
|
||||||
|
Model: q.Get("model"),
|
||||||
|
Streaming: q.Get("stream") == "true",
|
||||||
|
ApiKey: apiKey,
|
||||||
|
Metadata: make(map[string]string),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return ReqContextData{}, fmt.Errorf("error reading request body: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
}()
|
||||||
|
|
||||||
|
contentType := r.Header.Get("Content-Type")
|
||||||
|
|
||||||
|
if strings.Contains(contentType, "application/json") {
|
||||||
|
return ReqContextData{
|
||||||
|
Model: gjson.GetBytes(bodyBytes, "model").String(),
|
||||||
|
Streaming: gjson.GetBytes(bodyBytes, "stream").Bool(),
|
||||||
|
ApiKey: apiKey,
|
||||||
|
Metadata: make(map[string]string),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Form parsers read from r.Body, so feed them a fresh reader over the
|
||||||
|
// buffered bytes. The deferred restore above will reset r.Body again
|
||||||
|
// after parsing.
|
||||||
|
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
if strings.Contains(contentType, "multipart/form-data") {
|
||||||
|
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||||
|
return ReqContextData{}, fmt.Errorf("error parsing multipart form: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
return ReqContextData{}, fmt.Errorf("error parsing form: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ReqContextData{
|
||||||
|
Model: r.FormValue("model"),
|
||||||
|
Streaming: r.FormValue("stream") == "true",
|
||||||
|
ApiKey: apiKey,
|
||||||
|
Metadata: make(map[string]string),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractAPIKey pulls a candidate API key from the request, preferring Basic,
|
||||||
|
// then Bearer, then x-api-key.
|
||||||
|
func ExtractAPIKey(r *http.Request) string {
|
||||||
|
var bearerKey, basicKey string
|
||||||
|
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||||
|
scheme, credentials, ok := strings.Cut(auth, " ")
|
||||||
|
if ok {
|
||||||
|
switch strings.ToLower(scheme) {
|
||||||
|
case "bearer":
|
||||||
|
bearerKey = credentials
|
||||||
|
case "basic":
|
||||||
|
if decoded, err := base64.StdEncoding.DecodeString(credentials); err == nil {
|
||||||
|
if parts := strings.SplitN(string(decoded), ":", 2); len(parts) == 2 {
|
||||||
|
basicKey = parts[1] // password field is the API key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case basicKey != "":
|
||||||
|
return basicKey
|
||||||
|
case bearerKey != "":
|
||||||
|
return bearerKey
|
||||||
|
default:
|
||||||
|
return r.Header.Get("x-api-key")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,525 @@
|
|||||||
|
package shared
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractContext_GET(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
query string
|
||||||
|
wantModel string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"model present", "model=llama3", "llama3", false},
|
||||||
|
{"model with slashes", "model=author/model-7b", "author/model-7b", false},
|
||||||
|
{"model missing", "", "", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
|
||||||
|
got, err := extractContext(r)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||||
|
}
|
||||||
|
if got.Model != tt.wantModel {
|
||||||
|
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_JSON(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
wantModel string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"model present", `{"model":"llama3","stream":true}`, "llama3", false},
|
||||||
|
{"model with slashes", `{"model":"author/model-7b"}`, "author/model-7b", false},
|
||||||
|
{"model empty string", `{"model":""}`, "", false},
|
||||||
|
{"model key missing", `{"stream":true}`, "", false},
|
||||||
|
{"invalid json", `not-json`, "", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
got, err := extractContext(r)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||||
|
}
|
||||||
|
if got.Model != tt.wantModel {
|
||||||
|
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_URLEncodedForm(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
formModel string
|
||||||
|
wantModel string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"model present", "whisper-1", "whisper-1", false},
|
||||||
|
{"model missing", "", "", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
form := url.Values{}
|
||||||
|
if tt.formModel != "" {
|
||||||
|
form.Set("model", tt.formModel)
|
||||||
|
}
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(form.Encode()))
|
||||||
|
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
got, err := extractContext(r)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||||
|
}
|
||||||
|
if got.Model != tt.wantModel {
|
||||||
|
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_MultipartForm(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
formModel string
|
||||||
|
wantModel string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"model present", "whisper-1", "whisper-1", false},
|
||||||
|
{"model missing", "", "", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
mw := multipart.NewWriter(&buf)
|
||||||
|
if tt.formModel != "" {
|
||||||
|
fw, _ := mw.CreateFormField("model")
|
||||||
|
fw.Write([]byte(tt.formModel))
|
||||||
|
}
|
||||||
|
mw.Close()
|
||||||
|
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf)
|
||||||
|
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||||
|
got, err := extractContext(r)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||||
|
}
|
||||||
|
if got.Model != tt.wantModel {
|
||||||
|
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_JSONBodyRestored(t *testing.T) {
|
||||||
|
body := `{"model":"llama3","stream":true}`
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body))
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
if _, err := extractContext(r); err != nil {
|
||||||
|
t.Fatalf("ExtractContext: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
remaining, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading body after ExtractContext: %v", err)
|
||||||
|
}
|
||||||
|
if string(remaining) != body {
|
||||||
|
t.Errorf("body not restored: want %q got %q", body, string(remaining))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_MultipartBodyRestored(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
mw := multipart.NewWriter(&buf)
|
||||||
|
fw, _ := mw.CreateFormField("model")
|
||||||
|
fw.Write([]byte("whisper-1"))
|
||||||
|
ff, _ := mw.CreateFormFile("file", "audio.wav")
|
||||||
|
ff.Write([]byte("fake-audio-bytes"))
|
||||||
|
mw.Close()
|
||||||
|
|
||||||
|
original := buf.Bytes()
|
||||||
|
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", bytes.NewReader(original))
|
||||||
|
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||||
|
|
||||||
|
if _, err := extractContext(r); err != nil {
|
||||||
|
t.Fatalf("ExtractContext: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
remaining, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading body after ExtractContext: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(remaining, original) {
|
||||||
|
t.Errorf("multipart body not restored: want %d bytes got %d bytes", len(original), len(remaining))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_URLEncodedBodyRestored(t *testing.T) {
|
||||||
|
body := "model=whisper-1&extra=value"
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(body))
|
||||||
|
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
if _, err := extractContext(r); err != nil {
|
||||||
|
t.Fatalf("ExtractContext: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
remaining, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading body after ExtractContext: %v", err)
|
||||||
|
}
|
||||||
|
if string(remaining) != body {
|
||||||
|
t.Errorf("url-encoded body not restored: want %q got %q", body, string(remaining))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetContext(t *testing.T) {
|
||||||
|
ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"})
|
||||||
|
data, ok := ctx.Value(ReqContextKey).(ReqContextData)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("ContextKey not set or wrong type")
|
||||||
|
}
|
||||||
|
if data.Model != "llama3" {
|
||||||
|
t.Errorf("want %q got %q", "llama3", data.Model)
|
||||||
|
}
|
||||||
|
if data.ModelID != "llama3" {
|
||||||
|
t.Errorf("want %q got %q", "llama3", data.ModelID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetContext_WithAlias(t *testing.T) {
|
||||||
|
ctx := SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"})
|
||||||
|
data, _ := ctx.Value(ReqContextKey).(ReqContextData)
|
||||||
|
if data.Model != "llama" {
|
||||||
|
t.Errorf("want requested %q got %q", "llama", data.Model)
|
||||||
|
}
|
||||||
|
if data.ModelID != "llama3" {
|
||||||
|
t.Errorf("want real %q got %q", "llama3", data.ModelID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetContext_DoesNotMutateParent(t *testing.T) {
|
||||||
|
parent := context.Background()
|
||||||
|
_ = SetContext(parent, ReqContextData{Model: "llama3", ModelID: "llama3"})
|
||||||
|
if v := parent.Value(ReqContextKey); v != nil {
|
||||||
|
t.Errorf("parent context was mutated: %v", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadContext(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ctx context.Context
|
||||||
|
wantReq string
|
||||||
|
wantReal string
|
||||||
|
wantBool bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "model present, same name",
|
||||||
|
ctx: SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"}),
|
||||||
|
wantReq: "llama3",
|
||||||
|
wantReal: "llama3",
|
||||||
|
wantBool: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model present, aliased",
|
||||||
|
ctx: SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"}),
|
||||||
|
wantReq: "llama",
|
||||||
|
wantReal: "llama3",
|
||||||
|
wantBool: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model absent",
|
||||||
|
ctx: context.Background(),
|
||||||
|
wantReq: "",
|
||||||
|
wantReal: "",
|
||||||
|
wantBool: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model is empty string",
|
||||||
|
ctx: SetContext(context.Background(), ReqContextData{Model: "", ModelID: ""}),
|
||||||
|
wantReq: "",
|
||||||
|
wantReal: "",
|
||||||
|
wantBool: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotData, ok := ReadContext(tt.ctx)
|
||||||
|
if gotData.Model != tt.wantReq || gotData.ModelID != tt.wantReal || ok != tt.wantBool {
|
||||||
|
t.Errorf("want (%q, %q, %v) got (%q, %q, %v)", tt.wantReq, tt.wantReal, tt.wantBool, gotData.Model, gotData.ModelID, ok)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_Streaming_GET(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
query string
|
||||||
|
wantStreaming bool
|
||||||
|
}{
|
||||||
|
{"streaming true", "model=llama3&stream=true", true},
|
||||||
|
{"streaming false", "model=llama3&stream=false", false},
|
||||||
|
{"no stream param", "model=llama3", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
|
||||||
|
got, err := extractContext(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExtractContext: %v", err)
|
||||||
|
}
|
||||||
|
if got.Streaming != tt.wantStreaming {
|
||||||
|
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_Streaming_JSON(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
wantStreaming bool
|
||||||
|
}{
|
||||||
|
{"streaming true", `{"model":"llama3","stream":true}`, true},
|
||||||
|
{"streaming false", `{"model":"llama3","stream":false}`, false},
|
||||||
|
{"no stream param", `{"model":"llama3"}`, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
got, err := extractContext(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExtractContext: %v", err)
|
||||||
|
}
|
||||||
|
if got.Streaming != tt.wantStreaming {
|
||||||
|
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_Streaming_URLEncodedForm(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader("model=whisper-1&stream=true"))
|
||||||
|
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
got, err := extractContext(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExtractContext: %v", err)
|
||||||
|
}
|
||||||
|
if !got.Streaming {
|
||||||
|
t.Error("Streaming should be true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_ApiKey(t *testing.T) {
|
||||||
|
basicHeader := func(user, pass string) string {
|
||||||
|
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
|
||||||
|
}
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
ct string
|
||||||
|
body string
|
||||||
|
auth string
|
||||||
|
xapi string
|
||||||
|
wantKey string
|
||||||
|
}{
|
||||||
|
{"GET bearer", http.MethodGet, "", "", "Bearer sk-get", "", "sk-get"},
|
||||||
|
{"GET x-api-key", http.MethodGet, "", "", "", "xk-get", "xk-get"},
|
||||||
|
{"GET basic", http.MethodGet, "", "", basicHeader("u", "pw-get"), "", "pw-get"},
|
||||||
|
{"JSON bearer", http.MethodPost, "application/json", `{"model":"m"}`, "Bearer sk-json", "", "sk-json"},
|
||||||
|
{"JSON x-api-key", http.MethodPost, "application/json", `{"model":"m"}`, "", "xk-json", "xk-json"},
|
||||||
|
{"form bearer", http.MethodPost, "application/x-www-form-urlencoded", "model=m", "Bearer sk-form", "", "sk-form"},
|
||||||
|
{"no key", http.MethodGet, "", "", "", "", ""},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
var body io.Reader
|
||||||
|
if c.body != "" {
|
||||||
|
body = strings.NewReader(c.body)
|
||||||
|
}
|
||||||
|
r, _ := http.NewRequest(c.method, "/", body)
|
||||||
|
if c.ct != "" {
|
||||||
|
r.Header.Set("Content-Type", c.ct)
|
||||||
|
}
|
||||||
|
if c.auth != "" {
|
||||||
|
r.Header.Set("Authorization", c.auth)
|
||||||
|
}
|
||||||
|
if c.xapi != "" {
|
||||||
|
r.Header.Set("x-api-key", c.xapi)
|
||||||
|
}
|
||||||
|
got, err := extractContext(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("extractContext: %v", err)
|
||||||
|
}
|
||||||
|
if got.ApiKey != c.wantKey {
|
||||||
|
t.Errorf("ApiKey = %q, want %q", got.ApiKey, c.wantKey)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetReqData(t *testing.T) {
|
||||||
|
ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3", Metadata: make(map[string]string)})
|
||||||
|
|
||||||
|
if err := SetReqData(ctx, "client", "web"); err != nil {
|
||||||
|
t.Fatalf("SetReqData: %v", err)
|
||||||
|
}
|
||||||
|
if err := SetReqData(ctx, "trace", "abc123"); err != nil {
|
||||||
|
t.Fatalf("SetReqData: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, ok := ReadContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("context data missing")
|
||||||
|
}
|
||||||
|
if data.Metadata["client"] != "web" {
|
||||||
|
t.Errorf("client = %q, want %q", data.Metadata["client"], "web")
|
||||||
|
}
|
||||||
|
if data.Metadata["trace"] != "abc123" {
|
||||||
|
t.Errorf("trace = %q, want %q", data.Metadata["trace"], "abc123")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetReqData_Errors(t *testing.T) {
|
||||||
|
if err := SetReqData(context.Background(), "k", "v"); err == nil {
|
||||||
|
t.Error("expected error when no request context data exists")
|
||||||
|
}
|
||||||
|
ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"})
|
||||||
|
if err := SetReqData(ctx, "k", "v"); err == nil {
|
||||||
|
t.Error("expected error when metadata map is missing")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_ExtractAPIKey(t *testing.T) {
|
||||||
|
basicHeader := func(user, pass string) string {
|
||||||
|
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
|
||||||
|
}
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
auth string
|
||||||
|
xapi string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"none", "", "", ""},
|
||||||
|
{"bearer", "Bearer tok123", "", "tok123"},
|
||||||
|
{"basic", basicHeader("user", "pw-key"), "", "pw-key"},
|
||||||
|
{"x-api-key", "", "xkey", "xkey"},
|
||||||
|
{"basic beats bearer", basicHeader("u", "bk"), "", "bk"},
|
||||||
|
{"bearer beats x-api-key", "Bearer btok", "xkey", "btok"},
|
||||||
|
{"malformed basic falls back to x-api-key", "Basic !!!notbase64", "xkey", "xkey"},
|
||||||
|
{"lowercase bearer", "bearer tok123", "", "tok123"},
|
||||||
|
{"lowercase basic", "basic " + base64.StdEncoding.EncodeToString([]byte("user:pw-key")), "", "pw-key"},
|
||||||
|
{"mixed case BEARER", "BEARER tok456", "", "tok456"},
|
||||||
|
{"mixed case bAsIc", "bAsIc " + base64.StdEncoding.EncodeToString([]byte("u:bk")), "", "bk"},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if c.auth != "" {
|
||||||
|
r.Header.Set("Authorization", c.auth)
|
||||||
|
}
|
||||||
|
if c.xapi != "" {
|
||||||
|
r.Header.Set("x-api-key", c.xapi)
|
||||||
|
}
|
||||||
|
if got := ExtractAPIKey(r); got != c.want {
|
||||||
|
t.Errorf("extractAPIKey() = %q, want %q", got, c.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchContext_UpstreamPath(t *testing.T) {
|
||||||
|
cfg := config.Config{
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"m1": {},
|
||||||
|
"author/model": {},
|
||||||
|
"real": {Aliases: []string{"nick"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
wantModel string
|
||||||
|
wantModelID string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"known model", "/upstream/m1/v1/chat/completions", "m1", "m1", false},
|
||||||
|
{"model with slash", "/upstream/author/model/v1/chat", "author/model", "author/model", false},
|
||||||
|
{"unknown model", "/upstream/nope/v1/chat/completions", "", "", true},
|
||||||
|
{"bare model path", "/upstream/m1/", "m1", "m1", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodPost, c.path, strings.NewReader(`{}`))
|
||||||
|
data, err := FetchContext(r, cfg)
|
||||||
|
if (err != nil) != c.wantErr {
|
||||||
|
t.Fatalf("wantErr=%v got err=%v", c.wantErr, err)
|
||||||
|
}
|
||||||
|
if c.wantErr {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if data.Model != c.wantModel {
|
||||||
|
t.Errorf("model = %q, want %q", data.Model, c.wantModel)
|
||||||
|
}
|
||||||
|
if data.ModelID != c.wantModelID {
|
||||||
|
t.Errorf("modelID = %q, want %q", data.ModelID, c.wantModelID)
|
||||||
|
}
|
||||||
|
if data.Metadata == nil {
|
||||||
|
t.Error("metadata map not initialized")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchContext_UpstreamPath_DoesNotReadBody(t *testing.T) {
|
||||||
|
cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
|
||||||
|
body := `{"model":"should-not-matter"}`
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/chat/completions", strings.NewReader(body))
|
||||||
|
|
||||||
|
_, err := FetchContext(r, cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FetchContext: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The body should be untouched so the upstream handler can still read it.
|
||||||
|
got, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read body: %v", err)
|
||||||
|
}
|
||||||
|
if string(got) != body {
|
||||||
|
t.Errorf("body was consumed: %q", string(got))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
package shared
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HTTPError is an error that carries a complete HTTP response. A producer (e.g.
|
||||||
|
// a scheduler shedding a request) returns one of these; a renderer (e.g.
|
||||||
|
// router.SendError) writes the status, headers, and body verbatim instead of
|
||||||
|
// mapping the error to a generic status. It is the seam that lets a component
|
||||||
|
// shed a request with a rich response (e.g. a 429 with rate-limit headers and a
|
||||||
|
// JSON hint body) without the renderer knowing the producer's internals.
|
||||||
|
type HTTPError interface {
|
||||||
|
error
|
||||||
|
StatusCode() int
|
||||||
|
Header() http.Header
|
||||||
|
Body() []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConcurrencyLimitError is an HTTPError for a 429 concurrency-limit rejection.
|
||||||
|
// Zero-value fields fall back to sensible defaults: a 1-second Retry-After and a
|
||||||
|
// JSON hint body.
|
||||||
|
type ConcurrencyLimitError struct {
|
||||||
|
// RetryAfter, when > 0, is sent as the Retry-After header (in seconds).
|
||||||
|
// Defaults to 1.
|
||||||
|
RetryAfter int
|
||||||
|
|
||||||
|
// Message overrides the JSON body's "error" field. Defaults to
|
||||||
|
// "Too many requests".
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ConcurrencyLimitError) Error() string { return "concurrency limit reached" }
|
||||||
|
|
||||||
|
func (e ConcurrencyLimitError) StatusCode() int { return http.StatusTooManyRequests }
|
||||||
|
|
||||||
|
func (e ConcurrencyLimitError) Header() http.Header {
|
||||||
|
h := http.Header{}
|
||||||
|
h.Set("Content-Type", "application/json")
|
||||||
|
h.Set("Retry-After", e.retryAfter())
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ConcurrencyLimitError) Body() []byte {
|
||||||
|
b, _ := json.Marshal(map[string]string{"error": e.message()})
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ConcurrencyLimitError) retryAfter() string {
|
||||||
|
if e.RetryAfter > 0 {
|
||||||
|
return strconv.Itoa(e.RetryAfter)
|
||||||
|
}
|
||||||
|
return "1"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ConcurrencyLimitError) message() string {
|
||||||
|
if e.Message != "" {
|
||||||
|
return e.Message
|
||||||
|
}
|
||||||
|
return "Too many requests"
|
||||||
|
}
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
package shared
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
// IsLoopbackAddr reports whether listenAddr binds exclusively to loopback.
|
||||||
|
// Addresses with an empty or wildcard host (e.g. ":8080", "0.0.0.0:8080",
|
||||||
|
// "[::]:8080") bind on all interfaces and return false.
|
||||||
|
func IsLoopbackAddr(listenAddr string) bool {
|
||||||
|
host, _, err := net.SplitHostPort(listenAddr)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if host == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
ip := net.ParseIP(host)
|
||||||
|
if ip != nil {
|
||||||
|
return ip.IsLoopback()
|
||||||
|
}
|
||||||
|
// hostname case (e.g. "localhost")
|
||||||
|
addrs, err := net.LookupHost(host)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, a := range addrs {
|
||||||
|
if !net.ParseIP(a).IsLoopback() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(addrs) > 0
|
||||||
|
}
|
||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
@@ -262,6 +263,11 @@ func main() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
if !shared.IsLoopbackAddr(listenAddr) {
|
||||||
|
_, port, _ := net.SplitHostPort(listenAddr)
|
||||||
|
proxyLog.Infof("llama-swap is reachable by all hosts on the network, use -listen localhost:%s to restrict to loopback only", port)
|
||||||
|
}
|
||||||
|
|
||||||
exitChan := make(chan struct{})
|
exitChan := make(chan struct{})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
import Performance from "./routes/Performance.svelte";
|
import Performance from "./routes/Performance.svelte";
|
||||||
import Playground from "./routes/Playground.svelte";
|
import Playground from "./routes/Playground.svelte";
|
||||||
import PlaygroundStub from "./routes/PlaygroundStub.svelte";
|
import PlaygroundStub from "./routes/PlaygroundStub.svelte";
|
||||||
import { enableAPIEvents } from "./stores/api";
|
import { enableAPIEvents, checkPerformanceEnabled } from "./stores/api";
|
||||||
import { initScreenWidth, initSystemThemeListener, isDarkMode, appTitle, connectionState } from "./stores/theme";
|
import { initScreenWidth, initSystemThemeListener, isDarkMode, appTitle, connectionState } from "./stores/theme";
|
||||||
import { currentRoute } from "./stores/route";
|
import { currentRoute } from "./stores/route";
|
||||||
|
|
||||||
@@ -39,6 +39,7 @@
|
|||||||
const cleanupScreenWidth = initScreenWidth();
|
const cleanupScreenWidth = initScreenWidth();
|
||||||
const cleanupSystemTheme = initSystemThemeListener();
|
const cleanupSystemTheme = initSystemThemeListener();
|
||||||
enableAPIEvents(true);
|
enableAPIEvents(true);
|
||||||
|
checkPerformanceEnabled();
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
cleanupScreenWidth();
|
cleanupScreenWidth();
|
||||||
|
|||||||
@@ -193,7 +193,7 @@
|
|||||||
<dialog
|
<dialog
|
||||||
bind:this={dialogEl}
|
bind:this={dialogEl}
|
||||||
onclose={handleDialogClose}
|
onclose={handleDialogClose}
|
||||||
class="bg-surface text-txtmain rounded-lg shadow-xl max-w-4xl w-full max-h-[90vh] p-0 backdrop:bg-black/50 m-auto"
|
class="bg-surface text-txtmain rounded-lg shadow-xl max-w-[80%] w-full max-h-[90vh] p-0 backdrop:bg-black/50 m-auto"
|
||||||
>
|
>
|
||||||
{#if capture}
|
{#if capture}
|
||||||
<div class="flex flex-col max-h-[90vh]">
|
<div class="flex flex-col max-h-[90vh]">
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
import { screenWidth, toggleTheme, themeMode, appTitle, isNarrow } from "../stores/theme";
|
import { screenWidth, toggleTheme, themeMode, appTitle, isNarrow } from "../stores/theme";
|
||||||
import { currentRoute } from "../stores/route";
|
import { currentRoute } from "../stores/route";
|
||||||
import { playgroundActivity } from "../stores/playgroundActivity";
|
import { playgroundActivity } from "../stores/playgroundActivity";
|
||||||
|
import { performanceEnabled } from "../stores/api";
|
||||||
import ConnectionStatus from "./ConnectionStatus.svelte";
|
import ConnectionStatus from "./ConnectionStatus.svelte";
|
||||||
|
|
||||||
function handleTitleChange(newTitle: string): void {
|
function handleTitleChange(newTitle: string): void {
|
||||||
@@ -84,16 +85,18 @@
|
|||||||
>
|
>
|
||||||
Logs
|
Logs
|
||||||
</a>
|
</a>
|
||||||
<a
|
{#if $performanceEnabled}
|
||||||
href="/performance"
|
<a
|
||||||
use:link
|
href="/performance"
|
||||||
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
use:link
|
||||||
class:font-semibold={isActive("/performance", $currentRoute)}
|
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
||||||
class:underline={isActive("/performance", $currentRoute)}
|
class:font-semibold={isActive("/performance", $currentRoute)}
|
||||||
class:underline-offset-4={isActive("/performance", $currentRoute)}
|
class:underline={isActive("/performance", $currentRoute)}
|
||||||
>
|
class:underline-offset-4={isActive("/performance", $currentRoute)}
|
||||||
Performance
|
>
|
||||||
</a>
|
Performance
|
||||||
|
</a>
|
||||||
|
{/if}
|
||||||
<button onclick={toggleTheme} title="Toggle theme (current: {$themeMode})">
|
<button onclick={toggleTheme} title="Toggle theme (current: {$themeMode})">
|
||||||
{#if $themeMode === "system"}
|
{#if $themeMode === "system"}
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||||
|
|||||||
@@ -0,0 +1,85 @@
|
|||||||
|
<script lang="ts">
|
||||||
|
import type { Snippet } from "svelte";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
metadata: Record<string, string> | undefined;
|
||||||
|
children: Snippet;
|
||||||
|
}
|
||||||
|
|
||||||
|
let { metadata, children }: Props = $props();
|
||||||
|
|
||||||
|
let entries = $derived(Object.entries(metadata || {}));
|
||||||
|
let triggerEl: HTMLElement | undefined = $state();
|
||||||
|
let tooltipEl: HTMLDivElement | undefined = $state();
|
||||||
|
let show = $state(false);
|
||||||
|
let tooltipStyle = $state("");
|
||||||
|
|
||||||
|
function positionTooltip() {
|
||||||
|
if (!triggerEl || !tooltipEl) return;
|
||||||
|
const triggerRect = triggerEl.getBoundingClientRect();
|
||||||
|
const tooltipRect = tooltipEl.getBoundingClientRect();
|
||||||
|
const margin = 8;
|
||||||
|
const viewportWidth = window.innerWidth;
|
||||||
|
const viewportHeight = window.innerHeight;
|
||||||
|
|
||||||
|
let left = triggerRect.left;
|
||||||
|
let top = triggerRect.bottom + margin;
|
||||||
|
|
||||||
|
// Keep tooltip within horizontal viewport bounds
|
||||||
|
if (left + tooltipRect.width > viewportWidth - margin) {
|
||||||
|
left = triggerRect.right - tooltipRect.width;
|
||||||
|
}
|
||||||
|
if (left < margin) {
|
||||||
|
left = margin;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flip above trigger if it would overflow the bottom
|
||||||
|
if (top + tooltipRect.height > viewportHeight - margin) {
|
||||||
|
top = triggerRect.top - tooltipRect.height - margin;
|
||||||
|
}
|
||||||
|
|
||||||
|
tooltipStyle = `left: ${left}px; top: ${top}px; max-width: calc(100vw - ${margin * 2}px);`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function onEnter() {
|
||||||
|
show = true;
|
||||||
|
requestAnimationFrame(positionTooltip);
|
||||||
|
}
|
||||||
|
|
||||||
|
function onLeave() {
|
||||||
|
show = false;
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<span
|
||||||
|
bind:this={triggerEl}
|
||||||
|
onmouseenter={onEnter}
|
||||||
|
onmouseleave={onLeave}
|
||||||
|
onfocus={onEnter}
|
||||||
|
onblur={onLeave}
|
||||||
|
class="inline-flex"
|
||||||
|
role="button"
|
||||||
|
tabindex="0"
|
||||||
|
aria-label="Show metadata"
|
||||||
|
>
|
||||||
|
{@render children()}
|
||||||
|
</span>
|
||||||
|
|
||||||
|
{#if show && entries.length > 0}
|
||||||
|
<div
|
||||||
|
bind:this={tooltipEl}
|
||||||
|
style={tooltipStyle}
|
||||||
|
class="fixed px-3 py-2 bg-gray-900 text-white text-sm rounded-md z-50 normal-case min-w-[12rem] max-w-[24rem] shadow-lg whitespace-normal"
|
||||||
|
>
|
||||||
|
<table class="w-full text-left">
|
||||||
|
<tbody>
|
||||||
|
{#each entries as [key, value]}
|
||||||
|
<tr class="border-b border-white/10 last:border-0">
|
||||||
|
<td class="py-1 pr-3 font-medium whitespace-nowrap text-primary">{key}</td>
|
||||||
|
<td class="py-1 break-all">{value}</td>
|
||||||
|
</tr>
|
||||||
|
{/each}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
@@ -6,6 +6,8 @@
|
|||||||
|
|
||||||
let isUnloading = $state(false);
|
let isUnloading = $state(false);
|
||||||
let menuOpen = $state(false);
|
let menuOpen = $state(false);
|
||||||
|
let pendingLoads = $state<Record<string, boolean>>({});
|
||||||
|
const loadControllers = new Map<string, AbortController>();
|
||||||
|
|
||||||
const showUnlistedStore = persistentStore<boolean>("showUnlisted", true);
|
const showUnlistedStore = persistentStore<boolean>("showUnlisted", true);
|
||||||
const showIdorNameStore = persistentStore<"id" | "name">("showIdorName", "id");
|
const showIdorNameStore = persistentStore<"id" | "name">("showIdorName", "id");
|
||||||
@@ -42,6 +44,25 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function handleLoadModel(modelId: string): Promise<void> {
|
||||||
|
if (pendingLoads[modelId]) return;
|
||||||
|
const controller = new AbortController();
|
||||||
|
loadControllers.set(modelId, controller);
|
||||||
|
pendingLoads[modelId] = true;
|
||||||
|
try {
|
||||||
|
await loadModel(modelId, controller.signal);
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
} finally {
|
||||||
|
loadControllers.delete(modelId);
|
||||||
|
delete pendingLoads[modelId];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function cancelLoad(modelId: string): void {
|
||||||
|
loadControllers.get(modelId)?.abort();
|
||||||
|
}
|
||||||
|
|
||||||
function toggleIdorName(): void {
|
function toggleIdorName(): void {
|
||||||
showIdorNameStore.update((prev) => (prev === "name" ? "id" : "name"));
|
showIdorNameStore.update((prev) => (prev === "name" ? "id" : "name"));
|
||||||
}
|
}
|
||||||
@@ -170,14 +191,20 @@
|
|||||||
{/if}
|
{/if}
|
||||||
</td>
|
</td>
|
||||||
<td class="w-12">
|
<td class="w-12">
|
||||||
{#if model.state === "stopped"}
|
{#if model.state === "stopped" && pendingLoads[model.id]}
|
||||||
<button class="btn btn--sm" onclick={() => loadModel(model.id)}>Load</button>
|
<button class="btn btn--sm" onclick={() => cancelLoad(model.id)}>Cancel</button>
|
||||||
|
{:else if model.state === "stopped"}
|
||||||
|
<button class="btn btn--sm" onclick={() => handleLoadModel(model.id)}>Load</button>
|
||||||
{:else}
|
{:else}
|
||||||
<button class="btn btn--sm" onclick={() => unloadSingleModel(model.id)} disabled={model.state !== "ready"}>Unload</button>
|
<button class="btn btn--sm" onclick={() => unloadSingleModel(model.id)} disabled={model.state !== "ready"}>Unload</button>
|
||||||
{/if}
|
{/if}
|
||||||
</td>
|
</td>
|
||||||
<td class="w-20">
|
<td class="w-20">
|
||||||
<span class="w-16 text-center status status--{model.state}">{model.state}</span>
|
{#if model.state === "stopped" && pendingLoads[model.id]}
|
||||||
|
<span class="w-16 text-center status status--queued">queued</span>
|
||||||
|
{:else}
|
||||||
|
<span class="w-16 text-center status status--{model.state}">{model.state}</span>
|
||||||
|
{/if}
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
{/each}
|
{/each}
|
||||||
|
|||||||
@@ -145,7 +145,7 @@
|
|||||||
<div class="flex flex-col h-full">
|
<div class="flex flex-col h-full">
|
||||||
<!-- Model selector -->
|
<!-- Model selector -->
|
||||||
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
|
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
|
||||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select an audio model..." disabled={isTranscribing} />
|
<ModelSelector bind:value={$selectedModelStore} placeholder="Select an audio model..." disabled={isTranscribing} capabilities={["audio_transcriptions"]} />
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Empty state for no models configured -->
|
<!-- Empty state for no models configured -->
|
||||||
|
|||||||
@@ -193,7 +193,7 @@
|
|||||||
<div class="flex flex-col h-full">
|
<div class="flex flex-col h-full">
|
||||||
<!-- Model selector and mode toggle -->
|
<!-- Model selector and mode toggle -->
|
||||||
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
|
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
|
||||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select an image model..." disabled={isGenerating} />
|
<ModelSelector bind:value={$selectedModelStore} placeholder="Select an image model..." disabled={isGenerating} capabilities={["image_generation", "image_to_image"]} matchAny={true} />
|
||||||
|
|
||||||
<select
|
<select
|
||||||
class="px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
class="px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||||
|
|||||||
@@ -6,12 +6,15 @@
|
|||||||
value: string;
|
value: string;
|
||||||
placeholder?: string;
|
placeholder?: string;
|
||||||
disabled?: boolean;
|
disabled?: boolean;
|
||||||
|
capabilities?: string[];
|
||||||
|
matchAny?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
let { value = $bindable(), placeholder = "Select a model...", disabled = false }: Props = $props();
|
let { value = $bindable(), placeholder = "Select a model...", disabled = false, capabilities, matchAny = false }: Props = $props();
|
||||||
|
|
||||||
let grouped = $derived(groupModels($models));
|
let grouped = $derived(groupModels($models, capabilities, matchAny));
|
||||||
let hasModels = $derived(grouped.local.length > 0 || Object.keys(grouped.peersByProvider).length > 0);
|
let hasMatching = $derived(grouped.localMatching.length > 0);
|
||||||
|
let hasModels = $derived(hasMatching || grouped.local.length > 0 || Object.keys(grouped.peersByProvider).length > 0);
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
{#if hasModels}
|
{#if hasModels}
|
||||||
@@ -21,6 +24,18 @@
|
|||||||
{disabled}
|
{disabled}
|
||||||
>
|
>
|
||||||
<option value="">{placeholder}</option>
|
<option value="">{placeholder}</option>
|
||||||
|
{#if hasMatching}
|
||||||
|
<optgroup label="Matching Capabilities">
|
||||||
|
{#each grouped.localMatching as model (model.id)}
|
||||||
|
<option value={model.id}>{model.id}</option>
|
||||||
|
{#if model.aliases}
|
||||||
|
{#each model.aliases as alias (alias)}
|
||||||
|
<option value={alias}> ↳ {alias}</option>
|
||||||
|
{/each}
|
||||||
|
{/if}
|
||||||
|
{/each}
|
||||||
|
</optgroup>
|
||||||
|
{/if}
|
||||||
{#if grouped.local.length > 0}
|
{#if grouped.local.length > 0}
|
||||||
<optgroup label="Local">
|
<optgroup label="Local">
|
||||||
{#each grouped.local as model (model.id)}
|
{#each grouped.local as model (model.id)}
|
||||||
|
|||||||
@@ -264,7 +264,7 @@
|
|||||||
<div class="flex flex-col h-full">
|
<div class="flex flex-col h-full">
|
||||||
<!-- Top bar: model selector + query input (table mode) + mode toggle -->
|
<!-- Top bar: model selector + query input (table mode) + mode toggle -->
|
||||||
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
|
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
|
||||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select a rerank model..." disabled={isLoading} />
|
<ModelSelector bind:value={$selectedModelStore} placeholder="Select a rerank model..." disabled={isLoading} capabilities={["reranker"]} />
|
||||||
{#if editorMode === "table"}
|
{#if editorMode === "table"}
|
||||||
<input
|
<input
|
||||||
type="text"
|
type="text"
|
||||||
|
|||||||
@@ -206,7 +206,7 @@
|
|||||||
<div class="flex flex-col h-full">
|
<div class="flex flex-col h-full">
|
||||||
<!-- Model and voice selectors -->
|
<!-- Model and voice selectors -->
|
||||||
<div class="shrink-0 flex gap-2 mb-4">
|
<div class="shrink-0 flex gap-2 mb-4">
|
||||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select a speech model..." disabled={isGenerating} />
|
<ModelSelector bind:value={$selectedModelStore} placeholder="Select a speech model..." disabled={isGenerating} capabilities={["audio_speech"]} />
|
||||||
<div class="flex gap-2">
|
<div class="flex gap-2">
|
||||||
<select
|
<select
|
||||||
class="shrink-0 px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
class="shrink-0 px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||||
|
|||||||
@@ -139,7 +139,8 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
.status--starting,
|
.status--starting,
|
||||||
.status--stopping {
|
.status--stopping,
|
||||||
|
.status--queued {
|
||||||
@apply bg-warning/10 text-warning;
|
@apply bg-warning/10 text-warning;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,113 @@
|
|||||||
|
import { describe, it, expect } from "vitest";
|
||||||
|
import { matchesCapabilities, groupModels } from "./modelUtils";
|
||||||
|
import type { Model } from "./types";
|
||||||
|
|
||||||
|
function makeModel(overrides: Partial<Model> = {}): Model {
|
||||||
|
return {
|
||||||
|
id: "test-model",
|
||||||
|
state: "ready",
|
||||||
|
name: "Test Model",
|
||||||
|
description: "",
|
||||||
|
unlisted: false,
|
||||||
|
peerID: "",
|
||||||
|
...overrides,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
describe("matchesCapabilities", () => {
|
||||||
|
it("returns true when required is empty", () => {
|
||||||
|
const model = makeModel();
|
||||||
|
expect(matchesCapabilities(model, [])).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("returns false when model has no capabilities", () => {
|
||||||
|
const model = makeModel();
|
||||||
|
expect(matchesCapabilities(model, ["vision"])).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("returns false when model has empty capabilities object", () => {
|
||||||
|
const model = makeModel({ capabilities: {} });
|
||||||
|
expect(matchesCapabilities(model, ["vision"])).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("returns true when model has the single required capability", () => {
|
||||||
|
const model = makeModel({ capabilities: { vision: true } });
|
||||||
|
expect(matchesCapabilities(model, ["vision"])).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("returns false when model lacks the required capability", () => {
|
||||||
|
const model = makeModel({ capabilities: { vision: true } });
|
||||||
|
expect(matchesCapabilities(model, ["audio_transcriptions"])).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("AND semantics: returns true only when all required are present", () => {
|
||||||
|
const model = makeModel({ capabilities: { vision: true, audio_transcriptions: true } });
|
||||||
|
expect(matchesCapabilities(model, ["vision", "audio_transcriptions"])).toBe(true);
|
||||||
|
expect(matchesCapabilities(model, ["vision", "reranker"])).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("matchAny=true: returns true when at least one required is present", () => {
|
||||||
|
const model = makeModel({ capabilities: { vision: true } });
|
||||||
|
expect(matchesCapabilities(model, ["vision", "reranker"], true)).toBe(true);
|
||||||
|
expect(matchesCapabilities(model, ["audio_transcriptions", "reranker"], true)).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("matchAny=true with empty required returns true", () => {
|
||||||
|
const model = makeModel();
|
||||||
|
expect(matchesCapabilities(model, [], true)).toBe(true);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("groupModels", () => {
|
||||||
|
const models: Model[] = [
|
||||||
|
makeModel({ id: "chat-model", capabilities: { vision: true } }),
|
||||||
|
makeModel({ id: "audio-model", capabilities: { audio_transcriptions: true } }),
|
||||||
|
makeModel({ id: "no-caps-model" }),
|
||||||
|
makeModel({ id: "peer-model", peerID: "peer1" }),
|
||||||
|
makeModel({ id: "unlisted-model", unlisted: true, capabilities: { vision: true } }),
|
||||||
|
];
|
||||||
|
|
||||||
|
it("filters out unlisted models", () => {
|
||||||
|
const result = groupModels(models);
|
||||||
|
expect(result.localMatching.length + result.local.length).toBe(3);
|
||||||
|
expect([...result.localMatching, ...result.local].every((m) => !m.unlisted)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("separates peer models into peersByProvider", () => {
|
||||||
|
const result = groupModels(models);
|
||||||
|
expect(result.peersByProvider["peer1"]).toHaveLength(1);
|
||||||
|
expect(result.peersByProvider["peer1"][0].id).toBe("peer-model");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("without capabilities, all local models go to local (non-matching)", () => {
|
||||||
|
const result = groupModels(models);
|
||||||
|
expect(result.localMatching).toHaveLength(0);
|
||||||
|
expect(result.local).toHaveLength(3);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("with capabilities, matching models go to localMatching", () => {
|
||||||
|
const result = groupModels(models, ["vision"]);
|
||||||
|
expect(result.localMatching).toHaveLength(1);
|
||||||
|
expect(result.localMatching[0].id).toBe("chat-model");
|
||||||
|
expect(result.local).toHaveLength(2);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("with capabilities, models without capabilities go to local", () => {
|
||||||
|
const result = groupModels(models, ["vision"]);
|
||||||
|
expect(result.local.find((m) => m.id === "no-caps-model")).toBeDefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("with matchAny, matches models with any listed capability", () => {
|
||||||
|
const result = groupModels(models, ["vision", "audio_transcriptions"], true);
|
||||||
|
expect(result.localMatching).toHaveLength(2);
|
||||||
|
expect(result.localMatching.map((m) => m.id)).toContain("chat-model");
|
||||||
|
expect(result.localMatching.map((m) => m.id)).toContain("audio-model");
|
||||||
|
expect(result.local).toHaveLength(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("with empty capabilities array, all local go to local (non-matching)", () => {
|
||||||
|
const result = groupModels(models, []);
|
||||||
|
expect(result.localMatching).toHaveLength(0);
|
||||||
|
expect(result.local).toHaveLength(3);
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -2,14 +2,40 @@ import type { Model } from "./types";
|
|||||||
|
|
||||||
export interface GroupedModels {
|
export interface GroupedModels {
|
||||||
local: Model[];
|
local: Model[];
|
||||||
|
localMatching: Model[];
|
||||||
peersByProvider: Record<string, Model[]>;
|
peersByProvider: Record<string, Model[]>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function groupModels(models: Model[]): GroupedModels {
|
export function matchesCapabilities(model: Model, required: string[], matchAny = false): boolean {
|
||||||
|
if (!required.length) return true;
|
||||||
|
if (!model.capabilities) return false;
|
||||||
|
const caps = model.capabilities as Record<string, boolean>;
|
||||||
|
if (matchAny) {
|
||||||
|
return required.some((cap) => caps[cap] === true);
|
||||||
|
}
|
||||||
|
return required.every((cap) => caps[cap] === true);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function groupModels(models: Model[], capabilities?: string[], matchAny = false): GroupedModels {
|
||||||
const available = models.filter((m) => !m.unlisted);
|
const available = models.filter((m) => !m.unlisted);
|
||||||
const local = available.filter((m) => !m.peerID);
|
const local = available.filter((m) => !m.peerID);
|
||||||
const peerModels = available.filter((m) => m.peerID);
|
const peerModels = available.filter((m) => m.peerID);
|
||||||
|
|
||||||
|
let localMatching: Model[] = [];
|
||||||
|
let localRest: Model[] = [];
|
||||||
|
|
||||||
|
if (capabilities && capabilities.length > 0) {
|
||||||
|
for (const model of local) {
|
||||||
|
if (matchesCapabilities(model, capabilities, matchAny)) {
|
||||||
|
localMatching.push(model);
|
||||||
|
} else {
|
||||||
|
localRest.push(model);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
localRest = local;
|
||||||
|
}
|
||||||
|
|
||||||
const peersByProvider = peerModels.reduce(
|
const peersByProvider = peerModels.reduce(
|
||||||
(acc, model) => {
|
(acc, model) => {
|
||||||
const peerId = model.peerID || "unknown";
|
const peerId = model.peerID || "unknown";
|
||||||
@@ -20,5 +46,5 @@ export function groupModels(models: Model[]): GroupedModels {
|
|||||||
{} as Record<string, Model[]>
|
{} as Record<string, Model[]>
|
||||||
);
|
);
|
||||||
|
|
||||||
return { local, peersByProvider };
|
return { local: localRest, localMatching, peersByProvider };
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,16 @@ export type ConnectionState = "connected" | "connecting" | "disconnected";
|
|||||||
|
|
||||||
export type ModelStatus = "ready" | "starting" | "stopping" | "stopped" | "shutdown" | "unknown";
|
export type ModelStatus = "ready" | "starting" | "stopping" | "stopped" | "shutdown" | "unknown";
|
||||||
|
|
||||||
|
export interface ModelCapabilities {
|
||||||
|
vision?: boolean;
|
||||||
|
audio_transcriptions?: boolean;
|
||||||
|
audio_speech?: boolean;
|
||||||
|
image_generation?: boolean;
|
||||||
|
image_to_image?: boolean;
|
||||||
|
function_calling?: boolean;
|
||||||
|
reranker?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
export interface Model {
|
export interface Model {
|
||||||
id: string;
|
id: string;
|
||||||
state: ModelStatus;
|
state: ModelStatus;
|
||||||
@@ -10,10 +20,13 @@ export interface Model {
|
|||||||
unlisted: boolean;
|
unlisted: boolean;
|
||||||
peerID: string;
|
peerID: string;
|
||||||
aliases?: string[];
|
aliases?: string[];
|
||||||
|
capabilities?: ModelCapabilities;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface TokenMetrics {
|
export interface TokenMetrics {
|
||||||
cache_tokens: number;
|
cache_tokens: number;
|
||||||
|
draft_tokens: number;
|
||||||
|
draft_acc_tokens: number;
|
||||||
input_tokens: number;
|
input_tokens: number;
|
||||||
output_tokens: number;
|
output_tokens: number;
|
||||||
prompt_per_second: number;
|
prompt_per_second: number;
|
||||||
@@ -30,6 +43,8 @@ export interface ActivityLogEntry {
|
|||||||
tokens: TokenMetrics;
|
tokens: TokenMetrics;
|
||||||
duration_ms: number;
|
duration_ms: number;
|
||||||
has_capture: boolean;
|
has_capture: boolean;
|
||||||
|
error_msg?: string;
|
||||||
|
metadata?: Record<string, string>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ReqRespCapture {
|
export interface ReqRespCapture {
|
||||||
|
|||||||
@@ -2,25 +2,13 @@
|
|||||||
import { metrics, getCapture } from "../stores/api";
|
import { metrics, getCapture } from "../stores/api";
|
||||||
import ActivityStats from "../components/ActivityStats.svelte";
|
import ActivityStats from "../components/ActivityStats.svelte";
|
||||||
import Tooltip from "../components/Tooltip.svelte";
|
import Tooltip from "../components/Tooltip.svelte";
|
||||||
|
import MetadataTooltip from "../components/MetadataTooltip.svelte";
|
||||||
import CaptureDialog from "../components/CaptureDialog.svelte";
|
import CaptureDialog from "../components/CaptureDialog.svelte";
|
||||||
import { persistentStore } from "../stores/persistent";
|
import { persistentStore } from "../stores/persistent";
|
||||||
import { onMount } from "svelte";
|
import { onMount } from "svelte";
|
||||||
import type { ReqRespCapture } from "../lib/types";
|
import type { ReqRespCapture } from "../lib/types";
|
||||||
|
|
||||||
type ColumnKey =
|
type ColumnKey = string;
|
||||||
| "id"
|
|
||||||
| "time"
|
|
||||||
| "model"
|
|
||||||
| "req_path"
|
|
||||||
| "resp_status_code"
|
|
||||||
| "resp_content_type"
|
|
||||||
| "cached"
|
|
||||||
| "prompt"
|
|
||||||
| "generated"
|
|
||||||
| "prompt_speed"
|
|
||||||
| "gen_speed"
|
|
||||||
| "duration"
|
|
||||||
| "capture";
|
|
||||||
|
|
||||||
interface ColumnDef {
|
interface ColumnDef {
|
||||||
key: ColumnKey;
|
key: ColumnKey;
|
||||||
@@ -33,26 +21,31 @@
|
|||||||
{ key: "time", label: "Time", defaultVisible: true },
|
{ key: "time", label: "Time", defaultVisible: true },
|
||||||
{ key: "model", label: "Model", defaultVisible: true },
|
{ key: "model", label: "Model", defaultVisible: true },
|
||||||
{ key: "req_path", label: "Path", defaultVisible: false },
|
{ key: "req_path", label: "Path", defaultVisible: false },
|
||||||
{ key: "resp_status_code", label: "Status", defaultVisible: false },
|
{ key: "resp_status_code", label: "Status", defaultVisible: true },
|
||||||
{ key: "resp_content_type", label: "Content-Type", defaultVisible: false },
|
{ key: "resp_content_type", label: "Content-Type", defaultVisible: false },
|
||||||
{ key: "cached", label: "Cached", defaultVisible: true },
|
{ key: "cached", label: "Cached", defaultVisible: true },
|
||||||
{ key: "prompt", label: "Prompt", defaultVisible: true },
|
{ key: "prompt", label: "Prompt", defaultVisible: true },
|
||||||
{ key: "generated", label: "Generated", defaultVisible: true },
|
{ key: "generated", label: "Generated", defaultVisible: true },
|
||||||
|
{ key: "drafted", label: "Drafted", defaultVisible: false },
|
||||||
{ key: "prompt_speed", label: "Prompt Speed", defaultVisible: true },
|
{ key: "prompt_speed", label: "Prompt Speed", defaultVisible: true },
|
||||||
{ key: "gen_speed", label: "Gen Speed", defaultVisible: true },
|
{ key: "gen_speed", label: "Gen Speed", defaultVisible: true },
|
||||||
{ key: "duration", label: "Duration", defaultVisible: true },
|
{ key: "duration", label: "Duration", defaultVisible: true },
|
||||||
{ key: "capture", label: "Capture", defaultVisible: true },
|
{ key: "capture", label: "Capture", defaultVisible: true },
|
||||||
|
{ key: "meta", label: "Meta", defaultVisible: false },
|
||||||
];
|
];
|
||||||
|
|
||||||
const defaultVisibleKeys = columns.filter((c) => c.defaultVisible).map((c) => c.key);
|
const defaultVisibleKeys = columns.filter((c) => c.defaultVisible).map((c) => c.key);
|
||||||
|
|
||||||
const visibleColumns = persistentStore<ColumnKey[]>(
|
const visibleColumns = persistentStore<ColumnKey[]>("activity-columns", defaultVisibleKeys);
|
||||||
"activity-columns",
|
const columnOrder = persistentStore<ColumnKey[]>(
|
||||||
defaultVisibleKeys
|
"activity-column-order",
|
||||||
|
columns.map((c) => c.key)
|
||||||
);
|
);
|
||||||
|
|
||||||
let columnsMenuOpen = $state(false);
|
let columnsMenuOpen = $state(false);
|
||||||
let dropdownContainer: HTMLDivElement | null = null;
|
let dropdownContainer: HTMLDivElement | null = null;
|
||||||
|
let dragKey: ColumnKey | null = $state(null);
|
||||||
|
let dragOverKey: ColumnKey | null = $state(null);
|
||||||
|
|
||||||
onMount(() => {
|
onMount(() => {
|
||||||
function handleKeydown(e: KeyboardEvent) {
|
function handleKeydown(e: KeyboardEvent) {
|
||||||
@@ -84,10 +77,92 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function isColumnVisible(key: ColumnKey): boolean {
|
||||||
|
return $visibleColumns.includes(key);
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleDragStart(e: DragEvent, key: ColumnKey) {
|
||||||
|
dragKey = key;
|
||||||
|
e.dataTransfer?.setData("text/plain", key);
|
||||||
|
if (e.dataTransfer) {
|
||||||
|
e.dataTransfer.effectAllowed = "move";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleDragOver(e: DragEvent, key: ColumnKey) {
|
||||||
|
e.preventDefault();
|
||||||
|
if (e.dataTransfer) {
|
||||||
|
e.dataTransfer.dropEffect = "move";
|
||||||
|
}
|
||||||
|
dragOverKey = key;
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleDrop(e: DragEvent, targetKey: ColumnKey) {
|
||||||
|
e.preventDefault();
|
||||||
|
if (!dragKey || dragKey === targetKey) return;
|
||||||
|
const order = [...$columnOrder];
|
||||||
|
const fromIndex = order.indexOf(dragKey);
|
||||||
|
let toIndex = order.indexOf(targetKey);
|
||||||
|
if (fromIndex === -1 || toIndex === -1) return;
|
||||||
|
order.splice(fromIndex, 1);
|
||||||
|
if (fromIndex < toIndex) {
|
||||||
|
toIndex -= 1;
|
||||||
|
}
|
||||||
|
order.splice(toIndex, 0, dragKey);
|
||||||
|
columnOrder.set(order);
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleDragEnd() {
|
||||||
|
dragKey = null;
|
||||||
|
dragOverKey = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
let orderedColumns = $derived(
|
||||||
|
columns.slice().sort((a, b) => {
|
||||||
|
const aIndex = $columnOrder.indexOf(a.key);
|
||||||
|
const bIndex = $columnOrder.indexOf(b.key);
|
||||||
|
if (aIndex === -1 && bIndex === -1) return 0;
|
||||||
|
if (aIndex === -1) return 1;
|
||||||
|
if (bIndex === -1) return -1;
|
||||||
|
return aIndex - bIndex;
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
let activeVisibleColumns = $derived(
|
||||||
|
columns
|
||||||
|
.filter((c) => isColumnVisible(c.key))
|
||||||
|
.sort((a, b) => {
|
||||||
|
const aIndex = $columnOrder.indexOf(a.key);
|
||||||
|
const bIndex = $columnOrder.indexOf(b.key);
|
||||||
|
if (aIndex === -1 && bIndex === -1) return 0;
|
||||||
|
if (aIndex === -1) return 1;
|
||||||
|
if (bIndex === -1) return -1;
|
||||||
|
return aIndex - bIndex;
|
||||||
|
})
|
||||||
|
.map((c) => c.key)
|
||||||
|
);
|
||||||
|
|
||||||
|
let columnLabelMap = $derived(Object.fromEntries(columns.map((c) => [c.key, c.label])));
|
||||||
|
|
||||||
|
$effect(() => {
|
||||||
|
const staticKeys = new Set(columns.map((c) => c.key));
|
||||||
|
const order = $columnOrder;
|
||||||
|
const hasStale = order.some((k) => !staticKeys.has(k));
|
||||||
|
const missing = columns.filter((c) => !order.includes(c.key)).map((c) => c.key);
|
||||||
|
if (hasStale || missing.length > 0) {
|
||||||
|
const cleaned = order.filter((k) => staticKeys.has(k));
|
||||||
|
columnOrder.set([...cleaned, ...missing]);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
function formatSpeed(speed: number): string {
|
function formatSpeed(speed: number): string {
|
||||||
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
|
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function formatDrafted(drafted: number, accepted: number): string {
|
||||||
|
return drafted > 0 ? (accepted * 100 / drafted).toFixed(1) + "% (" + accepted + "/" + drafted + ")" : "-";
|
||||||
|
}
|
||||||
|
|
||||||
function formatDuration(ms: number): string {
|
function formatDuration(ms: number): string {
|
||||||
return (ms / 1000).toFixed(2) + "s";
|
return (ms / 1000).toFixed(2) + "s";
|
||||||
}
|
}
|
||||||
@@ -157,22 +232,37 @@
|
|||||||
</svg>
|
</svg>
|
||||||
</button>
|
</button>
|
||||||
{#if columnsMenuOpen}
|
{#if columnsMenuOpen}
|
||||||
<div class="absolute right-0 top-full mt-1 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-10 py-1 min-w-[16rem]">
|
<div class="absolute right-0 top-full mt-1 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-10 py-1 min-w-[16rem]" role="list">
|
||||||
<div class="px-3 py-2 text-xs font-medium uppercase tracking-wider text-gray-500 dark:text-gray-400 border-b border-gray-200 dark:border-white/10">
|
<div class="px-3 py-2 text-xs font-medium uppercase tracking-wider text-gray-500 dark:text-gray-400 border-b border-gray-200 dark:border-white/10" role="presentation">
|
||||||
Columns
|
Columns
|
||||||
</div>
|
</div>
|
||||||
{#each columns as col (col.key)}
|
{#each orderedColumns as col (col.key)}
|
||||||
<label
|
{@const key = col.key}
|
||||||
class="flex items-center gap-2 px-3 py-1.5 text-sm cursor-pointer hover:bg-secondary-hover transition-colors"
|
<div
|
||||||
|
class="flex items-center gap-2 px-3 py-1.5 text-sm hover:bg-secondary-hover transition-colors {dragOverKey === key && dragKey !== key ? 'bg-primary/10 ring-1 ring-primary/40' : ''} {dragKey === key ? 'opacity-40' : ''}"
|
||||||
|
role="listitem"
|
||||||
|
ondragover={(e) => handleDragOver(e, key)}
|
||||||
|
ondrop={(e) => handleDrop(e, key)}
|
||||||
>
|
>
|
||||||
<input
|
<span
|
||||||
type="checkbox"
|
class="text-txtsecondary select-none cursor-grab"
|
||||||
checked={$visibleColumns.includes(col.key)}
|
draggable={true}
|
||||||
onchange={() => toggleColumn(col.key)}
|
role="button"
|
||||||
class="rounded"
|
tabindex="-1"
|
||||||
/>
|
aria-label="Drag to reorder {col.label}"
|
||||||
{col.label}
|
ondragstart={(e) => handleDragStart(e, key)}
|
||||||
</label>
|
ondragend={handleDragEnd}
|
||||||
|
>⋮⋮</span>
|
||||||
|
<label class="flex items-center gap-2 flex-1 cursor-pointer">
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
checked={isColumnVisible(key)}
|
||||||
|
onchange={() => toggleColumn(key)}
|
||||||
|
class="rounded"
|
||||||
|
/>
|
||||||
|
{col.label}
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
{/each}
|
{/each}
|
||||||
</div>
|
</div>
|
||||||
{/if}
|
{/if}
|
||||||
@@ -182,112 +272,90 @@
|
|||||||
<table class="min-w-full divide-y">
|
<table class="min-w-full divide-y">
|
||||||
<thead class="border-gray-200 dark:border-white/10">
|
<thead class="border-gray-200 dark:border-white/10">
|
||||||
<tr class="text-left text-xs uppercase tracking-wider">
|
<tr class="text-left text-xs uppercase tracking-wider">
|
||||||
{#if $visibleColumns.includes("id")}
|
{#each activeVisibleColumns as key (key)}
|
||||||
<th class="px-6 py-3">ID</th>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("time")}
|
|
||||||
<th class="px-6 py-3">Time</th>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("model")}
|
|
||||||
<th class="px-6 py-3">Model</th>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("req_path")}
|
|
||||||
<th class="px-6 py-3">Path</th>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("resp_status_code")}
|
|
||||||
<th class="px-6 py-3">Status</th>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("resp_content_type")}
|
|
||||||
<th class="px-6 py-3">Content-Type</th>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("cached")}
|
|
||||||
<th class="px-6 py-3">
|
<th class="px-6 py-3">
|
||||||
Cached <Tooltip content="prompt tokens from cache" />
|
{#if key === "cached"}
|
||||||
|
Cached <Tooltip content="prompt tokens from cache" />
|
||||||
|
{:else if key === "prompt"}
|
||||||
|
Prompt <Tooltip content="new prompt tokens processed" />
|
||||||
|
{:else if key === "drafted"}
|
||||||
|
Drafted <Tooltip content="acceptance rate (accepted/drafted)" />
|
||||||
|
{:else}
|
||||||
|
{columnLabelMap[key] ?? key}
|
||||||
|
{/if}
|
||||||
</th>
|
</th>
|
||||||
{/if}
|
{/each}
|
||||||
{#if $visibleColumns.includes("prompt")}
|
|
||||||
<th class="px-6 py-3">
|
|
||||||
Prompt <Tooltip content="new prompt tokens processed" />
|
|
||||||
</th>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("generated")}
|
|
||||||
<th class="px-6 py-3">Generated</th>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("prompt_speed")}
|
|
||||||
<th class="px-6 py-3">Prompt Speed</th>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("gen_speed")}
|
|
||||||
<th class="px-6 py-3">Gen Speed</th>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("duration")}
|
|
||||||
<th class="px-6 py-3">Duration</th>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("capture")}
|
|
||||||
<th class="px-6 py-3">Capture</th>
|
|
||||||
{/if}
|
|
||||||
</tr>
|
</tr>
|
||||||
</thead>
|
</thead>
|
||||||
<tbody class="divide-y">
|
<tbody class="divide-y">
|
||||||
{#if sortedMetrics.length === 0}
|
{#if sortedMetrics.length === 0}
|
||||||
<tr>
|
<tr>
|
||||||
<td colspan={$visibleColumns.length} class="px-6 py-8 text-center text-sm text-gray-500 dark:text-gray-400">
|
<td colspan={activeVisibleColumns.length} class="px-6 py-8 text-center text-sm text-gray-500 dark:text-gray-400">
|
||||||
No activity recorded
|
No activity recorded
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
{:else}
|
{:else}
|
||||||
{#each sortedMetrics as metric (metric.id)}
|
{#each sortedMetrics as metric (metric.id)}
|
||||||
<tr class="whitespace-nowrap text-sm border-gray-200 dark:border-white/10">
|
<tr class="whitespace-nowrap text-sm border-gray-200 dark:border-white/10">
|
||||||
{#if $visibleColumns.includes("id")}
|
{#each activeVisibleColumns as key (key)}
|
||||||
<td class="px-4 py-4">{metric.id + 1}</td>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("time")}
|
|
||||||
<td class="px-6 py-4">{formatRelativeTime(metric.timestamp)}</td>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("model")}
|
|
||||||
<td class="px-6 py-4">{metric.model}</td>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("req_path")}
|
|
||||||
<td class="px-6 py-4">{metric.req_path || "-"}</td>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("resp_status_code")}
|
|
||||||
<td class="px-6 py-4">{metric.resp_status_code || "-"}</td>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("resp_content_type")}
|
|
||||||
<td class="px-6 py-4">{metric.resp_content_type || "-"}</td>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("cached")}
|
|
||||||
<td class="px-6 py-4">{metric.tokens.cache_tokens > 0 ? metric.tokens.cache_tokens.toLocaleString() : "-"}</td>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("prompt")}
|
|
||||||
<td class="px-6 py-4">{metric.tokens.input_tokens.toLocaleString()}</td>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("generated")}
|
|
||||||
<td class="px-6 py-4">{metric.tokens.output_tokens.toLocaleString()}</td>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("prompt_speed")}
|
|
||||||
<td class="px-6 py-4">{formatSpeed(metric.tokens.prompt_per_second)}</td>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("gen_speed")}
|
|
||||||
<td class="px-6 py-4">{formatSpeed(metric.tokens.tokens_per_second)}</td>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("duration")}
|
|
||||||
<td class="px-6 py-4">{formatDuration(metric.duration_ms)}</td>
|
|
||||||
{/if}
|
|
||||||
{#if $visibleColumns.includes("capture")}
|
|
||||||
<td class="px-6 py-4">
|
<td class="px-6 py-4">
|
||||||
{#if metric.has_capture}
|
{#if key === "id"}
|
||||||
<button
|
{metric.id + 1}
|
||||||
onclick={() => viewCapture(metric.id)}
|
{:else if key === "time"}
|
||||||
disabled={loadingCaptureId === metric.id}
|
{formatRelativeTime(metric.timestamp)}
|
||||||
class="btn btn--sm"
|
{:else if key === "model"}
|
||||||
>
|
{metric.model}
|
||||||
{loadingCaptureId === metric.id ? "..." : "View"}
|
{:else if key === "req_path"}
|
||||||
</button>
|
{metric.req_path || "-"}
|
||||||
|
{:else if key === "resp_status_code"}
|
||||||
|
{#if metric.error_msg}
|
||||||
|
<span class="text-red-500 dark:text-red-400 cursor-help" title={metric.error_msg}>
|
||||||
|
{metric.resp_status_code || "-"}
|
||||||
|
</span>
|
||||||
|
{:else}
|
||||||
|
{metric.resp_status_code || "-"}
|
||||||
|
{/if}
|
||||||
|
{:else if key === "resp_content_type"}
|
||||||
|
{metric.resp_content_type || "-"}
|
||||||
|
{:else if key === "cached"}
|
||||||
|
{metric.tokens.cache_tokens > 0 ? metric.tokens.cache_tokens.toLocaleString() : "-"}
|
||||||
|
{:else if key === "prompt"}
|
||||||
|
{metric.tokens.input_tokens.toLocaleString()}
|
||||||
|
{:else if key === "generated"}
|
||||||
|
{metric.tokens.output_tokens.toLocaleString()}
|
||||||
|
{:else if key === "drafted"}
|
||||||
|
{formatDrafted(metric.tokens.draft_tokens, metric.tokens.draft_acc_tokens)}
|
||||||
|
{:else if key === "prompt_speed"}
|
||||||
|
{formatSpeed(metric.tokens.prompt_per_second)}
|
||||||
|
{:else if key === "gen_speed"}
|
||||||
|
{formatSpeed(metric.tokens.tokens_per_second)}
|
||||||
|
{:else if key === "duration"}
|
||||||
|
{formatDuration(metric.duration_ms)}
|
||||||
|
{:else if key === "capture"}
|
||||||
|
{#if metric.has_capture}
|
||||||
|
<button
|
||||||
|
onclick={() => viewCapture(metric.id)}
|
||||||
|
disabled={loadingCaptureId === metric.id}
|
||||||
|
class="btn btn--sm"
|
||||||
|
>
|
||||||
|
{loadingCaptureId === metric.id ? "..." : "View"}
|
||||||
|
</button>
|
||||||
|
{:else}
|
||||||
|
<span class="text-txtsecondary">-</span>
|
||||||
|
{/if}
|
||||||
|
{:else if key === "meta"}
|
||||||
|
{#if Object.keys(metric.metadata || {}).length > 0}
|
||||||
|
<MetadataTooltip metadata={metric.metadata}>
|
||||||
|
<span class="cursor-help text-txtsecondary hover:text-txtmain">...</span>
|
||||||
|
</MetadataTooltip>
|
||||||
|
{:else}
|
||||||
|
<span class="text-txtsecondary">-</span>
|
||||||
|
{/if}
|
||||||
{:else}
|
{:else}
|
||||||
<span class="text-txtsecondary">-</span>
|
-
|
||||||
{/if}
|
{/if}
|
||||||
</td>
|
</td>
|
||||||
{/if}
|
{/each}
|
||||||
</tr>
|
</tr>
|
||||||
{/each}
|
{/each}
|
||||||
{/if}
|
{/if}
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ export const proxyLogs = writable<string>("");
|
|||||||
export const upstreamLogs = writable<string>("");
|
export const upstreamLogs = writable<string>("");
|
||||||
export const metrics = writable<ActivityLogEntry[]>([]);
|
export const metrics = writable<ActivityLogEntry[]>([]);
|
||||||
export const inFlightRequests = writable<number>(0);
|
export const inFlightRequests = writable<number>(0);
|
||||||
|
export const performanceEnabled = writable<boolean>(false);
|
||||||
export const versionInfo = writable<VersionInfo>({
|
export const versionInfo = writable<VersionInfo>({
|
||||||
build_date: "unknown",
|
build_date: "unknown",
|
||||||
commit: "unknown",
|
commit: "unknown",
|
||||||
@@ -176,15 +177,19 @@ export async function unloadSingleModel(model: string): Promise<void> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function loadModel(model: string): Promise<void> {
|
export async function loadModel(model: string, signal?: AbortSignal): Promise<void> {
|
||||||
try {
|
try {
|
||||||
const response = await fetch(`/upstream/${model}/`, {
|
const response = await fetch(`/upstream/${model}/?_=${Date.now()}`, {
|
||||||
method: "GET",
|
method: "GET",
|
||||||
|
signal,
|
||||||
});
|
});
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
throw new Error(`Failed to load model: ${response.status}`);
|
throw new Error(`Failed to load model: ${response.status}`);
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
if (error instanceof DOMException && error.name === "AbortError") {
|
||||||
|
return;
|
||||||
|
}
|
||||||
console.error("Failed to load model:", error);
|
console.error("Failed to load model:", error);
|
||||||
throw error;
|
throw error;
|
||||||
}
|
}
|
||||||
@@ -206,6 +211,20 @@ export async function getCapture(id: number): Promise<ReqRespCapture | null> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function checkPerformanceEnabled(): Promise<void> {
|
||||||
|
try {
|
||||||
|
const response = await fetch("/api/performance");
|
||||||
|
if (!response.ok) {
|
||||||
|
performanceEnabled.set(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const data = await response.json();
|
||||||
|
performanceEnabled.set(data.enabled);
|
||||||
|
} catch {
|
||||||
|
performanceEnabled.set(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export async function fetchPerformance(after?: string): Promise<PerformanceResponse | null> {
|
export async function fetchPerformance(after?: string): Promise<PerformanceResponse | null> {
|
||||||
try {
|
try {
|
||||||
const url = after ? `/api/performance?after=${encodeURIComponent(after)}` : "/api/performance";
|
const url = after ? `/api/performance?after=${encodeURIComponent(after)}` : "/api/performance";
|
||||||
|
|||||||
Reference in New Issue
Block a user