Compare commits
46 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0292c90ca1 | |||
| 617c7dc6b9 | |||
| 542b79dacf | |||
| 0a25b3bd31 | |||
| 32bc781326 | |||
| 316ad63f76 | |||
| e37077a963 | |||
| eff9b60434 | |||
| 9bcddad91b | |||
| a15e47922c | |||
| 0ab214d1c8 | |||
| d07b063ab6 | |||
| 826210dac9 | |||
| 6cf1317341 | |||
| 8e84b2ec4f | |||
| ed77385d08 | |||
| 92b90447e8 | |||
| 62aea0e83d | |||
| 8c660dcb90 | |||
| f6877b8175 | |||
| 9b3a33d7b9 | |||
| 0cfe5a6639 | |||
| 44e1501e81 | |||
| 46cea36bc2 | |||
| ccfba0df28 | |||
| ddfae90b19 | |||
| 29d3d9ba20 | |||
| 9be9a87fa0 | |||
| 6ea551362e | |||
| 03d58e53fa | |||
| c790d0ee03 | |||
| 4ca9c478a2 | |||
| 146a9eab24 | |||
| 02e015fa49 | |||
| 63bc266395 | |||
| 636b53e70f | |||
| 59cd3b690d | |||
| 5d1e62d224 | |||
| dbb869d019 | |||
| 26bb17e57e | |||
| 2982dd3d40 | |||
| 79dc87f881 | |||
| b2fcc2daa1 | |||
| 6a9c4efc8f | |||
| 0c813e44d1 | |||
| fe71e8a6ea |
+3
-1
@@ -13,8 +13,10 @@ reviews:
|
||||
docstrings:
|
||||
enabled: false
|
||||
auto_review:
|
||||
enabled: true
|
||||
enabled: false
|
||||
drafts: false
|
||||
unit_tests:
|
||||
enabled: false
|
||||
chat:
|
||||
auto_reply: true
|
||||
issue_enrichment:
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
name: Build CUDA image (fork)
|
||||
|
||||
# Builds this fork's llama-swap (serial scheduler + embedded UI) from source and
|
||||
# layers it on a pinned llama.cpp CUDA server base, then pushes to the Gitea
|
||||
# container registry, e.g. gitea.stevedudenhoeffer.com/steve/llama-swap:v230-cuda-b9821
|
||||
#
|
||||
# Requires repo secrets: REGISTRY_USER, REGISTRY_PASSWORD (push to the registry).
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
llama_swap_version:
|
||||
description: "llama-swap version label (image tag prefix)"
|
||||
required: false
|
||||
default: "v230"
|
||||
llamacpp_build:
|
||||
description: "llama.cpp CUDA server build (base image tag suffix)"
|
||||
required: false
|
||||
default: "b9821"
|
||||
# Building the build definition itself kicks off a fresh image.
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- ".gitea/workflows/build-cuda-image.yml"
|
||||
- "docker/fork-cuda.Containerfile"
|
||||
|
||||
env:
|
||||
REGISTRY: gitea.stevedudenhoeffer.com
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Compute image metadata
|
||||
id: meta
|
||||
run: |
|
||||
LS_VER="${{ inputs.llama_swap_version || 'v230' }}"
|
||||
LCPP="${{ inputs.llamacpp_build || 'b9821' }}"
|
||||
{
|
||||
echo "image=${REGISTRY}/${{ github.repository }}"
|
||||
echo "tag=${LS_VER}-cuda-${LCPP}"
|
||||
echo "base_tag=server-cuda-${LCPP}"
|
||||
echo "ls_version=${LS_VER}"
|
||||
echo "build_date=$(date -u +%Y-%m-%dT%H:%M:%SZ)"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Gitea registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ secrets.REGISTRY_USER }}
|
||||
password: ${{ secrets.REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: docker/fork-cuda.Containerfile
|
||||
push: true
|
||||
provenance: false
|
||||
build-args: |
|
||||
BASE_TAG=${{ steps.meta.outputs.base_tag }}
|
||||
LS_VERSION=${{ steps.meta.outputs.ls_version }}
|
||||
GIT_HASH=${{ github.sha }}
|
||||
BUILD_DATE=${{ steps.meta.outputs.build_date }}
|
||||
tags: ${{ steps.meta.outputs.image }}:${{ steps.meta.outputs.tag }}
|
||||
|
||||
- name: Summary
|
||||
run: |
|
||||
echo "Pushed ${{ steps.meta.outputs.image }}:${{ steps.meta.outputs.tag }}" >> "$GITHUB_STEP_SUMMARY"
|
||||
@@ -13,11 +13,11 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f #v10.2.0
|
||||
with:
|
||||
days-before-issue-stale: 14
|
||||
days-before-issue-close: 14
|
||||
days-before-issue-stale: 30
|
||||
days-before-issue-close: 30
|
||||
stale-issue-label: "stale"
|
||||
stale-issue-message: "This issue is stale because it has been open for 2 weeks with no activity."
|
||||
close-issue-message: "This issue was closed because it has been inactive for 2 weeks since being marked as stale."
|
||||
stale-issue-message: "This issue is stale because it has been open without activity for 30 days. Please remove the stale label if this was an error."
|
||||
close-issue-message: "This issue was closed because it has been inactive for 30 days since being marked as stale."
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@@ -44,13 +44,10 @@ jobs:
|
||||
|
||||
echo "✓ config-schema.json is valid"
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 #v6.2.0
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c #6.4.0
|
||||
with:
|
||||
python-version: "3.x"
|
||||
|
||||
- name: Install check-jsonschema
|
||||
run: pip install check-jsonschema
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Validate config.example.yaml against schema
|
||||
run: check-jsonschema --schemafile config-schema.json config.example.yaml
|
||||
run: go test ./internal/config/ -run TestConfig_ExampleMatchesSchema -v
|
||||
|
||||
@@ -2,10 +2,10 @@ name: Build Containers
|
||||
|
||||
on:
|
||||
# time has no specific meaning, trying to time it after
|
||||
# the llama.cpp daily packages are published
|
||||
# the llama.cpp daily packages have time to build and publish (~8hr after llama.cpp project's cron)
|
||||
# https://github.com/ggml-org/llama.cpp/blob/master/.github/workflows/docker.yml
|
||||
schedule:
|
||||
- cron: "37 5 * * *"
|
||||
- cron: "00 12,18 * * *"
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
@@ -32,11 +32,9 @@ jobs:
|
||||
uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # 6.4.0
|
||||
with:
|
||||
node-version: "24"
|
||||
- name: Install dependencies and build UI
|
||||
- name: Build UI
|
||||
run: |
|
||||
cd ui-svelte
|
||||
npm ci
|
||||
npm run build
|
||||
make ui
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@1a80836c5c9d9e5755a25cb59ec6f45a3b5f41a8 #7.2.1
|
||||
|
||||
@@ -5,3 +5,6 @@ dist/
|
||||
.vscode
|
||||
.DS_Store
|
||||
.dev/
|
||||
|
||||
# UI build output; placeholder.txt is kept so the go:embed succeeds.
|
||||
internal/server/ui_dist/*
|
||||
|
||||
@@ -5,23 +5,22 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
|
||||
## Tech stack
|
||||
|
||||
- golang
|
||||
- typescript, vite and svelt5 for UI (located in ui/)
|
||||
- typescript, vite and svelte 5 for UI (located in ui-svelte/)
|
||||
|
||||
## Workflow Tasks
|
||||
|
||||
- when summarizing changes only include details that require further action
|
||||
- just say "Done." when there is no further action
|
||||
- use the github CLI `gh` to create pull requests and work with github
|
||||
- Rules for creating pull requests:
|
||||
- keep them short and focused on changes.
|
||||
- never include a test plan
|
||||
- keep them short and focused on changes
|
||||
- skip the test plan
|
||||
- write the summary using the same style rules as commit message
|
||||
|
||||
## Testing
|
||||
|
||||
- Follow test naming conventions like `TestProxyManager_<test name>`, `TestProcessGroup_<test name>`, etc.
|
||||
- Use `go test -v -run <name pattern for new tests>` to run any new tests you've written.
|
||||
- Run `gofmt -l .` before committing to verify formatting. Fix any reported files with `gofmt -w <file>`.
|
||||
- Run `gofmt -w <file>` before committing to fix any formatting
|
||||
- Build go binaries into the ./build/ subdirectory
|
||||
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
|
||||
- Use `make test-all` before completing work. This includes long running concurrency tests.
|
||||
- Use `make test-ui` after making changes to the UI in ui-svelte/
|
||||
@@ -29,7 +28,7 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
|
||||
### Commit message example format:
|
||||
|
||||
```
|
||||
proxy: add new feature
|
||||
internal/server: add new feature
|
||||
|
||||
Add new feature that implements functionality X and Y.
|
||||
|
||||
|
||||
@@ -19,21 +19,17 @@ all: mac linux simple-responder
|
||||
clean:
|
||||
rm -rf $(BUILD_DIR)
|
||||
|
||||
proxy/ui_dist/placeholder.txt:
|
||||
mkdir -p proxy/ui_dist
|
||||
touch $@
|
||||
|
||||
# use cached test results while developing
|
||||
test-dev: proxy/ui_dist/placeholder.txt
|
||||
go test -short ./proxy/... ./internal/...
|
||||
staticcheck ./proxy/... ./internal/... || true
|
||||
test-dev:
|
||||
go test -short ./...
|
||||
staticcheck ./... || true
|
||||
|
||||
test: proxy/ui_dist/placeholder.txt
|
||||
go test -short -count=1 ./proxy/... ./internal/...
|
||||
test:
|
||||
go test -short -count=1 ./internal/...
|
||||
|
||||
# for CI - full test (takes longer)
|
||||
test-all: proxy/ui_dist/placeholder.txt
|
||||
go test -race -count=1 ./proxy/... ./internal/...
|
||||
test-all:
|
||||
go test -race -count=1 ./internal/...
|
||||
|
||||
ui/node_modules:
|
||||
cd ui-svelte && npm install
|
||||
@@ -41,6 +37,7 @@ ui/node_modules:
|
||||
# build react UI
|
||||
ui: ui/node_modules
|
||||
cd ui-svelte && npm run build
|
||||
touch internal/server/ui_dist/placeholder.txt
|
||||
|
||||
# Build OSX binary
|
||||
mac: ui
|
||||
@@ -63,7 +60,7 @@ windows: ui
|
||||
@echo "Building Windows binary..."
|
||||
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
|
||||
|
||||
# for testing proxy.Process
|
||||
# for testing with real external processes
|
||||
simple-responder:
|
||||
@echo "Building simple responder"
|
||||
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 cmd/simple-responder/simple-responder.go
|
||||
|
||||
@@ -88,10 +88,11 @@ Real time log streaming:
|
||||
llama-swap can be installed in multiple ways
|
||||
|
||||
1. Docker
|
||||
2. Homebrew (OSX and Linux)
|
||||
3. WinGet
|
||||
4. From release binaries
|
||||
5. From source
|
||||
2. Homebrew (macOS and Linux)
|
||||
3. MacPorts (macOS)
|
||||
4. WinGet
|
||||
5. From release binaries
|
||||
6. From source
|
||||
|
||||
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||
|
||||
@@ -155,6 +156,16 @@ brew install llama-swap
|
||||
llama-swap --config path/to/config.yaml --listen localhost:8080
|
||||
```
|
||||
|
||||
### MacPorts (macOS)
|
||||
|
||||
> [!NOTE]
|
||||
> Maintained by MacPorts community - [llama-swap port](https://ports.macports.org/port/llama-swap). It is not an official part of llama-swap.
|
||||
|
||||
```shell
|
||||
sudo port install llama-swap
|
||||
llama-swap --config path/to/config.yaml --listen localhost:8080
|
||||
```
|
||||
|
||||
### WinGet Install (Windows)
|
||||
|
||||
> [!NOTE]
|
||||
|
||||
@@ -0,0 +1,306 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
var loremWords = strings.Fields(
|
||||
"Lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod tempor " +
|
||||
"incididunt ut labore et dolore magna aliqua Ut enim ad minim veniam quis nostrud " +
|
||||
"exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat Duis aute " +
|
||||
"irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla " +
|
||||
"pariatur Excepteur sint occaecat cupidatat non proident sunt in culpa qui officia " +
|
||||
"deserunt mollit anim id est laborum Sed ut perspiciatis unde omnis iste natus error " +
|
||||
"sit voluptatem accusantium doloremque laudantium totam rem aperiam eaque ipsa quae " +
|
||||
"ab illo inventore veritatis et quasi architecto beatae vitae dicta sunt explicabo " +
|
||||
"Nemo enim ipsam voluptatem quia voluptas sit aspernatur aut odit aut fugit",
|
||||
)
|
||||
|
||||
var (
|
||||
flagListen = flag.String("listen", "localhost:9898", "listen address")
|
||||
flagTokens = flag.Int("tokens", 1000, "number of tokens to return")
|
||||
flagTPS = flag.Float64("tps", 75, "tokens per second")
|
||||
flagLoad = flag.String("load", "0s", "simulated load duration (e.g. 2s, 500ms)")
|
||||
)
|
||||
|
||||
type chunkDelta struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
type chunkChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta chunkDelta `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type chatChunk struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []chunkChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type completionMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type completionChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message completionMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type completionUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type chatCompletion struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []completionChoice `json:"choices"`
|
||||
Usage completionUsage `json:"usage"`
|
||||
}
|
||||
|
||||
func loremText(n int) string {
|
||||
words := make([]string, n)
|
||||
for i := range words {
|
||||
words[i] = loremWords[i%len(loremWords)]
|
||||
}
|
||||
return strings.Join(words, " ")
|
||||
}
|
||||
|
||||
func sendChunk(w http.ResponseWriter, content string, finishReason *string) error {
|
||||
chunk := chatChunk{
|
||||
ID: "chatcmpl-fake",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: time.Now().Unix(),
|
||||
Model: "fake-model",
|
||||
Choices: []chunkChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: chunkDelta{Content: content},
|
||||
FinishReason: finishReason,
|
||||
},
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(chunk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
return err
|
||||
}
|
||||
|
||||
// startLoading runs the countdown log and closes ready when loadDur elapses.
|
||||
// If loadDur is zero, ready is closed immediately.
|
||||
func startLoading(loadDur time.Duration) <-chan struct{} {
|
||||
ready := make(chan struct{})
|
||||
if loadDur == 0 {
|
||||
close(ready)
|
||||
return ready
|
||||
}
|
||||
go func() {
|
||||
deadline := time.Now().Add(loadDur)
|
||||
log.Printf("loading... %s remaining", loadDur.Round(time.Second))
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
timer := time.NewTimer(loadDur)
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
close(ready)
|
||||
log.Printf("ready")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if rem := time.Until(deadline).Round(time.Second); rem > 0 {
|
||||
log.Printf("loading... %s remaining", rem)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return ready
|
||||
}
|
||||
|
||||
func healthHandler(ready <-chan struct{}) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
select {
|
||||
case <-ready:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func chatHandler(ready <-chan struct{}) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to read body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
streaming := gjson.GetBytes(body, "stream").Bool()
|
||||
ctx := r.Context()
|
||||
|
||||
select {
|
||||
case <-ready:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
tokens := *flagTokens
|
||||
tps := *flagTPS
|
||||
if tps <= 0 {
|
||||
tps = 1
|
||||
}
|
||||
|
||||
if !streaming {
|
||||
delay := time.Duration(float64(tokens) / tps * float64(time.Second))
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
text := loremText(tokens)
|
||||
resp := chatCompletion{
|
||||
ID: "chatcmpl-fake",
|
||||
Object: "chat.completion",
|
||||
Created: time.Now().Unix(),
|
||||
Model: "fake-model",
|
||||
Choices: []completionChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: completionMessage{Role: "assistant", Content: text},
|
||||
FinishReason: "stop",
|
||||
},
|
||||
},
|
||||
Usage: completionUsage{
|
||||
PromptTokens: 0,
|
||||
CompletionTokens: tokens,
|
||||
TotalTokens: tokens,
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Send role delta first
|
||||
first := chatChunk{
|
||||
ID: "chatcmpl-fake",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: time.Now().Unix(),
|
||||
Model: "fake-model",
|
||||
Choices: []chunkChoice{
|
||||
{Index: 0, Delta: chunkDelta{Role: "assistant"}},
|
||||
},
|
||||
}
|
||||
if data, err := json.Marshal(first); err == nil {
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
interval := time.Duration(float64(time.Second) / tps)
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
stop := "stop"
|
||||
for i := 0; i < tokens; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
word := loremWords[i%len(loremWords)]
|
||||
if i < tokens-1 {
|
||||
if err := sendChunk(w, word+" ", nil); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := sendChunk(w, word, &stop); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
loadDur, err := time.ParseDuration(*flagLoad)
|
||||
if err != nil {
|
||||
log.Fatalf("invalid -load value %q: %v", *flagLoad, err)
|
||||
}
|
||||
|
||||
ready := startLoading(loadDur)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/health", healthHandler(ready))
|
||||
mux.HandleFunc("/v1/chat/completions", chatHandler(ready))
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: *flagListen,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Printf("listening on %s (tokens=%d tps=%.1f load=%s)",
|
||||
*flagListen, *flagTokens, *flagTPS, loadDur)
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
|
||||
log.Println("shutting down...")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
log.Printf("shutdown error: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -8,9 +8,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
)
|
||||
|
||||
func printSysStat(s perf.SysStat) {
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
func main() {
|
||||
prompt := flag.String("prompt", "Write a few sentences about the history of computing.", "user message sent to each model")
|
||||
maxTokens := flag.Int("max-tokens", 256, "max_tokens per request")
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, "Usage: %s [flags] <base-url> <model> [model...]\n", os.Args[0])
|
||||
fmt.Fprintf(os.Stderr, "Example: %s -max-tokens 400 http://localhost:8080 A B C D\n\n", os.Args[0])
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
args := flag.Args()
|
||||
if len(args) < 2 {
|
||||
flag.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
baseURL := args[0]
|
||||
models := args[1:]
|
||||
|
||||
m := newModel(models)
|
||||
prog := tea.NewProgram(m, tea.WithAltScreen(), tea.WithMouseCellMotion())
|
||||
|
||||
// Chain of triggers ensures requests are sent in the order provided.
|
||||
triggers := make([]chan struct{}, len(models))
|
||||
for i := range triggers {
|
||||
triggers[i] = make(chan struct{}, 1)
|
||||
}
|
||||
triggers[0] <- struct{}{}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
start := time.Now()
|
||||
|
||||
for i, name := range models {
|
||||
wg.Add(1)
|
||||
go func(idx int, mdl string) {
|
||||
defer wg.Done()
|
||||
|
||||
<-triggers[idx]
|
||||
|
||||
reqStart := time.Now()
|
||||
prog.Send(statusMsg{idx: idx, status: statusStreaming})
|
||||
|
||||
if idx+1 < len(triggers) {
|
||||
triggers[idx+1] <- struct{}{}
|
||||
}
|
||||
|
||||
err := sendRequest(baseURL, mdl, *prompt, *maxTokens, idx, func(i int, text string) {
|
||||
prog.Send(deltaMsg{idx: i, text: text})
|
||||
})
|
||||
|
||||
elapsed := time.Since(reqStart)
|
||||
if err != nil {
|
||||
prog.Send(statusMsg{idx: idx, status: statusError, elapsed: elapsed, err: err})
|
||||
} else {
|
||||
prog.Send(statusMsg{idx: idx, status: statusDone, elapsed: elapsed})
|
||||
}
|
||||
}(i, name)
|
||||
}
|
||||
|
||||
if _, err := prog.Run(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
printSummary(m, start)
|
||||
}
|
||||
|
||||
func printSummary(m *model, start time.Time) {
|
||||
fmt.Println("Summary:")
|
||||
for _, p := range m.panels {
|
||||
switch p.status {
|
||||
case statusError:
|
||||
fmt.Printf(" [%d] %-20s ERROR elapsed=%s err=%v\n",
|
||||
p.idx, p.model, p.elapsed.Round(time.Millisecond), p.err)
|
||||
case statusDone:
|
||||
fmt.Printf(" [%d] %-20s done elapsed=%s\n",
|
||||
p.idx, p.model, p.elapsed.Round(time.Millisecond))
|
||||
default:
|
||||
fmt.Printf(" [%d] %-20s %s\n", p.idx, p.model, p.status)
|
||||
}
|
||||
}
|
||||
fmt.Printf("all done in %s\n", time.Since(start).Round(time.Millisecond))
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// deltaSink receives streamed text fragments for a given model panel.
|
||||
type deltaSink func(idx int, text string)
|
||||
|
||||
type streamDelta struct {
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
}
|
||||
|
||||
type streamChoice struct {
|
||||
Delta streamDelta `json:"delta"`
|
||||
}
|
||||
|
||||
type streamChunk struct {
|
||||
Choices []streamChoice `json:"choices"`
|
||||
}
|
||||
|
||||
// sendRequest streams a chat completion and forwards each content/reasoning
|
||||
// delta to sink. Reasoning and assistant content are emitted into the same
|
||||
// stream so they render together.
|
||||
func sendRequest(baseURL, model, prompt string, maxTokens, idx int, sink deltaSink) error {
|
||||
payload := map[string]any{
|
||||
"model": model,
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": prompt},
|
||||
},
|
||||
"max_tokens": maxTokens,
|
||||
"stream": true,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := http.Post(baseURL+"/v1/chat/completions", "application/json", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(b)))
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if data == "" || data == "[DONE]" {
|
||||
if data == "[DONE]" {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
var chunk streamChunk
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
for _, c := range chunk.Choices {
|
||||
if c.Delta.ReasoningContent != "" {
|
||||
sink(idx, c.Delta.ReasoningContent)
|
||||
}
|
||||
if c.Delta.Content != "" {
|
||||
sink(idx, c.Delta.Content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
@@ -0,0 +1,343 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
type panelStatus int
|
||||
|
||||
const (
|
||||
statusWaiting panelStatus = iota
|
||||
statusStreaming
|
||||
statusDone
|
||||
statusError
|
||||
)
|
||||
|
||||
func (s panelStatus) String() string {
|
||||
switch s {
|
||||
case statusStreaming:
|
||||
return "streaming"
|
||||
case statusDone:
|
||||
return "done"
|
||||
case statusError:
|
||||
return "error"
|
||||
default:
|
||||
return "waiting"
|
||||
}
|
||||
}
|
||||
|
||||
// deltaMsg appends streamed text to a panel.
|
||||
type deltaMsg struct {
|
||||
idx int
|
||||
text string
|
||||
}
|
||||
|
||||
// statusMsg updates a panel's lifecycle state.
|
||||
type statusMsg struct {
|
||||
idx int
|
||||
status panelStatus
|
||||
elapsed time.Duration
|
||||
err error
|
||||
}
|
||||
|
||||
type panel struct {
|
||||
idx int
|
||||
model string
|
||||
color lipgloss.Color
|
||||
status panelStatus
|
||||
buf strings.Builder
|
||||
elapsed time.Duration
|
||||
err error
|
||||
}
|
||||
|
||||
const (
|
||||
minPanelWidth = 28
|
||||
maxCols = 3
|
||||
panelHeight = 9 // total box height including border + header
|
||||
)
|
||||
|
||||
type model struct {
|
||||
panels []*panel
|
||||
focused int
|
||||
vp viewport.Model
|
||||
width int
|
||||
height int
|
||||
cols int
|
||||
pw int // inner panel content width
|
||||
ready bool
|
||||
}
|
||||
|
||||
func newModel(models []string) *model {
|
||||
// Assign a stable color per unique model name (by first appearance).
|
||||
colorOf := map[string]lipgloss.Color{}
|
||||
panels := make([]*panel, len(models))
|
||||
for i, m := range models {
|
||||
c, ok := colorOf[m]
|
||||
if !ok {
|
||||
c = modelPalette[len(colorOf)%len(modelPalette)]
|
||||
colorOf[m] = c
|
||||
}
|
||||
panels[i] = &panel{idx: i, model: m, color: c, status: statusWaiting}
|
||||
}
|
||||
return &model{panels: panels, focused: 0}
|
||||
}
|
||||
|
||||
func (m *model) Init() tea.Cmd { return nil }
|
||||
|
||||
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.width = msg.Width
|
||||
m.height = msg.Height
|
||||
m.relayout()
|
||||
m.refreshViewport(true)
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.String() {
|
||||
case "q", "ctrl+c", "esc":
|
||||
return m, tea.Quit
|
||||
case "tab", "right", "l":
|
||||
m.setFocus(m.focused + 1)
|
||||
return m, nil
|
||||
case "shift+tab", "left", "h":
|
||||
m.setFocus(m.focused - 1)
|
||||
return m, nil
|
||||
}
|
||||
var cmd tea.Cmd
|
||||
m.vp, cmd = m.vp.Update(msg)
|
||||
return m, cmd
|
||||
|
||||
case tea.MouseMsg:
|
||||
if msg.Action == tea.MouseActionPress && msg.Button == tea.MouseButtonLeft {
|
||||
if idx, ok := m.panelAt(msg.X, msg.Y); ok {
|
||||
m.setFocus(idx)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
var cmd tea.Cmd
|
||||
m.vp, cmd = m.vp.Update(msg)
|
||||
return m, cmd
|
||||
|
||||
case deltaMsg:
|
||||
p := m.panels[msg.idx]
|
||||
p.buf.WriteString(msg.text)
|
||||
if msg.idx == m.focused {
|
||||
atBottom := m.vp.AtBottom()
|
||||
m.refreshViewport(false)
|
||||
if atBottom {
|
||||
m.vp.GotoBottom()
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case statusMsg:
|
||||
p := m.panels[msg.idx]
|
||||
p.status = msg.status
|
||||
p.elapsed = msg.elapsed
|
||||
p.err = msg.err
|
||||
if msg.err != nil {
|
||||
errTxt := lipgloss.NewStyle().Foreground(lipgloss.Color("196")).Render("\n" + msg.err.Error())
|
||||
p.buf.WriteString(errTxt)
|
||||
if msg.idx == m.focused {
|
||||
m.refreshViewport(false)
|
||||
m.vp.GotoBottom()
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *model) setFocus(idx int) {
|
||||
if len(m.panels) == 0 {
|
||||
return
|
||||
}
|
||||
if idx < 0 {
|
||||
idx = len(m.panels) - 1
|
||||
}
|
||||
if idx >= len(m.panels) {
|
||||
idx = 0
|
||||
}
|
||||
if idx == m.focused {
|
||||
return
|
||||
}
|
||||
m.focused = idx
|
||||
m.refreshViewport(true)
|
||||
}
|
||||
|
||||
// relayout recomputes grid columns and panel/viewport dimensions.
|
||||
func (m *model) relayout() {
|
||||
if m.width < minPanelWidth+4 {
|
||||
m.cols = 1
|
||||
} else {
|
||||
m.cols = m.width / (minPanelWidth + 2)
|
||||
if m.cols > maxCols {
|
||||
m.cols = maxCols
|
||||
}
|
||||
if m.cols > len(m.panels) {
|
||||
m.cols = len(m.panels)
|
||||
}
|
||||
if m.cols < 1 {
|
||||
m.cols = 1
|
||||
}
|
||||
}
|
||||
|
||||
// inner content width: total width / cols, minus borders+padding (4) and gap.
|
||||
boxOuter := m.width/m.cols - 1
|
||||
m.pw = boxOuter - 4
|
||||
if m.pw < 8 {
|
||||
m.pw = 8
|
||||
}
|
||||
|
||||
m.vp = viewport.New(m.pw, panelHeight-2)
|
||||
m.ready = true
|
||||
}
|
||||
|
||||
func (m *model) refreshViewport(reset bool) {
|
||||
if !m.ready || len(m.panels) == 0 {
|
||||
return
|
||||
}
|
||||
content := lipgloss.NewStyle().Width(m.pw).Render(m.panels[m.focused].buf.String())
|
||||
m.vp.SetContent(content)
|
||||
if reset {
|
||||
m.vp.GotoBottom()
|
||||
}
|
||||
}
|
||||
|
||||
// panelAt maps screen coordinates to a panel index based on the grid layout.
|
||||
func (m *model) panelAt(x, y int) (int, bool) {
|
||||
if m.cols == 0 {
|
||||
return 0, false
|
||||
}
|
||||
boxOuterW := m.width/m.cols + 1
|
||||
col := x / boxOuterW
|
||||
row := y / panelHeight
|
||||
idx := row*m.cols + col
|
||||
if col < m.cols && idx >= 0 && idx < len(m.panels) {
|
||||
return idx, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (m *model) View() string {
|
||||
if !m.ready {
|
||||
return "loading..."
|
||||
}
|
||||
|
||||
rows := []string{}
|
||||
var current []string
|
||||
for i, p := range m.panels {
|
||||
current = append(current, m.renderPanel(p, i == m.focused))
|
||||
if len(current) == m.cols {
|
||||
rows = append(rows, lipgloss.JoinHorizontal(lipgloss.Top, current...))
|
||||
current = nil
|
||||
}
|
||||
}
|
||||
if len(current) > 0 {
|
||||
rows = append(rows, lipgloss.JoinHorizontal(lipgloss.Top, current...))
|
||||
}
|
||||
|
||||
grid := lipgloss.JoinVertical(lipgloss.Left, rows...)
|
||||
footer := lipgloss.NewStyle().Faint(true).Render(
|
||||
"tab/click: focus panel • wheel/↑↓/pgup/pgdn: scroll focused • q: quit")
|
||||
return grid + "\n" + footer
|
||||
}
|
||||
|
||||
// modelPalette gives each panel a distinct, readable color for its name.
|
||||
var modelPalette = []lipgloss.Color{
|
||||
"39", // blue
|
||||
"213", // magenta
|
||||
"214", // orange
|
||||
"45", // cyan
|
||||
"141", // purple
|
||||
"203", // salmon
|
||||
"82", // lime
|
||||
"227", // light yellow
|
||||
}
|
||||
|
||||
func statusColor(s panelStatus) lipgloss.Color {
|
||||
switch s {
|
||||
case statusStreaming:
|
||||
return lipgloss.Color("220") // yellow - active
|
||||
case statusDone:
|
||||
return lipgloss.Color("42") // green - success
|
||||
case statusError:
|
||||
return lipgloss.Color("196") // red - error
|
||||
default:
|
||||
return lipgloss.Color("244") // gray - waiting
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) renderPanel(p *panel, focused bool) string {
|
||||
border := lipgloss.RoundedBorder()
|
||||
if focused {
|
||||
border = lipgloss.DoubleBorder()
|
||||
}
|
||||
style := lipgloss.NewStyle().
|
||||
Border(border).
|
||||
BorderForeground(lipgloss.Color("240"))
|
||||
|
||||
statusTxt := p.status.String()
|
||||
if p.elapsed > 0 {
|
||||
statusTxt += " " + p.elapsed.Round(time.Millisecond).String()
|
||||
}
|
||||
|
||||
// Header: model name (left, model color) + status/timer (right, status color).
|
||||
name := fmt.Sprintf("[%d] %s", p.idx, p.model)
|
||||
gap := m.pw - lipgloss.Width(name) - lipgloss.Width(statusTxt)
|
||||
if gap < 1 {
|
||||
name = truncate(name, m.pw-lipgloss.Width(statusTxt)-1)
|
||||
gap = m.pw - lipgloss.Width(name) - lipgloss.Width(statusTxt)
|
||||
}
|
||||
if gap < 1 {
|
||||
gap = 1
|
||||
}
|
||||
header := lipgloss.NewStyle().Bold(true).Foreground(p.color).Render(name) +
|
||||
strings.Repeat(" ", gap) +
|
||||
lipgloss.NewStyle().Foreground(statusColor(p.status)).Render(statusTxt)
|
||||
|
||||
var bodyLines string
|
||||
if focused {
|
||||
bodyLines = m.vp.View()
|
||||
} else {
|
||||
bodyLines = tailLines(p.buf.String(), m.pw, panelHeight-2)
|
||||
}
|
||||
|
||||
content := lipgloss.JoinVertical(lipgloss.Left, header, bodyLines)
|
||||
return style.Width(m.pw).Height(panelHeight - 2).Render(content)
|
||||
}
|
||||
|
||||
func truncate(s string, w int) string {
|
||||
if w <= 0 {
|
||||
return ""
|
||||
}
|
||||
if lipgloss.Width(s) <= w {
|
||||
return s
|
||||
}
|
||||
r := []rune(s)
|
||||
if len(r) > w {
|
||||
r = r[:w]
|
||||
}
|
||||
return string(r)
|
||||
}
|
||||
|
||||
// tailLines wraps text to width w and returns the last n lines.
|
||||
func tailLines(s string, w, n int) string {
|
||||
wrapped := lipgloss.NewStyle().Width(w).Render(s)
|
||||
lines := strings.Split(wrapped, "\n")
|
||||
if len(lines) > n {
|
||||
lines = lines[len(lines)-n:]
|
||||
}
|
||||
for len(lines) < n {
|
||||
lines = append(lines, "")
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
+227
-72
@@ -82,6 +82,78 @@
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "Timeout settings for proxy connections."
|
||||
},
|
||||
"groupsConfig": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"members"
|
||||
],
|
||||
"properties": {
|
||||
"swap": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Controls model swapping behaviour within the group. True: only one model runs at a time. False: all models can run together."
|
||||
},
|
||||
"exclusive": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Controls how the group affects other groups. True: causes all other groups to unload when this group runs a model. False: does not affect other groups."
|
||||
},
|
||||
"persistent": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Prevents other groups from unloading the models in this group. Does not affect individual model behaviour."
|
||||
},
|
||||
"members": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Array of model IDs that are members of this group. Model IDs must be defined in models."
|
||||
}
|
||||
}
|
||||
},
|
||||
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
|
||||
},
|
||||
"matrixConfig": {
|
||||
"type": "object",
|
||||
"description": "Solver-based alternative to groups. Declares valid combinations of concurrent models. The solver minimizes eviction cost when swapping. A config must use either groups or matrix, not both.",
|
||||
"required": [
|
||||
"vars",
|
||||
"sets"
|
||||
],
|
||||
"properties": {
|
||||
"vars": {
|
||||
"type": "object",
|
||||
"description": "Short names for models. Keys must be alphanumeric, 1-8 characters. All sets and evict_costs must use these IDs.",
|
||||
"minProperties": 1,
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
},
|
||||
"propertyNames": {
|
||||
"pattern": "^[a-zA-Z0-9]{1,8}$"
|
||||
}
|
||||
},
|
||||
"evict_costs": {
|
||||
"type": "object",
|
||||
"description": "Relative cost of evicting a running model. Models not listed default to 1. Values must be positive integers.",
|
||||
"additionalProperties": {
|
||||
"type": "integer",
|
||||
"minimum": 1
|
||||
}
|
||||
},
|
||||
"sets": {
|
||||
"type": "object",
|
||||
"description": "Named sets of concurrent model combinations. Values are DSL strings using & (AND), | (OR), () (grouping), and +ref (inline another set). Definition order is used for tie-breaking.",
|
||||
"minProperties": 1,
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"properties": {
|
||||
@@ -306,81 +378,68 @@
|
||||
},
|
||||
"timeouts": {
|
||||
"$ref": "#/definitions/timeouts"
|
||||
},
|
||||
"capabilities": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"in": {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"uniqueItems": true,
|
||||
"default": [],
|
||||
"items": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"text",
|
||||
"audio",
|
||||
"image"
|
||||
]
|
||||
},
|
||||
"description": "List of input modalities understood by the model."
|
||||
},
|
||||
"out": {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"uniqueItems": true,
|
||||
"default": [],
|
||||
"items": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"text",
|
||||
"audio",
|
||||
"image"
|
||||
]
|
||||
},
|
||||
"description": "List of output modalities generated by the model."
|
||||
},
|
||||
"tools": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Whether the model supports function calling."
|
||||
},
|
||||
"reranker": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Whether the model supports the /v1/rerank endpoint."
|
||||
},
|
||||
"context": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 0,
|
||||
"description": "Maximum token context length supported by the model."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "Defines what the model accepts for input, output and other metadata. Used in v1/models to inform clients what the model can do. An empty capabilities block (all zero values) is treated as not configured."
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"groups": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"members"
|
||||
],
|
||||
"properties": {
|
||||
"swap": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Controls model swapping behaviour within the group. True: only one model runs at a time. False: all models can run together."
|
||||
},
|
||||
"exclusive": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Controls how the group affects other groups. True: causes all other groups to unload when this group runs a model. False: does not affect other groups."
|
||||
},
|
||||
"persistent": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Prevents other groups from unloading the models in this group. Does not affect individual model behaviour."
|
||||
},
|
||||
"members": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Array of model IDs that are members of this group. Model IDs must be defined in models."
|
||||
}
|
||||
}
|
||||
},
|
||||
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
|
||||
"$ref": "#/definitions/groupsConfig"
|
||||
},
|
||||
"matrix": {
|
||||
"type": "object",
|
||||
"description": "Solver-based alternative to groups. Declares valid combinations of concurrent models. The solver minimizes eviction cost when swapping. A config must use either groups or matrix, not both.",
|
||||
"required": [
|
||||
"vars",
|
||||
"sets"
|
||||
],
|
||||
"properties": {
|
||||
"vars": {
|
||||
"type": "object",
|
||||
"description": "Short names for models. Keys must be alphanumeric, 1-8 characters. All sets and evict_costs must use these IDs.",
|
||||
"minProperties": 1,
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
},
|
||||
"propertyNames": {
|
||||
"pattern": "^[a-zA-Z0-9]{1,8}$"
|
||||
}
|
||||
},
|
||||
"evict_costs": {
|
||||
"type": "object",
|
||||
"description": "Relative cost of evicting a running model. Models not listed default to 1. Values must be positive integers.",
|
||||
"additionalProperties": {
|
||||
"type": "integer",
|
||||
"minimum": 1
|
||||
}
|
||||
},
|
||||
"sets": {
|
||||
"type": "object",
|
||||
"description": "Named sets of concurrent model combinations. Values are DSL strings using & (AND), | (OR), () (grouping), and +ref (inline another set). Definition order is used for tie-breaking.",
|
||||
"minProperties": 1,
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
"$ref": "#/definitions/matrixConfig"
|
||||
},
|
||||
"hooks": {
|
||||
"type": "object",
|
||||
@@ -512,26 +571,122 @@
|
||||
},
|
||||
"default": {},
|
||||
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
|
||||
},
|
||||
"upstream": {
|
||||
"type": "object",
|
||||
"description": "Controls behaviour of the /upstream passthrough endpoint. Recommended to only use in special use cases; leaving it as the default will typically be the best experience.",
|
||||
"properties": {
|
||||
"ignorePaths": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"default": [
|
||||
".*\\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$"
|
||||
],
|
||||
"description": "List of RE2 compatible regular expressions. Any request to a path matching any of the regular expressions will be ignored and not trigger a swap. When not specified, defaults to a pattern matching common static-asset suffixes (.js, .json, .css, .png, .gif, .jpg, .jpeg, .ico, .txt)."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"default": {}
|
||||
},
|
||||
"routing": {
|
||||
"type": "object",
|
||||
"description": "Canonical routing/scheduling configuration. Alternative to the legacy top-level 'groups'/'matrix' keys; a config must not use both styles.",
|
||||
"properties": {
|
||||
"scheduler": {
|
||||
"type": "object",
|
||||
"description": "Scheduler configuration. Decides the order in which queued requests are serviced.",
|
||||
"properties": {
|
||||
"use": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"serial",
|
||||
"fifo"
|
||||
],
|
||||
"default": "serial",
|
||||
"description": "Scheduler to use. 'serial' (default on this fork): strict one-model-at-a-time, requests run in exact arrival order, switching models evicts every other model first. 'fifo': throughput-oriented, batches same-model requests and allows parallel/co-resident models."
|
||||
},
|
||||
"settings": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"fifo": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"priority": {
|
||||
"type": "object",
|
||||
"description": "Per-model priority. Keys are model IDs, values are integers (default 0). Higher values are serviced first.",
|
||||
"additionalProperties": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
},
|
||||
"router": {
|
||||
"type": "object",
|
||||
"description": "Router configuration. Selects between the group and matrix swapping strategies.",
|
||||
"properties": {
|
||||
"use": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"group",
|
||||
"matrix"
|
||||
],
|
||||
"default": "group",
|
||||
"description": "Router to use. 'group' uses static groups, 'matrix' uses the solver-based swap matrix."
|
||||
},
|
||||
"settings": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"groups": {
|
||||
"$ref": "#/definitions/groupsConfig"
|
||||
},
|
||||
"matrix": {
|
||||
"$ref": "#/definitions/matrixConfig"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"allOf": [
|
||||
{
|
||||
"if": {
|
||||
"required": ["groups"]
|
||||
"required": [
|
||||
"groups"
|
||||
]
|
||||
},
|
||||
"then": {
|
||||
"not": {
|
||||
"required": ["matrix"]
|
||||
"required": [
|
||||
"matrix"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"if": {
|
||||
"required": ["matrix"]
|
||||
"required": [
|
||||
"matrix"
|
||||
]
|
||||
},
|
||||
"then": {
|
||||
"not": {
|
||||
"required": ["groups"]
|
||||
"required": [
|
||||
"groups"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+224
-79
@@ -134,6 +134,18 @@ apiKeys:
|
||||
- "${env.API_KEY_1}"
|
||||
- "${env.API_KEY_2}"
|
||||
|
||||
# upstream: controls behaviour of the /upstream passthrough endpoint
|
||||
# - optional, default: empty dictionary
|
||||
# - recommended to only use in special use cases. Leaving it as the
|
||||
# default will typically be the best experience
|
||||
upstream:
|
||||
# ignorePaths: list of RE2 compatible regular expressions
|
||||
# - default: (see below)
|
||||
# - any request to a path matching any of the regular expressions
|
||||
# will be ignored and not trigger a swap
|
||||
ignorePaths:
|
||||
- '.*\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$'
|
||||
|
||||
# models: a dictionary of model configurations
|
||||
# - required
|
||||
# - each key is the model's ID, used in API requests
|
||||
@@ -281,7 +293,7 @@ models:
|
||||
b: 2
|
||||
# objects can contain complex types with macro substitution
|
||||
# becomes: c: [0.7, false, "model: llama"]
|
||||
c: [ "${temp}", false, "model: ${MODEL_ID}" ]
|
||||
c: ["${temp}", false, "model: ${MODEL_ID}"]
|
||||
|
||||
# concurrencyLimit: overrides the allowed number of active parallel requests to a model
|
||||
# - optional, default: 0
|
||||
@@ -312,6 +324,37 @@ models:
|
||||
tlsHandshake: 10
|
||||
idleConn: 90
|
||||
|
||||
# capabilities: defines what the model accepts for input, output and other metadata
|
||||
# - optional; omitted or all-zero means no capabilities
|
||||
# - used in v1/models to inform clients what the model can do
|
||||
capabilities:
|
||||
# in: list of modalities understood by the model
|
||||
# - default: []
|
||||
# - valid: text, audio, image
|
||||
in:
|
||||
- text
|
||||
- audio
|
||||
- image
|
||||
# out: list of modalities generated by the model
|
||||
# - default: []
|
||||
# - valid: text, audio, image
|
||||
out:
|
||||
- text
|
||||
- audio
|
||||
- image
|
||||
# tools: the model supports function calling
|
||||
# - default: false
|
||||
tools: true
|
||||
|
||||
# reranker: the model supports the /v1/rerank endpoint
|
||||
# - default: false
|
||||
reranker: false
|
||||
|
||||
# context: the maximum token context length supported
|
||||
# - default: 0
|
||||
# - must be an integer > 0
|
||||
context: 32000
|
||||
|
||||
# Unlisted model example:
|
||||
"qwen-unlisted":
|
||||
# unlisted: boolean, true or false
|
||||
@@ -343,84 +386,6 @@ models:
|
||||
# - processes have 5 seconds to shutdown until forceful termination is attempted
|
||||
cmdStop: docker stop ${MODEL_ID}
|
||||
|
||||
# =============================================================================
|
||||
# matrix: run concurrent models with a solver-based swap DSL
|
||||
# =============================================================================
|
||||
#
|
||||
# Note:
|
||||
# A config must use either a matrix or legacy groups, not both. A configuration error
|
||||
# will occur if both are defined. Configuration examples for legacy Groups can be found:
|
||||
# https://github.com/mostlygeek/llama-swap/blob/40e39f7/config.example.yaml#L334-L396
|
||||
#
|
||||
# The matrix declares valid combinations of models that can run concurrently.
|
||||
# When a model is requested, the solver finds the cheapest way to make it
|
||||
# available by evicting as few (and least costly) running models as possible.
|
||||
#
|
||||
# Solver behavior:
|
||||
# 1. Request arrives for model X
|
||||
# 2. If X is already running, forward immediately. Done.
|
||||
# 3. Find all sets containing X
|
||||
# 4. For each candidate set, compute cost: sum of evict_costs for
|
||||
# every running model NOT in that set
|
||||
# 5. Pick lowest cost candidate. Ties broken by definition order.
|
||||
# 6. Evict what needs to stop. Start X. Forward request.
|
||||
#
|
||||
# Subset semantics: a set [a, b, c] means any subset is valid.
|
||||
# Only the requested model is started — others are not preloaded.
|
||||
#
|
||||
# A model not appearing in any set can only run alone.
|
||||
#
|
||||
matrix:
|
||||
# vars: short names for models (alphanumeric, 1-8 chars)
|
||||
# - required for sets and evict_costs settings
|
||||
# - each entry is a short name to a real model ID. Do not use an alias
|
||||
# - used to keep set DSL logic short and easier to read
|
||||
# - sets and evict_costs only use identifiers defined in vars
|
||||
vars:
|
||||
g: gemma-model
|
||||
q: qwen-model
|
||||
m: mistral-model
|
||||
v: voxtral-model
|
||||
e: reranker-model
|
||||
L: llama-70B
|
||||
sd: stable-diffusion
|
||||
|
||||
# evict_costs: relative cost of losing a running model (default: 1)
|
||||
evict_costs:
|
||||
v: 50 # vllm backend, slow cold start
|
||||
L: 30 # 70B weights, slow to load
|
||||
|
||||
# sets: named sets of concurrent model combinations
|
||||
# Values are DSL strings with operators:
|
||||
# & AND (models run together)
|
||||
# | OR (alternatives)
|
||||
# () grouping
|
||||
# +ref inline another set's expression
|
||||
#
|
||||
# Expansion examples:
|
||||
# "L" → [L]
|
||||
# "a & b" → [a, b]
|
||||
# "a | b" → [a], [b]
|
||||
# "(a | b) & c" → [a, c], [b, c]
|
||||
# "(a | b) & (c | d)" → [a,c], [a,d], [b,c], [b,d]
|
||||
# "+llms & v" → expands llms inline, then applies & v
|
||||
sets:
|
||||
# LLM + TTS: switching between g/q/m won't evict v
|
||||
# expands to: [g,v], [q,v], [m,v]
|
||||
standard: "(g | q | m) & v"
|
||||
|
||||
# LLM + TTS + reranker
|
||||
# expands to: [g,v,e], [q,v,e]
|
||||
with_rerank: "(g | q) & v & e"
|
||||
|
||||
# LLM + image generation, no TTS
|
||||
# expands to: [g,sd], [q,sd]
|
||||
creative: "(g | q) & sd"
|
||||
|
||||
# 70B model uses all GPUs, can only run alone
|
||||
# expands to: [L]
|
||||
full: "L"
|
||||
|
||||
# hooks: a dictionary of event triggers and actions
|
||||
# - optional, default: empty dictionary
|
||||
# - the only supported hook is on_startup
|
||||
@@ -437,6 +402,186 @@ hooks:
|
||||
preload:
|
||||
- "llama"
|
||||
|
||||
# routing:
|
||||
# Controls how llama-swap decides which models can run at the same time and
|
||||
# which get swapped out. Choose one of two swap engines:
|
||||
#
|
||||
# - group: the default engine. Simpler to configure. You define groups of
|
||||
# models that run together, and loading one group typically unloads
|
||||
# the others.
|
||||
#
|
||||
# - matrix: the newer engine. More involved to configure, but far more
|
||||
# flexible. It uses a small expression language to describe which
|
||||
# model combinations are allowed to run concurrently, enabling
|
||||
# setups that groups cannot express.
|
||||
#
|
||||
# The routing section is optional.
|
||||
routing:
|
||||
router:
|
||||
# use: a string defining which engine to use
|
||||
# - optional, default: "group"
|
||||
# - valid values: group, matrix
|
||||
use: group
|
||||
|
||||
# settings: a dictionary of settings for the specific engines
|
||||
settings:
|
||||
# groups: a dictionary of named groups
|
||||
# - optional, default: empty dictionary
|
||||
# - lets you keep some models loaded while others swap out
|
||||
# - every member must be a model ID defined in the models section
|
||||
# - a model can belong to only one group
|
||||
# - behaviour is set per group with the `swap`, `exclusive` and
|
||||
# `persistent` fields
|
||||
# - see issue #109 for details
|
||||
#
|
||||
# NOTE: the model names below are illustrative and are not defined above.
|
||||
groups:
|
||||
# group1 reproduces llama-swap's default behaviour: only one model
|
||||
# runs at a time across the entire instance.
|
||||
"group1":
|
||||
# swap: how members of this group swap among themselves
|
||||
# - optional, default: true
|
||||
# - true: only one member runs at a time
|
||||
# - false: all members can run together, no swapping
|
||||
swap: true
|
||||
|
||||
# exclusive: how this group affects other groups
|
||||
# - optional, default: true
|
||||
# - true: running a member unloads every other group
|
||||
# - false: running a member leaves other groups untouched
|
||||
exclusive: true
|
||||
|
||||
# members: the model IDs in this group
|
||||
# required
|
||||
members:
|
||||
- "llama"
|
||||
- "qwen-unlisted"
|
||||
|
||||
# group2: members all run together, but loading any other group
|
||||
# unloads them.
|
||||
"group2":
|
||||
# swap: false lets all members stay loaded at once
|
||||
swap: false
|
||||
|
||||
# exclusive: false means requesting a member loads it without
|
||||
# unloading any other group
|
||||
exclusive: false
|
||||
members:
|
||||
- "docker-llama"
|
||||
- "modelA"
|
||||
- "modelB"
|
||||
|
||||
# forever: a persistent group that other groups can never unload.
|
||||
"forever":
|
||||
# persistent: other groups cannot unload this group's members
|
||||
# - optional, default: false
|
||||
# - has no effect on swapping within the group
|
||||
persistent: true
|
||||
|
||||
# swap/exclusive: false keeps all members loaded and avoids
|
||||
# unloading other groups
|
||||
swap: false
|
||||
exclusive: false
|
||||
members:
|
||||
- "forever-modelA"
|
||||
- "forever-modelB"
|
||||
- "forever-modelc"
|
||||
|
||||
# The matrix lists the model combinations that are allowed to run
|
||||
# concurrently. When a model is requested, the solver makes room for it
|
||||
# by evicting as few running models as possible, preferring to keep the
|
||||
# costliest ones loaded.
|
||||
#
|
||||
# Solver behaviour:
|
||||
# 1. A request arrives for model X.
|
||||
# 2. If X is already running, forward the request. Done.
|
||||
# 3. Collect every set that contains X.
|
||||
# 4. For each set, add up the evict_costs of the running models that
|
||||
# are NOT in that set — that is the set's cost.
|
||||
# 5. Choose the lowest-cost set. Break ties by definition order.
|
||||
# 6. Evict the models outside that set, start X, forward the request.
|
||||
#
|
||||
# Subset semantics: a set [a, b, c] also permits any subset of itself.
|
||||
# Only the requested model is started; the others are not preloaded.
|
||||
#
|
||||
# A model that appears in no set can only run on its own.
|
||||
#
|
||||
matrix:
|
||||
# vars: short aliases for model IDs (alphanumeric, 1-8 chars)
|
||||
# - required: sets and evict_costs reference these names, not model IDs
|
||||
# - map each short name to a real model ID (not a model alias)
|
||||
# - keeps the set expressions short and readable
|
||||
vars:
|
||||
g: gemma-model
|
||||
q: qwen-model
|
||||
m: mistral-model
|
||||
v: voxtral-model
|
||||
e: reranker-model
|
||||
L: llama-70B
|
||||
sd: stable-diffusion
|
||||
|
||||
# evict_costs: relative cost of losing a running model (default: 1)
|
||||
evict_costs:
|
||||
v: 50 # vllm backend, slow cold start
|
||||
L: 30 # 70B weights, slow to load
|
||||
|
||||
# sets: named combinations of models that may run together.
|
||||
# Each value is an expression built from these operators:
|
||||
# & AND (models run together)
|
||||
# | OR (alternatives)
|
||||
# () grouping
|
||||
# +ref inline the expression of another set
|
||||
#
|
||||
# Each expression expands into one or more concrete sets:
|
||||
# "L" → [L]
|
||||
# "a & b" → [a, b]
|
||||
# "a | b" → [a], [b]
|
||||
# "(a | b) & c" → [a, c], [b, c]
|
||||
# "(a | b) & (c | d)" → [a,c], [a,d], [b,c], [b,d]
|
||||
# "+llms & v" → inline the llms set, then AND with v
|
||||
sets:
|
||||
# An LLM plus TTS. Switching between g/q/m keeps v loaded.
|
||||
# expands to: [g,v], [q,v], [m,v]
|
||||
standard: "(g | q | m) & v"
|
||||
|
||||
# An LLM plus TTS plus reranker.
|
||||
# expands to: [g,v,e], [q,v,e]
|
||||
with_rerank: "(g | q) & v & e"
|
||||
|
||||
# An LLM plus image generation, no TTS.
|
||||
# expands to: [g,sd], [q,sd]
|
||||
creative: "(g | q) & sd"
|
||||
|
||||
# The 70B model uses every GPU, so it can only run alone.
|
||||
# expands to: [L]
|
||||
full: "L"
|
||||
|
||||
# scheduler: how queued requests are ordered and run.
|
||||
# - optional, default on this fork: "serial"
|
||||
# - valid values:
|
||||
# - "serial": strict one-model-at-a-time. Requests run in exact arrival
|
||||
# order; only one request runs at a time; switching to a different model
|
||||
# evicts every other running model first so a single model occupies memory
|
||||
# at a time. This ignores group/matrix co-residency entirely. The "fifo"
|
||||
# settings below (priority) do not apply.
|
||||
# - "fifo": throughput-oriented. Same-model requests are batched to reduce
|
||||
# swaps and a model serves up to its concurrencyLimit in parallel; models
|
||||
# in non-exclusive groups can run concurrently. Requests may be reordered.
|
||||
scheduler:
|
||||
use: serial
|
||||
settings:
|
||||
# fifo settings only apply when use: fifo
|
||||
fifo:
|
||||
# priority: a dictionary of model ID -> priority
|
||||
# - optional, default: empty dictionary
|
||||
# - models default to priority 0
|
||||
# - higher priority requests are serviced first in the queue
|
||||
priority:
|
||||
A: 10
|
||||
B: 5
|
||||
C: 5
|
||||
D: 1
|
||||
|
||||
# peers: a dictionary of remote peers and models they provide
|
||||
# - optional, default empty dictionary
|
||||
# - peers can be another llama-swap
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
# Build a CUDA llama-swap image FROM THIS FORK's source (includes the serial
|
||||
# scheduler) and layer it on a pinned llama.cpp CUDA server base. Produces e.g.:
|
||||
# gitea.stevedudenhoeffer.com/steve/llama-swap:v230-cuda-b9821
|
||||
#
|
||||
# BASE_TAG selects the llama.cpp CUDA runtime + llama-server build, e.g.
|
||||
# "server-cuda-b9821". The llama-swap binary (with the embedded Svelte UI) is
|
||||
# compiled from the repo at build time, so no GitHub release is required.
|
||||
#
|
||||
# Build context is the repo root:
|
||||
# docker build -f docker/fork-cuda.Containerfile \
|
||||
# --build-arg BASE_TAG=server-cuda-b9821 -t llama-swap:v230-cuda-b9821 .
|
||||
|
||||
ARG BASE_IMAGE=ghcr.io/ggml-org/llama.cpp
|
||||
ARG BASE_TAG=server-cuda-b9821
|
||||
|
||||
# ---- Stage 1: build the Svelte UI (embedded into the binary) ----
|
||||
FROM node:22-bookworm-slim AS ui
|
||||
WORKDIR /src/ui-svelte
|
||||
# Install deps first for layer caching. .npmrc carries legacy-peer-deps=true,
|
||||
# which the project relies on (tailwind/vite peer ranges), so copy it before
|
||||
# npm ci or the strict resolver fails with ERESOLVE.
|
||||
COPY ui-svelte/package.json ui-svelte/package-lock.json ui-svelte/.npmrc ./
|
||||
RUN npm ci
|
||||
COPY ui-svelte/ ./
|
||||
# `npm run build` is `vite build --emptyOutDir`; vite.config.ts writes to
|
||||
# ../internal/server/ui_dist, which //go:embed picks up in the next stage.
|
||||
RUN mkdir -p /src/internal/server && npm run build
|
||||
|
||||
# ---- Stage 2: build the llama-swap binary with the embedded UI ----
|
||||
FROM golang:1.26-bookworm AS build
|
||||
WORKDIR /src
|
||||
# Cache modules independently of source churn.
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
# Overlay the freshly built UI so //go:embed ui_dist ships the real assets
|
||||
# instead of the committed placeholder.
|
||||
COPY --from=ui /src/internal/server/ui_dist/ ./internal/server/ui_dist/
|
||||
ARG LS_VERSION=v230
|
||||
ARG GIT_HASH=unknown
|
||||
ARG BUILD_DATE=unknown
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||
-ldflags="-X main.version=${LS_VERSION} -X main.commit=${GIT_HASH} -X main.date=${BUILD_DATE}" \
|
||||
-o /out/llama-swap .
|
||||
|
||||
# ---- Stage 3: runtime image on the pinned llama.cpp CUDA base ----
|
||||
FROM ${BASE_IMAGE}:${BASE_TAG}
|
||||
|
||||
# Run as root by default to match the upstream `vNNN-cuda-bNNNN` (non-suffixed)
|
||||
# image that ragnaros pulls today: it needs root to reach the mounted docker
|
||||
# socket for container-backed models (sd-server). Override UID/GID at build time
|
||||
# for a non-root variant.
|
||||
ARG UID=0
|
||||
ARG GID=0
|
||||
ARG USER_HOME=/root
|
||||
ENV HOME=$USER_HOME
|
||||
|
||||
RUN set -eux; \
|
||||
if [ "$UID" -ne 0 ]; then \
|
||||
if [ "$GID" -ne 0 ]; then groupadd --system --gid "$GID" app; fi; \
|
||||
useradd --system --uid "$UID" --gid "$GID" --home "$USER_HOME" app; \
|
||||
fi; \
|
||||
mkdir --parents "$HOME" /app; \
|
||||
chown --recursive "$UID:$GID" "$HOME" /app
|
||||
|
||||
COPY --from=build --chown=$UID:$GID /out/llama-swap /app/llama-swap
|
||||
COPY --chown=$UID:$GID docker/config.example.yaml /app/config.yaml
|
||||
|
||||
USER $UID:$GID
|
||||
WORKDIR /app
|
||||
ENV PATH="/app:${PATH}"
|
||||
|
||||
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
|
||||
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
|
||||
@@ -2,10 +2,6 @@ ARG BASE_IMAGE=ghcr.io/ggml-org/llama.cpp
|
||||
ARG BASE_TAG=server-cuda
|
||||
FROM ${BASE_IMAGE}:${BASE_TAG}
|
||||
|
||||
# has to be after the FROM
|
||||
# TARGETARCH is auto-set by `docker buildx build --platform …` (amd64/arm64);
|
||||
# falls back to amd64 when an older `docker build` runs without buildx.
|
||||
ARG TARGETARCH=amd64
|
||||
ARG LS_VER=170
|
||||
ARG LS_REPO=mostlygeek/llama-swap
|
||||
|
||||
@@ -37,9 +33,15 @@ WORKDIR /app
|
||||
ENV PATH="/app:${PATH}"
|
||||
|
||||
RUN \
|
||||
curl -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_${TARGETARCH}.tar.gz" && \
|
||||
tar -zxf "llama-swap_${LS_VER}_linux_${TARGETARCH}.tar.gz" && \
|
||||
rm "llama-swap_${LS_VER}_linux_${TARGETARCH}.tar.gz"
|
||||
set -eux; \
|
||||
case "$(uname -m)" in \
|
||||
x86_64) ARCH=amd64 ;; \
|
||||
aarch64) ARCH=arm64 ;; \
|
||||
*) echo "unsupported arch: $(uname -m)" >&2; exit 1 ;; \
|
||||
esac; \
|
||||
curl --fail -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_${ARCH}.tar.gz" && \
|
||||
tar -zxf "llama-swap_${LS_VER}_linux_${ARCH}.tar.gz" && \
|
||||
rm "llama-swap_${LS_VER}_linux_${ARCH}.tar.gz"
|
||||
|
||||
COPY --chown=$UID:$GID config.example.yaml /app/config.yaml
|
||||
|
||||
|
||||
@@ -0,0 +1,264 @@
|
||||
# New Router Migration TODO
|
||||
|
||||
This document tracks the work needed for [cmd/newrouter/main.go](../cmd/newrouter/main.go) and [internal/router/](../internal/router/) to reach feature parity with the legacy entrypoint at [llama-swap.go](../llama-swap.go) plus [proxy/proxymanager.go](../proxy/proxymanager.go).
|
||||
|
||||
The work is split into phases so each can land and be tested independently. Earlier phases unblock later ones.
|
||||
|
||||
## Current state (newrouter)
|
||||
|
||||
`cmd/newrouter` already supports:
|
||||
|
||||
- Loading config via `-config`
|
||||
- Selecting Matrix vs Group router based on config
|
||||
- Peer routing fallback
|
||||
- Plain HTTP listen (`-listen`)
|
||||
- Graceful shutdown on `SIGINT` / `SIGTERM`
|
||||
- Model extraction from JSON body, query string, and form bodies (see [router.go:88](../internal/router/router.go#L88))
|
||||
- `Server.ServeHTTP` dispatches a single request to peer or local router based on the requested model
|
||||
|
||||
Everything below is missing or only partially implemented.
|
||||
|
||||
---
|
||||
|
||||
## Phase 1 — Package relocation -- Completed.
|
||||
|
||||
Goal: move shared infrastructure packages out from under `proxy/` so the new router does not depend on the legacy proxy tree. This is a prerequisite for retiring `proxy/` in Phase 8.
|
||||
|
||||
---
|
||||
|
||||
## Phase 2 — Server lifecycle parity -- Completed.
|
||||
|
||||
Goal: make `cmd/newrouter` a drop-in replacement for the legacy binary's process model, _without_ yet adding any extra HTTP endpoints.
|
||||
|
||||
---
|
||||
|
||||
## Phase 3 — `internal/chain` package -- Completed.
|
||||
|
||||
API: `chain.New(mws...).Then(final)` for ServeMux registration; `Append` returns an extended Chain without mutating the receiver, so a base stack (auth/CORS) can be reused across many routes with per-route additions.
|
||||
|
||||
---
|
||||
|
||||
## Phase 4 — `internal/server` package scaffolding (ProxyManager replacement) -- Completed.
|
||||
|
||||
Goal: build the [internal/server](../internal/server/) package so it can stand in for [proxy.ProxyManager](../proxy/proxymanager.go#L67) — the mux, lifecycle, model dispatch, custom endpoints, request filters, auth/CORS, and upstream passthrough. After this phase, `cmd/newrouter/main.go` constructs a `server.Server` instead of a bare `router.Server`.
|
||||
|
||||
The legacy `ProxyManager` collapses three concerns into one struct: the HTTP mux, the model→process router, and the cross-cutting services (loggers, metrics, perf, inflight counter, version). The new layout keeps the `router.Router` implementations focused on model dispatch and lets `internal/server.Server` own the mux and all cross-cutting middleware. `server.Server` builds the `local` and `peer` routers directly and dispatches between them itself, so it fully **supersedes `internal/router.Server`** — see the cleanup item below.
|
||||
|
||||
The phase is split into sub-phases that can land and be tested independently:
|
||||
|
||||
| Sub-phase | Scope |
|
||||
| --------- | -------------------------------------------------------------------------- |
|
||||
| 4a | package scaffolding — struct, `New`, `ServeHTTP`, `Shutdown`, model routes |
|
||||
| 4b | custom (non-model-dispatched) HTTP endpoints |
|
||||
| 4c | request-body filter middleware |
|
||||
| 4d | auth & CORS middleware |
|
||||
| 4e | upstream passthrough |
|
||||
|
||||
The package is split by concern across stub files already in place:
|
||||
|
||||
| File | Responsibility | Filled in by |
|
||||
| ------------ | ----------------------------------------------- | ---------------------- |
|
||||
| `server.go` | `Server` struct, `New`, `ServeHTTP`, `Shutdown` | 4a |
|
||||
| `log.go` | `muxlog` combined logger; `/logs` handlers | 4a |
|
||||
| `auth.go` | `CreateAuthMiddleware` | 4d |
|
||||
| `filters.go` | request-body filter middleware | 4c |
|
||||
| `api.go` | llama-swap-specific API handlers | 4b / Phase 5 / Phase 6 |
|
||||
| `ui.go` | embedded UI serving | Phase 7 |
|
||||
|
||||
### Phase 4a — package scaffolding -- Completed.
|
||||
|
||||
`server.Server` owns the mux, the `local`/`peer` routers, `muxlog`, and a
|
||||
shutdown context. `New` builds the routers, registers all model-dispatched
|
||||
routes on a stdlib `http.ServeMux`, and wraps the mux with the global CORS
|
||||
middleware. `localPeerHandler` resolves the model once via `router.FetchModel`
|
||||
and dispatches to `local` or `peer`. `Shutdown` stops both routers in parallel
|
||||
and is idempotent. `cmd/newrouter/main.go` now constructs `server.New(...)`;
|
||||
`internal/router/server.go` and `server_test.go` were removed as dead code.
|
||||
|
||||
### Phase 4b — Custom HTTP endpoints -- Completed.
|
||||
|
||||
`GET /v1/models` (local + peer models, aliases, metadata), `GET /health`,
|
||||
`GET /wol-health`, and `GET /` → `/ui` are registered. `GET /favicon.ico` is
|
||||
deferred to Phase 7 since it requires the embedded UI filesystem.
|
||||
|
||||
### Phase 4c — Request-body filters -- Completed.
|
||||
|
||||
`CreateFilterMiddleware` (in `filters.go`) applies `UseModelName`,
|
||||
`StripParams`, `SetParams`, and `SetParamsByID` to JSON requests, then
|
||||
re-attaches the body with `Content-Length` / `Transfer-Encoding` cleanup.
|
||||
|
||||
### Phase 4d — Auth & CORS -- Completed.
|
||||
|
||||
`CreateAuthMiddleware` validates API keys (Bearer / Basic / `x-api-key`) and
|
||||
strips the headers before upstream. `CreateCORSMiddleware` answers OPTIONS
|
||||
preflight; `/v1/models` echoes the `Origin`.
|
||||
|
||||
### Phase 4e — Upstream passthrough -- Completed.
|
||||
|
||||
`GET /upstream` → `/ui/models`, and `/upstream/<model>/<path>` proxies to the
|
||||
resolved model with multi-segment name resolution, canonical-form redirect
|
||||
(301/308), and prefix stripping.
|
||||
|
||||
---
|
||||
|
||||
## Phase 5 — Operations endpoints -- Completed.
|
||||
|
||||
A new `router.LocalRouter` interface embeds `Router` and adds `RunningModels()`
|
||||
and `Unload(timeout, models...)`, both implemented once on `baseRouter` so
|
||||
`Group` and `Matrix` share them — the legacy matrix/group divergence at
|
||||
[proxymanager.go:1167](../proxy/proxymanager.go#L1167) collapses since
|
||||
`baseRouter` already unifies process storage. `Peer` does not implement it;
|
||||
`Server.local` is typed `LocalRouter`, `Server.peer` stays `Router`.
|
||||
|
||||
`GET /unload` stops every local process; `GET /running` lists non-stopped
|
||||
processes joined against config for `cmd`/`proxy`/`ttl`/`name`/`description`.
|
||||
`startPreload` fires a background `GET /` at each `Hooks.OnStartup.Preload`
|
||||
model and emits `shared.ModelPreloadedEvent`.
|
||||
|
||||
---
|
||||
|
||||
## Phase 6 — Metrics, perf, and SSE -- Completed.
|
||||
|
||||
`perf.Monitor` is created and started in `cmd/newrouter/main.go` (it outlives
|
||||
config reloads via `UpdateConfig`) and passed into `server.New`. `GET /metrics`
|
||||
serves `perf.Monitor.MetricsHandler()` output, 503 when disabled.
|
||||
|
||||
`internal/process` emits `shared.ProcessStateChangeEvent` from `setState`.
|
||||
`server.inflightCounter` (atomic) + `CreateInflightMiddleware` track
|
||||
model-dispatched requests and emit `InFlightRequestsEvent`. `metricsMonitor`
|
||||
(in `metrics.go`) parses token usage from upstream responses via
|
||||
`CreateMetricsMiddleware`.
|
||||
|
||||
The `/api` group (API-key protected) is registered: `POST /api/models/unload`,
|
||||
`POST /api/models/unload/{model...}`, `GET /api/events` (SSE: `modelStatus` /
|
||||
`logData` / `metrics` / `inflight`), `GET /api/metrics`, `GET /api/performance`
|
||||
(`?after=` RFC3339 filter), `GET /api/version`. `GET /api/captures/{id}`
|
||||
returns 501 until 6f.
|
||||
|
||||
### Phase 6f — Request/response captures -- Completed.
|
||||
|
||||
`proxy/cache` moved to `internal/cache`. `metricsMonitor` stores zstd+CBOR
|
||||
`ReqRespCapture` records in a sized `cache.Cache` (`captureBuffer` MB, 0
|
||||
disables). `CreateMetricsMiddleware` buffers request body/headers before
|
||||
dispatch; `record` builds the capture per a `captureFieldsByPath` table
|
||||
(`captures.go`) that trims large audio/image payloads, defaulting JSON routes
|
||||
to `captureAll`. `GET /api/captures/{id}` decompresses and returns the capture;
|
||||
`getMetrics` resolves `HasCapture` against the cache.
|
||||
|
||||
---
|
||||
|
||||
## Phase 7 — UI serving -- Completed.
|
||||
|
||||
`internal/server/ui.go` embeds `ui_dist` and serves it. `GET /ui/` is
|
||||
brotli/gzip-aware via `serveCompressedFile`; unknown paths without a file
|
||||
extension fall back to `index.html` for SPA routing. `GET /favicon.ico` serves
|
||||
from the same embedded FS. The Makefile `ui` target copies the vite build into
|
||||
`internal/server/ui_dist`; a committed `placeholder.txt` keeps the embed valid
|
||||
before a build runs.
|
||||
|
||||
---
|
||||
|
||||
## Phase 8a - Review Part I
|
||||
|
||||
- [x] All functionality from the proxy package has been migrated in the above phases — with the remaining gaps listed in Phase 8b
|
||||
- [x] Test coverage at or exceeds the level from the proxy package — `internal/server` now at 76.6% vs 73.9% (`proxy`)
|
||||
|
||||
### Findings
|
||||
|
||||
**Gap 1 — Request logging middleware missing -- Resolved.**
|
||||
|
||||
`CreateRequestLogMiddleware` ([log.go](../internal/server/log.go)) records one
|
||||
access-log line per request to `s.proxylog` in the legacy format
|
||||
`clientIP "METHOD PATH PROTO" status bodySize "UA" duration`, skipping
|
||||
`/wol-health`, `/api/performance`, and `/metrics`. A `statusRecorder` captures
|
||||
the status/body size (forwarding `Flush` for SSE) and `clientIP` honours
|
||||
`X-Forwarded-For` / `X-Real-IP`. It is wired as the outermost middleware in
|
||||
`routes()`, wrapping the CORS layer.
|
||||
|
||||
**Gap 2 — Per-model log streaming not supported -- Resolved **
|
||||
|
||||
`Server.getLogger` ([log.go:50](../internal/server/log.go#L50)) only handles `""`, `"proxy"`, and `"upstream"`. The legacy `ProxyManager.getLogger` ([proxymanager_loghandlers.go:92](../proxy/proxymanager_loghandlers.go#L92)) additionally resolves a model ID against the active process groups / matrix and returns that process's logger. Callers of `GET /logs/stream/<modelID>` will get a 400 instead of the model's live log stream.
|
||||
|
||||
**Gap 3 — `UseModelName` not applied to multipart form endpoints -- Resolved.**
|
||||
|
||||
`CreateFormFilterMiddleware` ([filters.go](../internal/server/filters.go)) parses
|
||||
`multipart/form-data` requests, rewrites the `model` field with `UseModelName`,
|
||||
reconstructs the body via `rewriteMultipartModel`, and re-attaches it with
|
||||
`Content-Type` / `Content-Length` cleanup. It runs in `modelChain` after the
|
||||
JSON `filterMW`; each is a no-op for the other's Content-Type. Audio
|
||||
transcription (`/v1/audio/transcriptions`) and image edit (`/v1/images/edits`)
|
||||
now honour `use_model_name`.
|
||||
|
||||
**Coverage gaps (0 % functions) -- Resolved.**
|
||||
|
||||
The functions previously at 0 % (`handleListModels`, `handleMetrics`,
|
||||
`handleRootRedirect`, `handleUpstreamRedirect`, `handleUpstream`,
|
||||
`findModelInPath`, `handleAPICapture`, `handleAPIUnloadAll`,
|
||||
`handleAPIUnloadModel`, `CreateAuthMiddleware`, `extractAPIKey`,
|
||||
`handleLogStream`, `applyFilters`, `decompressBody`, `filterAcceptEncoding`,
|
||||
`handleUI`, `handleFavicon`) now have tests across `auth_test.go`, `api_test.go`,
|
||||
`filters_test.go`, `log_test.go`, and `extras_test.go`.
|
||||
|
||||
---
|
||||
|
||||
### Phase 8b - Fill gaps discovered in Phase 8a
|
||||
|
||||
- [x] **Add request-log middleware** — `CreateRequestLogMiddleware` ([log.go](../internal/server/log.go)) records `clientIP "METHOD PATH PROTO" status bodySize "UA" duration` to `s.proxylog`, skips `/wol-health` / `/api/performance` / `/metrics`, and is wired as the outermost middleware in `routes()`.
|
||||
- [x] **Extend `getLogger` with model-ID resolution** — add a `default:` branch to `Server.getLogger` ([log.go:50](../internal/server/log.go#L50)) that resolves the ID via `s.local` (using a new `LocalRouter.GetProcess(name)` method or equivalent) and returns that process's `Logger()`. Match the fallback behaviour: return a 400 with `"invalid logger. Use 'proxy', 'upstream' or a model's ID"` when not found.
|
||||
- [x] **`UseModelName` rewrite for multipart endpoints** — `CreateFormFilterMiddleware` parses `multipart/form-data`, rewrites the `model` field according to `UseModelName`, reconstructs the body, and updates `Content-Type` / `Content-Length`. It is wired into `modelChain` after the JSON filter.
|
||||
- [x] **Raise test coverage to ≥ 74 %** — `internal/server` now at 76.1%; tests added for every 0 % function across `auth_test.go`, `api_test.go`, `filters_test.go`, `log_test.go`, and `extras_test.go`.
|
||||
|
||||
---
|
||||
|
||||
## Phase 8c - Review Part II (entrypoint comparison)
|
||||
|
||||
A second pass comparing [cmd/newrouter/main.go](../cmd/newrouter/main.go) against
|
||||
the legacy [llama-swap.go](../llama-swap.go) + [proxy.New](../proxy/proxymanager.go#L104)
|
||||
surfaced four more gaps, all in logger setup.
|
||||
|
||||
**Gap 4 — `LogToStdout` config ignored -- Resolved.**
|
||||
|
||||
`cmd/newrouter/main.go` previously hardcoded `proxyLog` / `upstreamLog` to
|
||||
`os.Stdout`, and the old `muxlog()` helper built a Monitor that nothing wrote
|
||||
into — so `logToStdout` had no effect and `/logs` (combined history) was always
|
||||
empty. `server.NewLoggers` ([log.go](../internal/server/log.go)) now replicates
|
||||
the legacy switch: `proxy` / `upstream` monitors feed `muxLog` (or `io.Discard`)
|
||||
per `none` / `both` / `upstream` / `proxy`, so `muxLog` accumulates the combined
|
||||
history. `server.New` takes `muxlog` as a parameter. The loggers outlive config
|
||||
reloads, so a `LogToStdout` change requires a restart to take effect.
|
||||
|
||||
**Gap 5 — `LogTimeFormat` config ignored -- Resolved.**
|
||||
|
||||
`cmd/newrouter/main.go` now maps `cfg.LogTimeFormat` to a Go time layout via the
|
||||
`logTimeFormats` table and applies it (alongside log level) to the proxy and
|
||||
upstream monitors in `applyLogSettings`, re-applied on config reload.
|
||||
|
||||
**Gap 6 — `LogRequests` deprecation warning missing.**
|
||||
|
||||
The legacy [proxymanager.go:127](../proxy/proxymanager.go#L127) warns when the
|
||||
deprecated `logRequests` config key is set. `cmd/newrouter` does not. Low
|
||||
priority — left open.
|
||||
|
||||
**Gap 7 — PID debug log missing -- Resolved.**
|
||||
|
||||
`cmd/newrouter/main.go` now logs `PID: %d` at debug level after `applyLogSettings`,
|
||||
matching [llama-swap.go:71](../llama-swap.go#L71).
|
||||
|
||||
---
|
||||
|
||||
## Phase X (tbd) — Cutover
|
||||
|
||||
- [ ] Swap `llama-swap.go` to delegate to `cmd/newrouter` (or rename newrouter to be the primary entrypoint)
|
||||
- [ ] Update `Makefile` build targets
|
||||
- [ ] Update docs / README references to the legacy binary
|
||||
- [ ] Remove `proxy/proxymanager*.go` and `gin-gonic` dependency once nothing imports them
|
||||
- [ ] Run `make test-all` and confirm concurrency suite still passes against the new entrypoint
|
||||
|
||||
---
|
||||
|
||||
## Cross-cutting concerns to keep in mind
|
||||
|
||||
- **Single body read**: legacy and newrouter both buffer the request body once. When adding filters (Phase 4c), make sure the buffered bytes flow through `Content-Length` / `transfer-encoding` cleanup as in [proxymanager.go:872](../proxy/proxymanager.go#L872).
|
||||
- **Streaming flag in context**: legacy stashes `streaming` and `model` under `proxyCtxKey`. The new router uses `ModelKey` / `ModelIDKey` — pick one set of keys and use them consistently for metrics + log handlers.
|
||||
- **Matrix vs Group divergence**: any handler that calls `swapProcessGroup` or `findGroupByModelName` in the legacy needs a matrix branch too. The new router's `Router` interface already abstracts this — preserve that abstraction rather than reintroducing the branch in every handler.
|
||||
- **Shutdown ordering**: `httpServer.Shutdown` must drain inflight requests _before_ `Server.Shutdown` tears down processes, otherwise inflight requests 502. Current newrouter ordering at [main.go:87](../cmd/newrouter/main.go#L87) is correct — keep it.
|
||||
@@ -4,23 +4,38 @@ go 1.26.1
|
||||
|
||||
require (
|
||||
github.com/billziss-gh/golib v0.2.0
|
||||
github.com/charmbracelet/bubbles v1.0.0
|
||||
github.com/charmbracelet/bubbletea v1.3.10
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/fxamacker/cbor/v2 v2.9.1
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/google/jsonschema-go v0.4.3
|
||||
github.com/klauspost/compress v1.18.5
|
||||
github.com/shirou/gopsutil/v4 v4.26.4
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
golang.org/x/sync v0.20.0
|
||||
golang.org/x/sys v0.41.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.4.1 // indirect
|
||||
github.com/charmbracelet/x/ansi v0.11.6 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||
github.com/clipperhouse/displaywidth v0.9.0 // indirect
|
||||
github.com/clipperhouse/stringish v0.1.1 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/ebitengine/purego v0.10.0 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
@@ -31,13 +46,20 @@ require (
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.19 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.16 // indirect
|
||||
@@ -45,11 +67,11 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
github.com/x448/float16 v0.8.4 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
)
|
||||
|
||||
@@ -1,9 +1,31 @@
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/billziss-gh/golib v0.2.0 h1:NyvcAQdfvM8xokKkKotiligKjKXzuQD4PPykg1nKc/8=
|
||||
github.com/billziss-gh/golib v0.2.0/go.mod h1:mZpUYANXZkDKSnyYbX9gfnyxwe0ddRhUtfXcsD5r8dw=
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
|
||||
github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
|
||||
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
||||
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
|
||||
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
|
||||
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
|
||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
|
||||
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
|
||||
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
|
||||
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
|
||||
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
|
||||
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
|
||||
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
|
||||
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
||||
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
||||
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
||||
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
|
||||
@@ -13,6 +35,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/ebitengine/purego v0.10.0 h1:QIw4xfpWT6GWTzaW5XEKy3HXoqrJGx1ijYHzTF0/ISU=
|
||||
github.com/ebitengine/purego v0.10.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||
github.com/fxamacker/cbor/v2 v2.9.1 h1:2rWm8B193Ll4VdjsJY28jxs70IdDsHRWgQYAI80+rMQ=
|
||||
github.com/fxamacker/cbor/v2 v2.9.1/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
@@ -37,6 +61,8 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0=
|
||||
github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE=
|
||||
@@ -47,21 +73,35 @@ github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZY
|
||||
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
|
||||
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/shirou/gopsutil/v4 v4.26.4 h1:B4SXVbcwTyrocPHEmWBC4uCYr4Xcu3MK1TXqbprAOWY=
|
||||
github.com/shirou/gopsutil/v4 v4.26.4/go.mod h1:LZ6ewCSkBqUpvSOf+LsTGnRinC6iaNUNMGBtDkJBaLQ=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
@@ -97,6 +137,8 @@ github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65E
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
@@ -104,10 +146,15 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
// Package chain composes http.Handler middleware into a single handler.
|
||||
//
|
||||
// A Middleware wraps a downstream http.Handler and may run logic before or
|
||||
// after delegating to it, or short-circuit by not calling next at all
|
||||
// (e.g. auth failure, CORS preflight).
|
||||
package chain
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Middleware wraps an http.Handler with cross-cutting behavior. It receives
|
||||
// the next handler in the chain and returns a handler that may call next,
|
||||
// modify the request/response around it, or short-circuit.
|
||||
type Middleware func(next http.Handler) http.Handler
|
||||
|
||||
// Chain is a reusable middleware stack. Build it once with New (and optionally
|
||||
// extend per-route with Append), then call Then to wrap each terminal handler
|
||||
// when registering routes against an http.ServeMux:
|
||||
//
|
||||
// api := chain.New(authMW, corsMW)
|
||||
// mux.Handle("/v1/chat/completions", api.Then(dispatch))
|
||||
// mux.Handle("/v1/embeddings", api.Append(filters).Then(dispatch))
|
||||
//
|
||||
// Middlewares execute left-to-right: mws[0] runs first and may call into
|
||||
// mws[1], and so on, with the terminal handler invoked last. A middleware
|
||||
// that does not call next short-circuits the remainder of the chain.
|
||||
// A zero Chain is valid and applies no middleware.
|
||||
type Chain struct {
|
||||
mws []Middleware
|
||||
}
|
||||
|
||||
// New returns a Chain that applies mws left-to-right around any terminal
|
||||
// handler passed to Then.
|
||||
func New(mws ...Middleware) Chain {
|
||||
cp := make([]Middleware, len(mws))
|
||||
copy(cp, mws)
|
||||
return Chain{mws: cp}
|
||||
}
|
||||
|
||||
// Append returns a new Chain with mws added after the existing middleware.
|
||||
// The receiver is not modified, so a base Chain can be safely reused across
|
||||
// multiple routes that each need different per-route additions.
|
||||
func (c Chain) Append(mws ...Middleware) Chain {
|
||||
out := make([]Middleware, 0, len(c.mws)+len(mws))
|
||||
out = append(out, c.mws...)
|
||||
out = append(out, mws...)
|
||||
return Chain{mws: out}
|
||||
}
|
||||
|
||||
// Then wraps final with the chain's middleware and returns the resulting
|
||||
// handler, suitable for passing to http.ServeMux.Handle. With an empty chain,
|
||||
// Then returns final unchanged.
|
||||
func (c Chain) Then(final http.Handler) http.Handler {
|
||||
h := final
|
||||
for i := len(c.mws) - 1; i >= 0; i-- {
|
||||
h = c.mws[i](h)
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
// ThenFunc is shorthand for Then(http.HandlerFunc(f)).
|
||||
func (c Chain) ThenFunc(f http.HandlerFunc) http.Handler {
|
||||
return c.Then(f)
|
||||
}
|
||||
@@ -0,0 +1,205 @@
|
||||
package chain
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// recordingMiddleware appends tag before calling next and "-after-"+tag after.
|
||||
func recordingMiddleware(tag string, log *[]string) Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
*log = append(*log, tag)
|
||||
next.ServeHTTP(w, r)
|
||||
*log = append(*log, "after-"+tag)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_HandlersExecuteInDeclaredOrder(t *testing.T) {
|
||||
var log []string
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log = append(log, "final")
|
||||
})
|
||||
|
||||
h := New(
|
||||
recordingMiddleware("a", &log),
|
||||
recordingMiddleware("b", &log),
|
||||
recordingMiddleware("c", &log),
|
||||
).Then(final)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
h.ServeHTTP(rec, req)
|
||||
|
||||
want := []string{"a", "b", "c", "final", "after-c", "after-b", "after-a"}
|
||||
if !equal(log, want) {
|
||||
t.Fatalf("execution order mismatch:\n got: %v\nwant: %v", log, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_ShortCircuitsWhenMiddlewareDoesNotCallNext(t *testing.T) {
|
||||
var log []string
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log = append(log, "final")
|
||||
})
|
||||
|
||||
gate := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log = append(log, "gate")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
|
||||
h := New(
|
||||
recordingMiddleware("outer", &log),
|
||||
gate,
|
||||
recordingMiddleware("inner", &log),
|
||||
).Then(final)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
h.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status: got %d, want %d", rec.Code, http.StatusUnauthorized)
|
||||
}
|
||||
want := []string{"outer", "gate", "after-outer"}
|
||||
if !equal(log, want) {
|
||||
t.Fatalf("short-circuit order mismatch:\n got: %v\nwant: %v", log, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_EarlyWritesAreVisibleToLaterMiddleware(t *testing.T) {
|
||||
header := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Set-By", "outer")
|
||||
_, _ = io.WriteString(w, "outer:")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
inner := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// The outer middleware already set the header; we should see it.
|
||||
if got := w.Header().Get("X-Set-By"); got != "outer" {
|
||||
_, _ = io.WriteString(w, "missing-header;")
|
||||
}
|
||||
_, _ = io.WriteString(w, "inner:")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "final")
|
||||
})
|
||||
|
||||
h := New(header, inner).Then(final)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
|
||||
body, _ := io.ReadAll(rec.Body)
|
||||
if got := string(body); !strings.Contains(got, "outer:inner:final") {
|
||||
t.Fatalf("body: got %q, want it to contain %q", got, "outer:inner:final")
|
||||
}
|
||||
if got := rec.Header().Get("X-Set-By"); got != "outer" {
|
||||
t.Fatalf("header X-Set-By: got %q, want %q", got, "outer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_ReusableAcrossRoutesViaThen(t *testing.T) {
|
||||
var log []string
|
||||
base := New(
|
||||
recordingMiddleware("auth", &log),
|
||||
recordingMiddleware("cors", &log),
|
||||
)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/a", base.ThenFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log = append(log, "handler-a")
|
||||
}))
|
||||
mux.Handle("/b", base.ThenFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log = append(log, "handler-b")
|
||||
}))
|
||||
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
for _, path := range []string{"/a", "/b"} {
|
||||
resp, err := http.Get(srv.URL + path)
|
||||
if err != nil {
|
||||
t.Fatalf("GET %s: %v", path, err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
want := []string{
|
||||
"auth", "cors", "handler-a", "after-cors", "after-auth",
|
||||
"auth", "cors", "handler-b", "after-cors", "after-auth",
|
||||
}
|
||||
if !equal(log, want) {
|
||||
t.Fatalf("reusable chain order mismatch:\n got: %v\nwant: %v", log, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_AppendDoesNotMutateReceiver(t *testing.T) {
|
||||
var log []string
|
||||
base := New(recordingMiddleware("base", &log))
|
||||
extended := base.Append(recordingMiddleware("extra", &log))
|
||||
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log = append(log, "final")
|
||||
})
|
||||
|
||||
// Run extended first to surface any aliasing of the underlying slice.
|
||||
rec := httptest.NewRecorder()
|
||||
extended.Then(final).ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
base.Then(final).ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
|
||||
want := []string{
|
||||
"base", "extra", "final", "after-extra", "after-base",
|
||||
"base", "final", "after-base",
|
||||
}
|
||||
if !equal(log, want) {
|
||||
t.Fatalf("Append must not mutate the receiver:\n got: %v\nwant: %v", log, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_ZeroValueAndEmptyThenAreIdentity(t *testing.T) {
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
})
|
||||
|
||||
for name, c := range map[string]Chain{
|
||||
"zero": {},
|
||||
"empty": New(),
|
||||
} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := c.Then(final)
|
||||
if _, ok := h.(http.HandlerFunc); !ok {
|
||||
t.Fatalf("expected http.HandlerFunc identity, got %T", h)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
if rec.Code != http.StatusTeapot {
|
||||
t.Fatalf("status: got %d, want %d", rec.Code, http.StatusTeapot)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func equal(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/billziss-gh/golib/shlex"
|
||||
)
|
||||
|
||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
// Handle trailing backslashes by replacing with space
|
||||
if strings.HasSuffix(trimmed, "\\") {
|
||||
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||
} else {
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
// put it back together
|
||||
cmdStr = strings.Join(cleanedLines, "\n")
|
||||
|
||||
// Split the command into arguments
|
||||
var args []string
|
||||
if runtime.GOOS == "windows" {
|
||||
args = shlex.Windows.Split(cmdStr)
|
||||
} else {
|
||||
args = shlex.Posix.Split(cmdStr)
|
||||
}
|
||||
|
||||
// Ensure the command is not empty
|
||||
if len(args) == 0 {
|
||||
return nil, fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func StripComments(cmdStr string) string {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
return strings.Join(cleanedLines, "\n")
|
||||
}
|
||||
@@ -0,0 +1,257 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const DEFAULT_GROUP_ID = "(default)"
|
||||
const (
|
||||
LogToStdoutProxy = "proxy"
|
||||
LogToStdoutUpstream = "upstream"
|
||||
LogToStdoutBoth = "both"
|
||||
LogToStdoutNone = "none"
|
||||
)
|
||||
|
||||
type MacroEntry struct {
|
||||
Name string
|
||||
Value any
|
||||
}
|
||||
|
||||
type MacroList []MacroEntry
|
||||
|
||||
// UnmarshalYAML implements custom YAML unmarshaling that preserves macro definition order
|
||||
func (ml *MacroList) UnmarshalYAML(value *yaml.Node) error {
|
||||
if value.Kind != yaml.MappingNode {
|
||||
return fmt.Errorf("macros must be a mapping")
|
||||
}
|
||||
|
||||
// yaml.Node.Content for a mapping contains alternating key/value nodes
|
||||
entries := make([]MacroEntry, 0, len(value.Content)/2)
|
||||
for i := 0; i < len(value.Content); i += 2 {
|
||||
keyNode := value.Content[i]
|
||||
valueNode := value.Content[i+1]
|
||||
|
||||
var name string
|
||||
if err := keyNode.Decode(&name); err != nil {
|
||||
return fmt.Errorf("failed to decode macro name: %w", err)
|
||||
}
|
||||
|
||||
var val any
|
||||
if err := valueNode.Decode(&val); err != nil {
|
||||
return fmt.Errorf("failed to decode macro value for '%s': %w", name, err)
|
||||
}
|
||||
|
||||
entries = append(entries, MacroEntry{Name: name, Value: val})
|
||||
}
|
||||
|
||||
*ml = entries
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a macro value by name
|
||||
func (ml MacroList) Get(name string) (any, bool) {
|
||||
for _, entry := range ml {
|
||||
if entry.Name == name {
|
||||
return entry.Value, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// ToMap converts MacroList to a map (for backward compatibility if needed)
|
||||
func (ml MacroList) ToMap() map[string]any {
|
||||
result := make(map[string]any, len(ml))
|
||||
for _, entry := range ml {
|
||||
result[entry.Name] = entry.Value
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type GroupConfig struct {
|
||||
Swap bool `yaml:"swap"`
|
||||
Exclusive bool `yaml:"exclusive"`
|
||||
Persistent bool `yaml:"persistent"`
|
||||
Members []string `yaml:"members"`
|
||||
}
|
||||
|
||||
// set default values for GroupConfig
|
||||
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawGroupConfig GroupConfig
|
||||
defaults := rawGroupConfig{
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Persistent: false,
|
||||
Members: []string{},
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*c = GroupConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
type HooksConfig struct {
|
||||
OnStartup HookOnStartup `yaml:"on_startup"`
|
||||
}
|
||||
|
||||
type HookOnStartup struct {
|
||||
Preload []string `yaml:"preload"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
LogRequests bool `yaml:"logRequests"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
LogTimeFormat string `yaml:"logTimeFormat"`
|
||||
LogToStdout string `yaml:"logToStdout"`
|
||||
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
|
||||
CaptureBuffer int `yaml:"captureBuffer"`
|
||||
Performance PerformanceConfig `yaml:"performance"`
|
||||
GlobalTTL int `yaml:"globalTTL"`
|
||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
|
||||
// routing is the canonical source for swap/scheduling configuration.
|
||||
// New code must read Routing, never the backwards-compat fields below.
|
||||
Routing RoutingConfig `yaml:"routing"`
|
||||
|
||||
// Groups and Matrix are permanent backwards-compat input fields for the
|
||||
// legacy top-level `groups:`/`matrix:` keys. They are normalized into
|
||||
// Routing by LoadConfigFromReader. New code must not read them directly.
|
||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||
Matrix *MatrixConfig `yaml:"matrix"`
|
||||
|
||||
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
||||
Macros MacroList `yaml:"macros"`
|
||||
|
||||
// map aliases to actual model IDs
|
||||
aliases map[string]string
|
||||
|
||||
// automatic port assignments
|
||||
StartPort int `yaml:"startPort"`
|
||||
|
||||
// hooks, see: #209
|
||||
Hooks HooksConfig `yaml:"hooks"`
|
||||
|
||||
// send loading state in reasoning
|
||||
SendLoadingState bool `yaml:"sendLoadingState"`
|
||||
|
||||
// present aliases to /v1/models OpenAI API listing
|
||||
IncludeAliasesInList bool `yaml:"includeAliasesInList"`
|
||||
|
||||
// support API keys, see issue #433, #50, #251
|
||||
RequiredAPIKeys []string `yaml:"apiKeys"`
|
||||
|
||||
// support remote peers, see issue #433, #296
|
||||
Peers PeerDictionaryConfig `yaml:"peers"`
|
||||
|
||||
// upstream controls behaviour of the /upstream passthrough endpoint
|
||||
Upstream UpstreamConfig `yaml:"upstream"`
|
||||
}
|
||||
|
||||
// RoutingConfig is the canonical, normalized routing/scheduling configuration.
|
||||
type RoutingConfig struct {
|
||||
Scheduler SchedulerConfig `yaml:"scheduler"`
|
||||
Router RouterConfig `yaml:"router"`
|
||||
}
|
||||
|
||||
type SchedulerConfig struct {
|
||||
Use string `yaml:"use"` // default "fifo"
|
||||
Settings SchedulerSettings `yaml:"settings"`
|
||||
}
|
||||
|
||||
type SchedulerSettings struct {
|
||||
Fifo FifoConfig `yaml:"fifo"`
|
||||
}
|
||||
|
||||
type FifoConfig struct {
|
||||
Priority map[string]int `yaml:"priority"` // model ID -> priority, default 0
|
||||
}
|
||||
|
||||
type RouterConfig struct {
|
||||
Use string `yaml:"use"` // "group" (default) | "matrix"
|
||||
Settings RouterSettings `yaml:"settings"`
|
||||
}
|
||||
|
||||
type RouterSettings struct {
|
||||
Groups map[string]GroupConfig `yaml:"groups"`
|
||||
Matrix *MatrixConfig `yaml:"matrix"`
|
||||
}
|
||||
|
||||
func (c *Config) RealModelName(search string) (string, bool) {
|
||||
if _, found := c.Models[search]; found {
|
||||
return search, true
|
||||
} else if name, found := c.aliases[search]; found {
|
||||
return name, found
|
||||
} else {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
||||
if realName, found := c.RealModelName(modelName); !found {
|
||||
return ModelConfig{}, "", false
|
||||
} else {
|
||||
return c.Models[realName], realName, true
|
||||
}
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (Config, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
defer file.Close()
|
||||
return LoadConfigFromReader(file)
|
||||
}
|
||||
|
||||
// rewrites the yaml to include a default group with any orphaned models
|
||||
func AddDefaultGroupToConfig(config Config) Config {
|
||||
|
||||
if config.Groups == nil {
|
||||
config.Groups = make(map[string]GroupConfig)
|
||||
}
|
||||
|
||||
defaultGroup := GroupConfig{
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{},
|
||||
}
|
||||
// if groups is empty, create a default group and put
|
||||
// all models into it
|
||||
if len(config.Groups) == 0 {
|
||||
for modelName := range config.Models {
|
||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||
}
|
||||
} else {
|
||||
// iterate over existing group members and add non-grouped models into the default group
|
||||
for modelName := range config.Models {
|
||||
foundModel := false
|
||||
found:
|
||||
// search for the model in existing groups
|
||||
for _, groupConfig := range config.Groups {
|
||||
for _, member := range groupConfig.Members {
|
||||
if member == modelName {
|
||||
foundModel = true
|
||||
break found
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundModel {
|
||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
|
||||
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
|
||||
|
||||
return config
|
||||
}
|
||||
@@ -173,6 +173,25 @@ groups:
|
||||
IdleConn: 90,
|
||||
}
|
||||
|
||||
expectedGroups := map[string]GroupConfig{
|
||||
DEFAULT_GROUP_ID: {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model3"},
|
||||
},
|
||||
"group1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model4"},
|
||||
},
|
||||
}
|
||||
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
@@ -189,42 +208,46 @@ groups:
|
||||
SendLoadingState: false,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
Name: "Model 1",
|
||||
Description: "This is model 1",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
Name: "Model 1",
|
||||
Description: "This is model 1",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
HealthCheckTimeout: 15,
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
HealthCheckTimeout: 15,
|
||||
},
|
||||
"model3": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"mthree"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"mthree"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
HealthCheckTimeout: 15,
|
||||
},
|
||||
"model4": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8082",
|
||||
CheckEndpoint: "/",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8082",
|
||||
CheckEndpoint: "/",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
HealthCheckTimeout: 15,
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
@@ -242,22 +265,19 @@ groups:
|
||||
"m2": "model2",
|
||||
"mthree": "model3",
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
DEFAULT_GROUP_ID: {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model3"},
|
||||
Groups: expectedGroups,
|
||||
Upstream: UpstreamConfig{
|
||||
IgnorePaths: DefaultUpstreamIgnorePaths(),
|
||||
},
|
||||
Routing: RoutingConfig{
|
||||
Router: RouterConfig{
|
||||
Use: "group",
|
||||
Settings: RouterSettings{
|
||||
Groups: expectedGroups,
|
||||
},
|
||||
},
|
||||
"group1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model4"},
|
||||
Scheduler: SchedulerConfig{
|
||||
Use: "serial",
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/google/jsonschema-go/jsonschema"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// TestConfig_ExampleMatchesSchema validates that config.example.yaml conforms to
|
||||
// config-schema.json. Both files live at the repository root.
|
||||
func TestConfig_ExampleMatchesSchema(t *testing.T) {
|
||||
const (
|
||||
schemaPath = "../../config-schema.json"
|
||||
examplePath = "../../config.example.yaml"
|
||||
)
|
||||
|
||||
schemaBytes, err := os.ReadFile(schemaPath)
|
||||
if err != nil {
|
||||
t.Fatalf("reading %s: %v", schemaPath, err)
|
||||
}
|
||||
|
||||
var schema jsonschema.Schema
|
||||
if err := json.Unmarshal(schemaBytes, &schema); err != nil {
|
||||
t.Fatalf("unmarshalling schema: %v", err)
|
||||
}
|
||||
|
||||
resolved, err := schema.Resolve(&jsonschema.ResolveOptions{
|
||||
BaseURI: "https://github.com/mostlygeek/llama-swap/",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("resolving schema: %v", err)
|
||||
}
|
||||
|
||||
exampleBytes, err := os.ReadFile(examplePath)
|
||||
if err != nil {
|
||||
t.Fatalf("reading %s: %v", examplePath, err)
|
||||
}
|
||||
|
||||
// Convert YAML to a JSON-like value so numbers and keys match what the
|
||||
// validator expects.
|
||||
var yamlValue any
|
||||
if err := yaml.Unmarshal(exampleBytes, &yamlValue); err != nil {
|
||||
t.Fatalf("unmarshalling example yaml: %v", err)
|
||||
}
|
||||
jsonBytes, err := json.Marshal(yamlValue)
|
||||
if err != nil {
|
||||
t.Fatalf("converting example to json: %v", err)
|
||||
}
|
||||
var instance any
|
||||
if err := json.Unmarshal(jsonBytes, &instance); err != nil {
|
||||
t.Fatalf("unmarshalling example json: %v", err)
|
||||
}
|
||||
|
||||
if err := resolved.Validate(instance); err != nil {
|
||||
t.Fatalf("config.example.yaml does not match config-schema.json:\n%v", err)
|
||||
}
|
||||
}
|
||||
@@ -777,22 +777,27 @@ func TestConfig_APIKeys_Invalid(t *testing.T) {
|
||||
{
|
||||
name: "blank spaces only",
|
||||
content: `apiKeys: [" "]`,
|
||||
expectedErr: "api key cannot contain spaces: ` `",
|
||||
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||
},
|
||||
{
|
||||
name: "contains leading space",
|
||||
content: `apiKeys: [" key123"]`,
|
||||
expectedErr: "api key cannot contain spaces: ` key123`",
|
||||
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||
},
|
||||
{
|
||||
name: "contains trailing space",
|
||||
content: `apiKeys: ["key123 "]`,
|
||||
expectedErr: "api key cannot contain spaces: `key123 `",
|
||||
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||
},
|
||||
{
|
||||
name: "contains middle space",
|
||||
content: `apiKeys: ["key 123"]`,
|
||||
expectedErr: "api key cannot contain spaces: `key 123`",
|
||||
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||
},
|
||||
{
|
||||
name: "space in second key reports correct index",
|
||||
content: `apiKeys: ["valid-key", "bad key"]`,
|
||||
expectedErr: "apiKeys[1]: api key cannot contain spaces",
|
||||
},
|
||||
{
|
||||
name: "empty in list with valid keys",
|
||||
@@ -1544,3 +1549,174 @@ peers:
|
||||
assert.Equal(t, 1, peerConfig.Timeouts.ExpectContinue)
|
||||
assert.Equal(t, 90, peerConfig.Timeouts.IdleConn)
|
||||
}
|
||||
|
||||
// twoModels is a minimal models block reused by the routing tests below.
|
||||
const twoModels = `
|
||||
models:
|
||||
gemma:
|
||||
cmd: echo gemma
|
||||
proxy: http://localhost:8080
|
||||
qwen:
|
||||
cmd: echo qwen
|
||||
proxy: http://localhost:8081
|
||||
`
|
||||
|
||||
func TestConfig_Routing_LegacyTopLevelGroups(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
groups:
|
||||
g1:
|
||||
members: [gemma, qwen]
|
||||
`
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
||||
// default group injected for orphaned models (none here) still leaves g1
|
||||
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
|
||||
assert.Equal(t, "serial", cfg.Routing.Scheduler.Use)
|
||||
}
|
||||
|
||||
func TestConfig_Routing_LegacyTopLevelMatrix(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
matrix:
|
||||
vars:
|
||||
g: gemma
|
||||
q: qwen
|
||||
sets:
|
||||
combo: "g | q"
|
||||
`
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "matrix", cfg.Routing.Router.Use)
|
||||
require.NotNil(t, cfg.Routing.Router.Settings.Matrix)
|
||||
assert.Len(t, cfg.Routing.Router.Settings.Matrix.ExpandedSets, 2)
|
||||
}
|
||||
|
||||
func TestConfig_Routing_RouterUseMatrix(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
routing:
|
||||
router:
|
||||
use: matrix
|
||||
settings:
|
||||
matrix:
|
||||
vars:
|
||||
g: gemma
|
||||
q: qwen
|
||||
sets:
|
||||
combo: "g | q"
|
||||
`
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "matrix", cfg.Routing.Router.Use)
|
||||
require.NotNil(t, cfg.Routing.Router.Settings.Matrix)
|
||||
assert.Len(t, cfg.Routing.Router.Settings.Matrix.ExpandedSets, 2)
|
||||
}
|
||||
|
||||
func TestConfig_Routing_RouterUseGroup(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
routing:
|
||||
router:
|
||||
use: group
|
||||
settings:
|
||||
groups:
|
||||
g1:
|
||||
members: [gemma, qwen]
|
||||
`
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
||||
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
|
||||
}
|
||||
|
||||
func TestConfig_Routing_DefaultsToGroup(t *testing.T) {
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(twoModels))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
||||
assert.Equal(t, "serial", cfg.Routing.Scheduler.Use)
|
||||
}
|
||||
|
||||
func TestConfig_Routing_LegacyAndRoutingConflict(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
groups:
|
||||
g1:
|
||||
members: [gemma, qwen]
|
||||
routing:
|
||||
router:
|
||||
use: group
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "migrate")
|
||||
}
|
||||
|
||||
func TestConfig_Routing_RouterUseMatrixWithoutSettings(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
routing:
|
||||
router:
|
||||
use: matrix
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "routing.router.settings.matrix is not set")
|
||||
}
|
||||
|
||||
// Both groups and matrix may be defined under routing.router.settings;
|
||||
// routing.router.use selects which one is active.
|
||||
func TestConfig_Routing_RouterSettingsBothGroupsAndMatrix(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
routing:
|
||||
router:
|
||||
use: group
|
||||
settings:
|
||||
groups:
|
||||
g1:
|
||||
members: [gemma, qwen]
|
||||
matrix:
|
||||
sets:
|
||||
s: "gemma"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
// use: group means groups are active and matrix is ignored
|
||||
assert.Equal(t, "group", config.Routing.Router.Use)
|
||||
assert.Nil(t, config.Matrix)
|
||||
assert.Contains(t, config.Groups, "g1")
|
||||
}
|
||||
|
||||
func TestConfig_Routing_UnknownRouter(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
routing:
|
||||
router:
|
||||
use: bogus
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown router")
|
||||
}
|
||||
|
||||
func TestConfig_Routing_FifoPriorityUnknownModel(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
routing:
|
||||
scheduler:
|
||||
settings:
|
||||
fifo:
|
||||
priority:
|
||||
nope: 5
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown model")
|
||||
}
|
||||
|
||||
func TestConfig_Routing_FifoPriorityKnownModel(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
routing:
|
||||
scheduler:
|
||||
settings:
|
||||
fifo:
|
||||
priority:
|
||||
gemma: 5
|
||||
`
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 5, cfg.Routing.Scheduler.Settings.Fifo.Priority["gemma"])
|
||||
}
|
||||
@@ -165,6 +165,25 @@ groups:
|
||||
IdleConn: 90,
|
||||
}
|
||||
|
||||
expectedGroups := map[string]GroupConfig{
|
||||
DEFAULT_GROUP_ID: {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model3"},
|
||||
},
|
||||
"group1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model4"},
|
||||
},
|
||||
}
|
||||
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
@@ -176,44 +195,48 @@ groups:
|
||||
SendLoadingState: false,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
HealthCheckTimeout: 15,
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
HealthCheckTimeout: 15,
|
||||
},
|
||||
"model3": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"mthree"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"mthree"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
HealthCheckTimeout: 15,
|
||||
},
|
||||
"model4": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8082",
|
||||
CheckEndpoint: "/",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8082",
|
||||
CheckEndpoint: "/",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
HealthCheckTimeout: 15,
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
@@ -231,22 +254,19 @@ groups:
|
||||
"m2": "model2",
|
||||
"mthree": "model3",
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
DEFAULT_GROUP_ID: {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model3"},
|
||||
Groups: expectedGroups,
|
||||
Upstream: UpstreamConfig{
|
||||
IgnorePaths: DefaultUpstreamIgnorePaths(),
|
||||
},
|
||||
Routing: RoutingConfig{
|
||||
Router: RouterConfig{
|
||||
Use: "group",
|
||||
Settings: RouterSettings{
|
||||
Groups: expectedGroups,
|
||||
},
|
||||
},
|
||||
"group1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model4"},
|
||||
Scheduler: SchedulerConfig{
|
||||
Use: "serial",
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,441 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
yamlStr := string(data)
|
||||
|
||||
// Phase 1: Substitute all ${env.VAR} macros at string level
|
||||
// This is safe because env values are simple strings without YAML formatting
|
||||
yamlStr, err = substituteEnvMacros(yamlStr)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
// Unmarshal into full Config with defaults
|
||||
config := Config{
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
GlobalTTL: 0,
|
||||
}
|
||||
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
// Apply defaults for performance config when section is missing
|
||||
if config.Performance.Every == 0 {
|
||||
config.Performance.Every = 5 * time.Second
|
||||
}
|
||||
if err = config.Performance.Validate(); err != nil {
|
||||
return Config{}, fmt.Errorf("performance: %w", err)
|
||||
}
|
||||
|
||||
if config.StartPort < 1 {
|
||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||
}
|
||||
|
||||
if config.GlobalTTL < 0 {
|
||||
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
||||
}
|
||||
|
||||
// Apply default for upstream.ignorePaths when not specified. The default
|
||||
// matches common static-asset suffixes so they do not trigger a swap.
|
||||
if len(config.Upstream.IgnorePaths) == 0 {
|
||||
config.Upstream.IgnorePaths = DefaultUpstreamIgnorePaths()
|
||||
}
|
||||
|
||||
switch config.LogToStdout {
|
||||
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
||||
default:
|
||||
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
if _, found := config.aliases[alias]; found {
|
||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||
}
|
||||
config.aliases[alias] = modelName
|
||||
}
|
||||
}
|
||||
|
||||
// Validate global macros
|
||||
for _, macro := range config.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get and sort all model IDs for consistent port assignment
|
||||
modelIds := make([]string, 0, len(config.Models))
|
||||
for modelId := range config.Models {
|
||||
modelIds = append(modelIds, modelId)
|
||||
}
|
||||
sort.Strings(modelIds)
|
||||
|
||||
nextPort := config.StartPort
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
modelConfig.HealthCheckTimeout = config.HealthCheckTimeout
|
||||
|
||||
// Strip comments from command fields
|
||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||
|
||||
// set model TTL to globalTTL it is the default value
|
||||
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
|
||||
modelConfig.UnloadAfter = config.GlobalTTL
|
||||
}
|
||||
|
||||
if modelConfig.UnloadAfter < 0 {
|
||||
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
|
||||
}
|
||||
|
||||
// Validate model macros
|
||||
for _, macro := range modelConfig.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||
mergedMacros = append(mergedMacros, config.Macros...)
|
||||
|
||||
// Add model macros (override globals with same name)
|
||||
for _, entry := range modelConfig.Macros {
|
||||
found := false
|
||||
for i, existing := range mergedMacros {
|
||||
if existing.Name == entry.Name {
|
||||
mergedMacros[i] = entry
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
mergedMacros = append(mergedMacros, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Substitute remaining macros in model fields (LIFO order)
|
||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||
entry := mergedMacros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||
|
||||
// Substitute macros in SetParamsByID keys and values
|
||||
if len(modelConfig.Filters.SetParamsByID) > 0 {
|
||||
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
|
||||
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
|
||||
}
|
||||
newParamMap, ok := newValAny.(map[string]any)
|
||||
if !ok {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
|
||||
}
|
||||
newSetParamsByID[newKey] = newParamMap
|
||||
}
|
||||
modelConfig.Filters.SetParamsByID = newSetParamsByID
|
||||
}
|
||||
|
||||
// Substitute in metadata (type-preserving)
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle PORT macro - only allocate if cmd uses it
|
||||
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||
if cmdHasPort || proxyHasPort {
|
||||
if !cmdHasPort && proxyHasPort {
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
macroSlug := "${PORT}"
|
||||
macroStr := fmt.Sprintf("%v", nextPort)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
|
||||
nextPort++
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
fieldMap := map[string]string{
|
||||
"cmd": modelConfig.Cmd,
|
||||
"cmdStop": modelConfig.CmdStop,
|
||||
"proxy": modelConfig.Proxy,
|
||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||
"filters.stripParams": modelConfig.Filters.StripParams,
|
||||
"name": modelConfig.Name,
|
||||
"description": modelConfig.Description,
|
||||
}
|
||||
|
||||
for fieldName, fieldValue := range fieldMap {
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
if macroName == "PID" && fieldName == "cmdStop" {
|
||||
continue // replaced at runtime
|
||||
}
|
||||
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if err = modelConfig.Capabilities.Validate(); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %w", modelId, err)
|
||||
}
|
||||
|
||||
// Validate SetParamsByID keys and values
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
|
||||
}
|
||||
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
|
||||
for key := range modelConfig.Filters.SetParamsByID {
|
||||
if key == modelId {
|
||||
continue
|
||||
}
|
||||
if _, exists := config.Models[key]; exists {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
|
||||
}
|
||||
if existingModel, exists := config.aliases[key]; exists {
|
||||
if existingModel != modelId {
|
||||
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
|
||||
}
|
||||
continue // already registered as explicit alias for this model
|
||||
}
|
||||
config.aliases[key] = modelId
|
||||
modelConfig.Aliases = append(modelConfig.Aliases, key)
|
||||
}
|
||||
|
||||
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
|
||||
}
|
||||
|
||||
if modelConfig.SendLoadingState == nil {
|
||||
v := config.SendLoadingState
|
||||
modelConfig.SendLoadingState = &v
|
||||
}
|
||||
|
||||
config.Models[modelId] = modelConfig
|
||||
}
|
||||
|
||||
// Normalize routing config. The legacy top-level `matrix`/`groups` keys and
|
||||
// the new `routing.router` block are mutually exclusive: a config may use
|
||||
// either style, never both.
|
||||
hasTopLevel := config.Matrix != nil || len(config.Groups) > 0
|
||||
rtr := config.Routing.Router
|
||||
hasRouting := rtr.Use != "" || rtr.Settings.Matrix != nil || len(rtr.Settings.Groups) > 0
|
||||
|
||||
if hasTopLevel && hasRouting {
|
||||
return Config{}, fmt.Errorf("config uses both the legacy top-level 'matrix'/'groups' keys and the new 'routing.router' block; please migrate the top-level keys into 'routing.router' and remove them")
|
||||
}
|
||||
|
||||
if !hasTopLevel {
|
||||
// Both groups and matrix may be defined under routing.router.settings;
|
||||
// routing.router.use selects which one is active, so there is no conflict.
|
||||
rs := config.Routing.Router.Settings
|
||||
switch config.Routing.Router.Use {
|
||||
case "matrix":
|
||||
if rs.Matrix == nil {
|
||||
return Config{}, fmt.Errorf("routing.router.use is 'matrix' but routing.router.settings.matrix is not set")
|
||||
}
|
||||
config.Matrix = rs.Matrix
|
||||
case "group", "":
|
||||
config.Groups = rs.Groups
|
||||
default:
|
||||
return Config{}, fmt.Errorf("routing.router.use: unknown router %q (valid: group, matrix)", config.Routing.Router.Use)
|
||||
}
|
||||
}
|
||||
|
||||
// groups XOR matrix
|
||||
if config.Matrix != nil && len(config.Groups) > 0 {
|
||||
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
|
||||
}
|
||||
|
||||
if config.Matrix != nil {
|
||||
expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("matrix: %w", err)
|
||||
}
|
||||
config.Matrix.ExpandedSets = expandedSets
|
||||
} else {
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
|
||||
// Validate group members
|
||||
memberUsage := make(map[string]string)
|
||||
for groupID, groupConfig := range config.Groups {
|
||||
prevSet := make(map[string]bool)
|
||||
for _, member := range groupConfig.Members {
|
||||
if _, found := prevSet[member]; found {
|
||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||
}
|
||||
prevSet[member] = true
|
||||
|
||||
if existingGroup, exists := memberUsage[member]; exists {
|
||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||
}
|
||||
memberUsage[member] = groupID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build the canonical Config.Routing from the effective result. Both legacy
|
||||
// and new-style configs converge here. The Matrix pointer is shared so
|
||||
// ExpandedSets stays in one place.
|
||||
if config.Matrix != nil {
|
||||
config.Routing.Router.Use = "matrix"
|
||||
} else {
|
||||
config.Routing.Router.Use = "group"
|
||||
}
|
||||
config.Routing.Router.Settings.Matrix = config.Matrix
|
||||
config.Routing.Router.Settings.Groups = config.Groups
|
||||
|
||||
// This fork defaults to the "serial" scheduler: one model loaded at a time,
|
||||
// requests served in strict arrival order. Set use: fifo for the upstream
|
||||
// throughput-oriented behavior that batches same-model requests.
|
||||
if config.Routing.Scheduler.Use == "" {
|
||||
config.Routing.Scheduler.Use = "serial"
|
||||
}
|
||||
switch config.Routing.Scheduler.Use {
|
||||
case "fifo", "serial":
|
||||
default:
|
||||
return Config{}, fmt.Errorf("routing.scheduler.use: unknown scheduler %q (valid: fifo, serial)", config.Routing.Scheduler.Use)
|
||||
}
|
||||
for modelID := range config.Routing.Scheduler.Settings.Fifo.Priority {
|
||||
if _, found := config.RealModelName(modelID); !found {
|
||||
return Config{}, fmt.Errorf("routing.scheduler.settings.fifo.priority references unknown model %q", modelID)
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up hooks preload
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
var toPreload []string
|
||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
continue
|
||||
}
|
||||
if real, found := config.RealModelName(modelID); found {
|
||||
toPreload = append(toPreload, real)
|
||||
}
|
||||
}
|
||||
config.Hooks.OnStartup.Preload = toPreload
|
||||
}
|
||||
|
||||
// Validate API keys (env macros already substituted at string level)
|
||||
for i, apikey := range config.RequiredAPIKeys {
|
||||
if apikey == "" {
|
||||
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
||||
}
|
||||
if strings.Contains(apikey, " ") {
|
||||
return Config{}, fmt.Errorf("apiKeys[%d]: api key cannot contain spaces", i)
|
||||
}
|
||||
config.RequiredAPIKeys[i] = apikey
|
||||
}
|
||||
|
||||
// Process peers with global macro substitution
|
||||
for peerName, peerConfig := range config.Peers {
|
||||
// Substitute global macros (LIFO order)
|
||||
for i := len(config.Macros) - 1; i >= 0; i-- {
|
||||
entry := config.Macros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
|
||||
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute in setParams (type-preserving)
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
|
||||
}
|
||||
peerConfig.Filters.SetParams = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
config.Peers[peerName] = peerConfig
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
|
||||
)
|
||||
|
||||
// validateMacro validates macro name and value constraints
|
||||
func validateMacro(name string, value any) error {
|
||||
if len(name) >= 64 {
|
||||
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
|
||||
}
|
||||
if !macroNameRegex.MatchString(name) {
|
||||
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
|
||||
}
|
||||
|
||||
// Validate that value is a scalar type
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check for self-reference
|
||||
macroSlug := fmt.Sprintf("${%s}", name)
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
||||
// These types are allowed
|
||||
default:
|
||||
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
|
||||
}
|
||||
|
||||
switch name {
|
||||
case "PORT", "MODEL_ID":
|
||||
return fmt.Errorf("macro name '%s' is reserved", name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
|
||||
func validateNestedForUnknownMacros(value any, context string) error {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
|
||||
}
|
||||
// Check for unsubstituted env macros
|
||||
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range envMatches {
|
||||
varName := match[1]
|
||||
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
|
||||
}
|
||||
return nil
|
||||
|
||||
case map[string]any:
|
||||
for _, val := range v {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
case []any:
|
||||
for _, val := range v {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
default:
|
||||
// Scalar types don't contain macros
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||
// This is called once per macro, allowing LIFO substitution order
|
||||
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
macroStr := fmt.Sprintf("%v", macroValue)
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check if this is a direct macro substitution
|
||||
if v == macroSlug {
|
||||
return macroValue, nil
|
||||
}
|
||||
// Handle string interpolation
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case map[string]any:
|
||||
// Recursively process map values
|
||||
newMap := make(map[string]any)
|
||||
for key, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newMap[key] = newVal
|
||||
}
|
||||
return newMap, nil
|
||||
|
||||
case []any:
|
||||
// Recursively process slice elements
|
||||
newSlice := make([]any, len(v))
|
||||
for i, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSlice[i] = newVal
|
||||
}
|
||||
return newSlice, nil
|
||||
|
||||
default:
|
||||
// Return scalar types as-is
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
|
||||
// Returns error if any referenced env var is not set or contains invalid characters.
|
||||
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
|
||||
// (which strips comments) and only checking the comment-free version for macros.
|
||||
func substituteEnvMacros(s string) (string, error) {
|
||||
// Unmarshal and remarshal to strip YAML comments
|
||||
var raw any
|
||||
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
|
||||
// If YAML is invalid, fall back to scanning the original string
|
||||
// so the user gets the env var error rather than a confusing YAML parse error
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
clean, err := yaml.Marshal(raw)
|
||||
if err != nil {
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
|
||||
return substituteEnvMacrosInString(s, string(clean))
|
||||
}
|
||||
|
||||
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
|
||||
// them in target. This separation allows scanning comment-free YAML while
|
||||
// substituting in the original string.
|
||||
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
|
||||
result := target
|
||||
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
|
||||
for _, match := range matches {
|
||||
fullMatch := match[0] // ${env.VAR_NAME}
|
||||
varName := match[1] // VAR_NAME
|
||||
|
||||
value, exists := os.LookupEnv(varName)
|
||||
if !exists {
|
||||
return "", fmt.Errorf("environment variable '%s' is not set", varName)
|
||||
}
|
||||
|
||||
// Sanitize the value for safe YAML substitution
|
||||
value, err := sanitizeEnvValueForYAML(value, varName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result = strings.ReplaceAll(result, fullMatch, value)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
|
||||
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
|
||||
// for compatibility with double-quoted YAML strings.
|
||||
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
|
||||
// Reject values that would break YAML structure regardless of quoting context
|
||||
if strings.ContainsAny(value, "\n\r\x00") {
|
||||
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
|
||||
}
|
||||
|
||||
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
|
||||
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
|
||||
// In double-quoted contexts, they are interpreted correctly.
|
||||
value = strings.ReplaceAll(value, `\`, `\\`)
|
||||
value = strings.ReplaceAll(value, `"`, `\"`)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
@@ -15,6 +15,9 @@ type MatrixConfig struct {
|
||||
Var map[string]string `yaml:"vars"`
|
||||
EvictCosts map[string]int `yaml:"evict_costs"`
|
||||
Sets OrderedSets `yaml:"sets"`
|
||||
|
||||
// populated by ValidateMatrix; not settable from yaml
|
||||
ExpandedSets []ExpandedSet `yaml:"-"`
|
||||
}
|
||||
|
||||
// SetEntry is a single named set with its DSL expression.
|
||||
@@ -289,7 +289,9 @@ matrix:
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, cfg.Matrix)
|
||||
assert.Len(t, cfg.ExpandedSets, 2)
|
||||
assert.Len(t, cfg.Matrix.ExpandedSets, 2)
|
||||
assert.Equal(t, "matrix", cfg.Routing.Router.Use)
|
||||
assert.Len(t, cfg.Routing.Router.Settings.Matrix.ExpandedSets, 2)
|
||||
// Groups should be empty when matrix is used
|
||||
assert.Empty(t, cfg.Groups)
|
||||
}
|
||||
@@ -0,0 +1,300 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// identityMapPaths is the set of dotted paths whose direct children are
|
||||
// identity-keyed maps. A child key present in two sources is a hard error;
|
||||
// such keys name discrete entities (a model, a group, a peer, etc.) and a
|
||||
// duplicate means the user has split one entity across files by mistake.
|
||||
var identityMapPaths = map[string]bool{
|
||||
"models": true,
|
||||
"groups": true,
|
||||
"profiles": true,
|
||||
"peers": true,
|
||||
"matrix": true,
|
||||
"routing.router.settings.groups": true,
|
||||
"routing.router.settings.matrix": true,
|
||||
}
|
||||
|
||||
// LoadConfigSources loads and merges configuration from -config (optional)
|
||||
// and -config-dir (optional). At least one must be provided. The -config file
|
||||
// is loaded first; *.yml/*.yaml files directly under -config-dir are then
|
||||
// merged in sorted filename order. The merged document is passed through the
|
||||
// existing LoadConfigFromReader pipeline unchanged.
|
||||
func LoadConfigSources(configPath, configDir string) (Config, error) {
|
||||
if configPath == "" && configDir == "" {
|
||||
return Config{}, fmt.Errorf("at least one of -config or -config-dir must be provided")
|
||||
}
|
||||
|
||||
var sourcePaths []string
|
||||
|
||||
if configPath != "" {
|
||||
sourcePaths = append(sourcePaths, configPath)
|
||||
}
|
||||
|
||||
if configDir != "" {
|
||||
dirFiles, err := listYAMLFiles(configDir)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("-config-dir %s: %w", configDir, err)
|
||||
}
|
||||
|
||||
if configPath != "" {
|
||||
absConfig, err := filepath.Abs(configPath)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("failed to resolve -config path: %w", err)
|
||||
}
|
||||
for _, f := range dirFiles {
|
||||
absF, err := filepath.Abs(f)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("failed to resolve config dir file %s: %w", f, err)
|
||||
}
|
||||
if absConfig == absF {
|
||||
return Config{}, fmt.Errorf("-config path %s is also present in -config-dir %s; remove it from one", configPath, configDir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sourcePaths = append(sourcePaths, dirFiles...)
|
||||
}
|
||||
|
||||
if len(sourcePaths) == 0 {
|
||||
return Config{}, fmt.Errorf("no configuration sources found")
|
||||
}
|
||||
|
||||
var merged *yaml.Node
|
||||
for _, p := range sourcePaths {
|
||||
node, err := parseSource(p)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
if node == nil {
|
||||
continue // empty file
|
||||
}
|
||||
if merged == nil {
|
||||
merged = node
|
||||
continue
|
||||
}
|
||||
if err := mergeNodes(merged, node, "", p); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if merged == nil {
|
||||
// All sources were empty; run the pipeline on empty input so defaults
|
||||
// and validation still apply (e.g. startPort, performance defaults).
|
||||
return LoadConfigFromReader(strings.NewReader(""))
|
||||
}
|
||||
|
||||
out, err := yaml.Marshal(merged)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("failed to marshal merged config: %w", err)
|
||||
}
|
||||
return LoadConfigFromReader(strings.NewReader(string(out)))
|
||||
}
|
||||
|
||||
// listYAMLFiles returns the top-level *.yml and *.yaml files in dir, sorted by
|
||||
// filename for deterministic merge order. Subdirectories are not traversed.
|
||||
func listYAMLFiles(dir string) ([]string, error) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var files []string
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := e.Name()
|
||||
if !strings.HasSuffix(name, ".yml") && !strings.HasSuffix(name, ".yaml") {
|
||||
continue
|
||||
}
|
||||
files = append(files, filepath.Join(dir, name))
|
||||
}
|
||||
sort.Strings(files)
|
||||
return files, nil
|
||||
}
|
||||
|
||||
// parseSource reads and parses one YAML config file into a root mapping node.
|
||||
// Returns a nil node (no error) when the file is empty or contains only
|
||||
// comments.
|
||||
//
|
||||
// Env macros (${env.VAR}) are substituted at the string level before YAML
|
||||
// parsing so that flow-style constructs like [${env.API_KEY}] parse
|
||||
// correctly — the brace would otherwise be interpreted as a flow mapping.
|
||||
func parseSource(path string) (*yaml.Node, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config %s: %w", path, err)
|
||||
}
|
||||
yamlStr, err := substituteEnvMacros(string(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config %s: %w", path, err)
|
||||
}
|
||||
var doc yaml.Node
|
||||
if err := yaml.Unmarshal([]byte(yamlStr), &doc); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config %s: %w", path, err)
|
||||
}
|
||||
// yaml.Unmarshal into a yaml.Node yields a DocumentNode whose Content[0]
|
||||
// is the actual root. Unwrap it so callers see the real top-level node.
|
||||
root := &doc
|
||||
if root.Kind == yaml.DocumentNode && len(root.Content) > 0 {
|
||||
root = root.Content[0]
|
||||
}
|
||||
if root.Kind == 0 || root.Content == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if root.Kind != yaml.MappingNode {
|
||||
return nil, fmt.Errorf("config %s: top-level YAML must be a mapping", path)
|
||||
}
|
||||
return root, nil
|
||||
}
|
||||
|
||||
// mergeNodes merges src into dst (both MappingNodes) in place. Keys present in
|
||||
// only one side are kept; shared keys are merged recursively under the rules
|
||||
// in mergeValue. srcPath is included in error messages to identify the file
|
||||
// that introduced the conflict.
|
||||
func mergeNodes(dst, src *yaml.Node, path, srcPath string) error {
|
||||
srcIdx := indexMapping(src)
|
||||
|
||||
// First pass: merge shared keys in place.
|
||||
for i := 0; i+1 < len(dst.Content); i += 2 {
|
||||
keyNode := dst.Content[i]
|
||||
dstVal := dst.Content[i+1]
|
||||
key := keyNode.Value
|
||||
|
||||
srcVal, ok := srcIdx[key]
|
||||
if !ok {
|
||||
continue // dst-only key, keep as-is
|
||||
}
|
||||
|
||||
childPath := joinPath(path, key)
|
||||
|
||||
if identityMapPaths[childPath] {
|
||||
// Identity-keyed map: each child key names a discrete entity
|
||||
// (a model, group, peer, ...). A shared child key is a hard
|
||||
// error; src-only children are appended in the second pass.
|
||||
if err := mergeIdentityMap(dstVal, srcVal, childPath, key, srcPath); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err := mergeValue(dstVal, srcVal, childPath, srcPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: append src-only keys.
|
||||
dstIdx := indexMapping(dst)
|
||||
for i := 0; i+1 < len(src.Content); i += 2 {
|
||||
keyNode := src.Content[i]
|
||||
srcVal := src.Content[i+1]
|
||||
key := keyNode.Value
|
||||
|
||||
if _, ok := dstIdx[key]; ok {
|
||||
continue // already merged above
|
||||
}
|
||||
keyCopy := *keyNode
|
||||
valCopy := *srcVal
|
||||
dst.Content = append(dst.Content, &keyCopy, &valCopy)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// mergeIdentityMap merges two identity-keyed mapping nodes (e.g. `models`,
|
||||
// `groups`, `peers`). Any child key present in both sides is a duplicate
|
||||
// entity and produces an error naming the conflicting key and source file.
|
||||
// src-only keys are appended to dst.
|
||||
func mergeIdentityMap(dst, src *yaml.Node, path, mapName, srcPath string) error {
|
||||
if dst.Kind != yaml.MappingNode || src.Kind != yaml.MappingNode {
|
||||
return fmt.Errorf("conflict at %q: expected a mapping, introduced by %s", path, srcPath)
|
||||
}
|
||||
dstIdx := indexMapping(dst)
|
||||
for i := 0; i+1 < len(src.Content); i += 2 {
|
||||
keyNode := src.Content[i]
|
||||
srcVal := src.Content[i+1]
|
||||
key := keyNode.Value
|
||||
if _, dup := dstIdx[key]; dup {
|
||||
return fmt.Errorf("duplicate %s %q found in %s (already defined in another config source)", mapName, key, srcPath)
|
||||
}
|
||||
keyCopy := *keyNode
|
||||
valCopy := *srcVal
|
||||
dst.Content = append(dst.Content, &keyCopy, &valCopy)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// mergeValue merges srcVal into dstVal (both pointing into the parent's
|
||||
// Content slice). Mapping+Mapping recurses; Sequence+Sequence concatenates;
|
||||
// Scalar+Scalar errors on value mismatch; null on either side yields to the
|
||||
// non-null side.
|
||||
func mergeValue(dstVal, srcVal *yaml.Node, path, srcPath string) error {
|
||||
switch {
|
||||
case dstVal.Kind == yaml.MappingNode && srcVal.Kind == yaml.MappingNode:
|
||||
return mergeNodes(dstVal, srcVal, path, srcPath)
|
||||
|
||||
case dstVal.Kind == yaml.SequenceNode && srcVal.Kind == yaml.SequenceNode:
|
||||
dstVal.Content = append(dstVal.Content, srcVal.Content...)
|
||||
return nil
|
||||
|
||||
case dstVal.Kind == yaml.ScalarNode && srcVal.Kind == yaml.ScalarNode:
|
||||
if isNullScalar(dstVal) {
|
||||
*dstVal = *srcVal
|
||||
return nil
|
||||
}
|
||||
if isNullScalar(srcVal) {
|
||||
return nil
|
||||
}
|
||||
if dstVal.Value == srcVal.Value && dstVal.Tag == srcVal.Tag {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("conflict at %q: %s sets a different value than a previous source", path, srcPath)
|
||||
|
||||
case isNull(dstVal):
|
||||
*dstVal = *srcVal
|
||||
return nil
|
||||
|
||||
case isNull(srcVal):
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("conflict at %q: incompatible YAML node kinds (kind %d vs %d) introduced by %s", path, dstVal.Kind, srcVal.Kind, srcPath)
|
||||
}
|
||||
}
|
||||
|
||||
// isNull reports whether n represents a YAML null (empty or !!null).
|
||||
func isNull(n *yaml.Node) bool {
|
||||
if n == nil || n.Kind == 0 {
|
||||
return true
|
||||
}
|
||||
return isNullScalar(n)
|
||||
}
|
||||
|
||||
func isNullScalar(n *yaml.Node) bool {
|
||||
return n.Kind == yaml.ScalarNode && (n.Tag == "!!null" || n.Tag == "") && n.Value == ""
|
||||
}
|
||||
|
||||
// indexMapping builds a key -> value-node index for a mapping node.
|
||||
func indexMapping(n *yaml.Node) map[string]*yaml.Node {
|
||||
idx := make(map[string]*yaml.Node, len(n.Content)/2)
|
||||
for i := 0; i+1 < len(n.Content); i += 2 {
|
||||
idx[n.Content[i].Value] = n.Content[i+1]
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func joinPath(parent, key string) string {
|
||||
if parent == "" {
|
||||
return key
|
||||
}
|
||||
return parent + "." + key
|
||||
}
|
||||
@@ -0,0 +1,304 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// writeYAML writes content to a file named name inside dir. Returns the full
|
||||
// path of the written file.
|
||||
func writeYAML(t *testing.T, dir, name, content string) string {
|
||||
t.Helper()
|
||||
p := filepath.Join(dir, name)
|
||||
require.NoError(t, os.MkdirAll(filepath.Dir(p), 0o755))
|
||||
require.NoError(t, os.WriteFile(p, []byte(content), 0o644))
|
||||
return p
|
||||
}
|
||||
|
||||
// modelCfg builds a single-model YAML snippet indented for nesting under a
|
||||
// `models:` key. The proxy uses a fixed port so tests don't depend on
|
||||
// ${PORT} allocation.
|
||||
func modelCfg(id, cmd string) string {
|
||||
return " " + id + ":\n cmd: " + cmd + "\n proxy: \"http://localhost:9999\"\n"
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_NeitherProvided(t *testing.T) {
|
||||
_, err := LoadConfigSources("", "")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "at least one of -config or -config-dir")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_ConfigOnly(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := writeYAML(t, dir, "config.yaml", `
|
||||
models:
|
||||
`+modelCfg("model1", "echo hi")+`
|
||||
groups:
|
||||
group1:
|
||||
members: ["model1"]
|
||||
`)
|
||||
cfg, err := LoadConfigSources(cfgPath, "")
|
||||
require.NoError(t, err)
|
||||
_, id, ok := cfg.FindConfig("model1")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "model1", id)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_DirOnly(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("alpha", "echo a"))
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("beta", "echo b"))
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
for _, want := range []string{"alpha", "beta"} {
|
||||
_, _, ok := cfg.FindConfig(want)
|
||||
assert.True(t, ok, "model %s should be present", want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_ConfigPlusDirAdditive(t *testing.T) {
|
||||
// -config lives outside -config-dir; both contribute models additively.
|
||||
dir := t.TempDir()
|
||||
cfgPath := writeYAML(t, dir, "config.yaml", "models:\n"+modelCfg("base", "echo base"))
|
||||
cfgDir := t.TempDir()
|
||||
writeYAML(t, cfgDir, "extra.yaml", "models:\n"+modelCfg("ext", "echo ext"))
|
||||
|
||||
cfg, err := LoadConfigSources(cfgPath, cfgDir)
|
||||
require.NoError(t, err)
|
||||
for _, want := range []string{"base", "ext"} {
|
||||
_, _, ok := cfg.FindConfig(want)
|
||||
assert.True(t, ok, "model %s should be present after merge", want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfigSources_ConfigInDirOverlap verifies that a -config file that
|
||||
// is also a member of -config-dir is rejected.
|
||||
func TestLoadConfigSources_ConfigInDirOverlap(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := writeYAML(t, dir, "main.yaml", "models:\n"+modelCfg("base", "echo base"))
|
||||
|
||||
_, err := LoadConfigSources(cfgPath, dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "is also present in -config-dir")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_DuplicateModelID(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("dup", "echo a"))
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("dup", "echo b"))
|
||||
|
||||
_, err := LoadConfigSources("", dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), `duplicate models "dup"`)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_DuplicateGroupID(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", `
|
||||
models:
|
||||
`+modelCfg("m1", "echo m1")+"groups:\n g1:\n members: [m1]\n")
|
||||
writeYAML(t, dir, "b.yaml", `
|
||||
models:
|
||||
`+modelCfg("m2", "echo m2")+"groups:\n g1:\n members: [m2]\n")
|
||||
|
||||
_, err := LoadConfigSources("", dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), `duplicate groups "g1"`)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_DuplicatePeer(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
peerA := "peers:\n remote:\n proxy: http://x:1\n models: [m1]\n"
|
||||
peerB := "peers:\n remote:\n proxy: http://x:2\n models: [m2]\n"
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\n"+peerA)
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\n"+peerB)
|
||||
|
||||
_, err := LoadConfigSources("", dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), `duplicate peers "remote"`)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_ScalarConflict(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\nglobalTTL: 100\n")
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\nglobalTTL: 200\n")
|
||||
|
||||
_, err := LoadConfigSources("", dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), `conflict at "globalTTL"`)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_ScalarSameValueNoConflict(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\nglobalTTL: 100\n")
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\nglobalTTL: 100\n")
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, cfg.GlobalTTL)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_MacrosConcatenate(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "macros:\n LOW: 1\nmodels:\n"+modelCfg("m1", "echo ${LOW}"))
|
||||
writeYAML(t, dir, "b.yaml", "macros:\n HIGH: 2\nmodels:\n"+modelCfg("m2", "echo ${HIGH}"))
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
// Both macros are available globally after merge.
|
||||
low, ok := cfg.Macros.Get("LOW")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, 1, low)
|
||||
high, ok := cfg.Macros.Get("HIGH")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, 2, high)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_APIKeysConcatenate(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\napiKeys: [key-a]\n")
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\napiKeys: [key-b]\n")
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"key-a", "key-b"}, cfg.RequiredAPIKeys)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_RoutingGroupsMerge(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", `
|
||||
models:
|
||||
`+modelCfg("m1", "echo m1")+`
|
||||
routing:
|
||||
router:
|
||||
settings:
|
||||
groups:
|
||||
groupA:
|
||||
members: [m1]
|
||||
`)
|
||||
writeYAML(t, dir, "b.yaml", `
|
||||
models:
|
||||
`+modelCfg("m2", "echo m2")+`
|
||||
routing:
|
||||
router:
|
||||
settings:
|
||||
groups:
|
||||
groupB:
|
||||
members: [m2]
|
||||
`)
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
groups := cfg.Routing.Router.Settings.Groups
|
||||
assert.Contains(t, groups, "groupA")
|
||||
assert.Contains(t, groups, "groupB")
|
||||
// default group added by pipeline for orphaned/leftover routing groups...
|
||||
// here both groups reference distinct models
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_EnvMacrosSubstituted(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
// Use ${PORT} in cmd so the pipeline allocates a port and substitutes it;
|
||||
// verifies env/macro substitution runs on the merged document.
|
||||
writeYAML(t, dir, "a.yaml", "models:\n m1:\n cmd: serve --port ${PORT}\n proxy: \"http://localhost:${PORT}\"\n")
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
m := cfg.Models["m1"]
|
||||
assert.NotContains(t, m.Cmd, "${PORT}", "PORT macro should have been substituted")
|
||||
assert.NotContains(t, m.Proxy, "${PORT}", "PORT macro should have been substituted in proxy")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_EnvMacroInFlowStyleList(t *testing.T) {
|
||||
// Regression: flow-style lists with ${env.*} must parse. Previously
|
||||
// parseSource unmarshalled before env substitution, so the brace in
|
||||
// [${env.API_KEY}] was misread as a flow mapping and parsing failed.
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n m1:\n cmd: echo hi\n proxy: \"http://localhost:9999\"\n")
|
||||
writeYAML(t, dir, "keys.yaml", "apiKeys: [${env.TEST_API_KEY}]\nmodels:\n m2:\n cmd: echo hi\n proxy: \"http://localhost:9998\"\n")
|
||||
|
||||
t.Setenv("TEST_API_KEY", "secret123")
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, cfg.RequiredAPIKeys, "secret123")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_SortedOrderDeterministic(t *testing.T) {
|
||||
// Two files defining distinct models, scanned in z..a order by filename.
|
||||
// Determine merged result is the same regardless of how the FS returns them.
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "z.yaml", "models:\n"+modelCfg("zmodel", "echo z"))
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("amodel", "echo a"))
|
||||
|
||||
const runs = 3
|
||||
for i := 0; i < runs; i++ {
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
// startPort-based allocation: first allocated model gets 5800.
|
||||
// Sorted order means amodel gets 5800, zmodel gets 5801.
|
||||
_, _, ok := cfg.FindConfig("amodel")
|
||||
assert.True(t, ok)
|
||||
_, _, ok = cfg.FindConfig("zmodel")
|
||||
assert.True(t, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_EmptyDirWithConfig(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgDir := t.TempDir()
|
||||
cfgPath := writeYAML(t, dir, "main.yaml", "models:\n"+modelCfg("m1", "echo m1"))
|
||||
|
||||
cfg, err := LoadConfigSources(cfgPath, cfgDir)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, cfg.Models, "m1")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_EmptyDirOnly(t *testing.T) {
|
||||
// An empty -config-dir with no -config is an error: there is nothing to
|
||||
// load and silently producing an empty config would mask the misconfig.
|
||||
cfgDir := t.TempDir()
|
||||
_, err := LoadConfigSources("", cfgDir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no configuration sources found")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_AssertNoUnknownMacrosAfterMerge(t *testing.T) {
|
||||
// Macros defined in one file should not satisfy unknown-macro validation in
|
||||
// another — they do, because merge concats global macros before validation
|
||||
// runs. This test documents that a macro from file A is usable in file B.
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "macros.yaml", "macros:\n SHARED: hello\nmodels:\n"+modelCfg("dummy", "echo dummy"))
|
||||
writeYAML(t, dir, "use.yaml", "models:\n"+modelCfg("user", "echo ${SHARED}"))
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
m := cfg.Models["user"]
|
||||
assert.Contains(t, m.Cmd, "hello")
|
||||
assert.NotContains(t, m.Cmd, "${SHARED}")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_KindMismatchErrors(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "startPort: 5800\nmodels:\n"+modelCfg("m1", "echo m1"))
|
||||
writeYAML(t, dir, "b.yaml", "startPort: [5800, 5801]\nmodels:\n"+modelCfg("m2", "echo m2"))
|
||||
|
||||
_, err := LoadConfigSources("", dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "incompatible YAML node kinds")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_NullYieldsToValue(t *testing.T) {
|
||||
// File A: routing.router block absent (null on root for routing);
|
||||
// file B: defines routing.router.settings.groups. Merge should keep B's.
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1"))
|
||||
writeYAML(t, dir, "b.yaml", "routing:\n router:\n settings:\n groups:\n g1:\n members: [m1]\nmodels:\n"+modelCfg("m2", "echo m2"))
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
@@ -9,6 +10,47 @@ const (
|
||||
MODEL_CONFIG_DEFAULT_TTL = -1
|
||||
)
|
||||
|
||||
var validModalities = map[string]struct{}{
|
||||
"text": {},
|
||||
"audio": {},
|
||||
"image": {},
|
||||
}
|
||||
|
||||
// ModelCapConfig defines what modalities and features a model supports.
|
||||
// Used in /v1/models to inform clients. An empty block (all zero values) is
|
||||
// treated as not configured.
|
||||
type ModelCapConfig struct {
|
||||
In []string `yaml:"in"`
|
||||
Out []string `yaml:"out"`
|
||||
Tools bool `yaml:"tools"`
|
||||
Reranker bool `yaml:"reranker"`
|
||||
Context int `yaml:"context"`
|
||||
}
|
||||
|
||||
// Empty returns true when all fields are at their zero values.
|
||||
func (c ModelCapConfig) Empty() bool {
|
||||
return len(c.In) == 0 && len(c.Out) == 0 && !c.Tools && !c.Reranker && c.Context == 0
|
||||
}
|
||||
|
||||
// Validate checks that all modality values are recognized and context is
|
||||
// non-negative. Returns an error if any value is invalid.
|
||||
func (c ModelCapConfig) Validate() error {
|
||||
for _, m := range c.In {
|
||||
if _, ok := validModalities[m]; !ok {
|
||||
return fmt.Errorf("capabilities.in: invalid modality %q, must be one of: text, audio, image", m)
|
||||
}
|
||||
}
|
||||
for _, m := range c.Out {
|
||||
if _, ok := validModalities[m]; !ok {
|
||||
return fmt.Errorf("capabilities.out: invalid modality %q, must be one of: text, audio, image", m)
|
||||
}
|
||||
}
|
||||
if c.Context < 0 {
|
||||
return errors.New("capabilities.context: must be >= 0")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TimeoutsConfig holds timeout settings for proxy connections
|
||||
// 0 = no timeout
|
||||
type TimeoutsConfig struct {
|
||||
@@ -54,6 +96,12 @@ type ModelConfig struct {
|
||||
|
||||
// Timeout settings for proxy connections
|
||||
Timeouts TimeoutsConfig `yaml:"timeouts"`
|
||||
|
||||
// Capabilities defines what modalities and features the model supports.
|
||||
Capabilities ModelCapConfig `yaml:"capabilities"`
|
||||
|
||||
// Copy of HealthCheckTimeout from global config
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
}
|
||||
|
||||
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
@@ -152,7 +152,7 @@ models:
|
||||
stop:
|
||||
- "<|end|>"
|
||||
- "<|stop|>"
|
||||
`
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -170,3 +170,167 @@ models:
|
||||
assert.Equal(t, 0.7, setParams["temperature"])
|
||||
assert.Equal(t, 0.9, setParams["top_p"])
|
||||
}
|
||||
|
||||
func TestConfig_ModelCapabilities(t *testing.T) {
|
||||
t.Run("all fields", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
capabilities:
|
||||
in:
|
||||
- text
|
||||
- audio
|
||||
- image
|
||||
out:
|
||||
- text
|
||||
- audio
|
||||
- image
|
||||
tools: true
|
||||
context: 32000
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
mc := config.Models["model1"]
|
||||
assert.False(t, mc.Capabilities.Empty())
|
||||
assert.Equal(t, []string{"text", "audio", "image"}, mc.Capabilities.In)
|
||||
assert.Equal(t, []string{"text", "audio", "image"}, mc.Capabilities.Out)
|
||||
assert.True(t, mc.Capabilities.Tools)
|
||||
assert.Equal(t, 32000, mc.Capabilities.Context)
|
||||
})
|
||||
|
||||
t.Run("partial fields", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
capabilities:
|
||||
tools: true
|
||||
context: 8192
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
mc := config.Models["model1"]
|
||||
assert.False(t, mc.Capabilities.Empty())
|
||||
assert.Nil(t, mc.Capabilities.In)
|
||||
assert.Nil(t, mc.Capabilities.Out)
|
||||
assert.True(t, mc.Capabilities.Tools)
|
||||
assert.Equal(t, 8192, mc.Capabilities.Context)
|
||||
})
|
||||
|
||||
t.Run("not set", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
mc := config.Models["model1"]
|
||||
assert.True(t, mc.Capabilities.Empty())
|
||||
})
|
||||
|
||||
t.Run("tools false is empty", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
capabilities:
|
||||
tools: false
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
mc := config.Models["model1"]
|
||||
assert.True(t, mc.Capabilities.Empty())
|
||||
})
|
||||
|
||||
t.Run("reranker true is not empty", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
capabilities:
|
||||
reranker: true
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
mc := config.Models["model1"]
|
||||
assert.False(t, mc.Capabilities.Empty())
|
||||
assert.True(t, mc.Capabilities.Reranker)
|
||||
})
|
||||
|
||||
t.Run("reranker false is empty", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
capabilities:
|
||||
reranker: false
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
mc := config.Models["model1"]
|
||||
assert.True(t, mc.Capabilities.Empty())
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_ModelCapabilities_Validate(t *testing.T) {
|
||||
t.Run("valid_modalities", func(t *testing.T) {
|
||||
caps := ModelCapConfig{
|
||||
In: []string{"text", "image"},
|
||||
Out: []string{"text", "audio"},
|
||||
Tools: true,
|
||||
Context: 100000,
|
||||
}
|
||||
assert.NoError(t, caps.Validate())
|
||||
})
|
||||
|
||||
t.Run("empty_is_valid", func(t *testing.T) {
|
||||
caps := ModelCapConfig{}
|
||||
assert.NoError(t, caps.Validate())
|
||||
})
|
||||
|
||||
t.Run("invalid_in_modality", func(t *testing.T) {
|
||||
caps := ModelCapConfig{In: []string{"video"}}
|
||||
err := caps.Validate()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "capabilities.in")
|
||||
assert.Contains(t, err.Error(), "video")
|
||||
})
|
||||
|
||||
t.Run("invalid_out_modality", func(t *testing.T) {
|
||||
caps := ModelCapConfig{Out: []string{"video"}}
|
||||
err := caps.Validate()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "capabilities.out")
|
||||
assert.Contains(t, err.Error(), "video")
|
||||
})
|
||||
|
||||
t.Run("negative_context", func(t *testing.T) {
|
||||
caps := ModelCapConfig{Context: -1}
|
||||
err := caps.Validate()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "capabilities.context")
|
||||
})
|
||||
|
||||
t.Run("rejects_invalid_at_load", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
capabilities:
|
||||
in:
|
||||
- text
|
||||
- video
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "video")
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// DefaultUpstreamIgnorePathsPattern is the default regular expression applied
|
||||
// to upstream.ignorePaths when the section is empty or absent from the config.
|
||||
// It matches common static-asset suffixes so requests for .js/.css/.png/etc.
|
||||
// files do not trigger a model swap.
|
||||
const DefaultUpstreamIgnorePathsPattern = `.*\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$`
|
||||
|
||||
// DefaultUpstreamIgnorePaths returns the default compiled ignore paths used
|
||||
// when upstream.ignorePaths is not specified in the config. The returned slice
|
||||
// is fresh so callers may mutate it without affecting other configs.
|
||||
func DefaultUpstreamIgnorePaths() []*regexp.Regexp {
|
||||
return []*regexp.Regexp{regexp.MustCompile(DefaultUpstreamIgnorePathsPattern)}
|
||||
}
|
||||
|
||||
// UpstreamConfig controls behaviour of the /upstream passthrough endpoint.
|
||||
type UpstreamConfig struct {
|
||||
// IgnorePaths is a slice of compiled regular expressions. Any request to
|
||||
// /upstream/<model>/<path> whose remaining path matches any of these
|
||||
// expressions will be ignored and not trigger a swap. When the config
|
||||
// does not specify any patterns, DefaultUpstreamIgnorePaths is applied.
|
||||
IgnorePaths []*regexp.Regexp `yaml:"-"`
|
||||
}
|
||||
|
||||
// rawUpstreamConfig is the intermediate form used to unmarshal the YAML into
|
||||
// plain strings, which are then compiled into *regexp.Regexp.
|
||||
type rawUpstreamConfig struct {
|
||||
IgnorePaths []string `yaml:"ignorePaths"`
|
||||
}
|
||||
|
||||
// UnmarshalYAML compiles each ignorePaths entry into a *regexp.Regexp. If any
|
||||
// entry fails to compile, an error is returned.
|
||||
func (u *UpstreamConfig) UnmarshalYAML(value *yaml.Node) error {
|
||||
var raw rawUpstreamConfig
|
||||
if err := value.Decode(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
patterns := make([]*regexp.Regexp, 0, len(raw.IgnorePaths))
|
||||
for _, p := range raw.IgnorePaths {
|
||||
re, err := regexp.Compile(p)
|
||||
if err != nil {
|
||||
return fmt.Errorf("upstream.ignorePaths: invalid regular expression %q: %w", p, err)
|
||||
}
|
||||
patterns = append(patterns, re)
|
||||
}
|
||||
u.IgnorePaths = patterns
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const upstreamConfigHeader = `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
|
||||
func TestConfig_UpstreamIgnorePaths_DefaultWhenAbsent(t *testing.T) {
|
||||
// When upstream is not specified at all, the default pattern is applied.
|
||||
content := upstreamConfigHeader
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cfg.Upstream.IgnorePaths, 1)
|
||||
|
||||
def := cfg.Upstream.IgnorePaths[0]
|
||||
assert.IsType(t, ®exp.Regexp{}, def)
|
||||
assert.Equal(t, DefaultUpstreamIgnorePathsPattern, def.String())
|
||||
|
||||
// The default matches common static-asset suffixes.
|
||||
assert.True(t, def.MatchString("/foo.js"))
|
||||
assert.True(t, def.MatchString("/bar/baz.json"))
|
||||
assert.True(t, def.MatchString("/static/img.png"))
|
||||
assert.True(t, def.MatchString("/notes.txt"))
|
||||
assert.True(t, def.MatchString("/favicon.ico"))
|
||||
// And does not match inference API paths.
|
||||
assert.False(t, def.MatchString("/v1/chat/completions"))
|
||||
assert.False(t, def.MatchString("/v1/models"))
|
||||
assert.False(t, def.MatchString("/health"))
|
||||
}
|
||||
|
||||
func TestConfig_UpstreamIgnorePaths_DefaultWhenSectionEmpty(t *testing.T) {
|
||||
// When upstream is present but ignorePaths is omitted, the default is still
|
||||
// applied.
|
||||
content := `upstream: {}` + "\n" + upstreamConfigHeader
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cfg.Upstream.IgnorePaths, 1)
|
||||
assert.Equal(t, DefaultUpstreamIgnorePathsPattern, cfg.Upstream.IgnorePaths[0].String())
|
||||
}
|
||||
|
||||
func TestConfig_UpstreamIgnorePaths_Compiles(t *testing.T) {
|
||||
content := `
|
||||
upstream:
|
||||
ignorePaths:
|
||||
- ".*\\.(js|json|css|png|gif|jpg|jpeg|txt)$"
|
||||
- "^/static/.*"
|
||||
` + upstreamConfigHeader
|
||||
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cfg.Upstream.IgnorePaths, 2)
|
||||
|
||||
// Verify the patterns are compiled into *regexp.Regexp and match as expected.
|
||||
assert.True(t, cfg.Upstream.IgnorePaths[0].MatchString("/foo.js"))
|
||||
assert.True(t, cfg.Upstream.IgnorePaths[0].MatchString("/bar/baz.json"))
|
||||
assert.False(t, cfg.Upstream.IgnorePaths[0].MatchString("/v1/chat/completions"))
|
||||
assert.True(t, cfg.Upstream.IgnorePaths[1].MatchString("/static/foo.png"))
|
||||
assert.False(t, cfg.Upstream.IgnorePaths[1].MatchString("/v1/chat/completions"))
|
||||
|
||||
// Confirm the type is *regexp.Regexp to satisfy the API contract.
|
||||
for _, re := range cfg.Upstream.IgnorePaths {
|
||||
assert.IsType(t, ®exp.Regexp{}, re)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_UpstreamIgnorePaths_InvalidRegexReturnsError(t *testing.T) {
|
||||
content := `
|
||||
upstream:
|
||||
ignorePaths:
|
||||
- "[invalid("
|
||||
` + upstreamConfigHeader
|
||||
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "upstream.ignorePaths")
|
||||
assert.Contains(t, err.Error(), "invalid regular expression")
|
||||
}
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
)
|
||||
|
||||
const DataEventID = 0x04
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
package perf
|
||||
|
||||
type LUID struct {
|
||||
LowPart uint32
|
||||
HighPart int32
|
||||
}
|
||||
|
||||
const maxEnumAdapters = 16
|
||||
|
||||
type D3DKMT_ENUMADAPTERS2 struct {
|
||||
NumAdapters uint32
|
||||
pAdapters uintptr
|
||||
}
|
||||
|
||||
type D3DKMT_ADAPTERINFO struct {
|
||||
hAdapter uint32
|
||||
AdapterLuid LUID
|
||||
NumOfSources uint32
|
||||
bPresentMoveRegionsPreferred int32
|
||||
}
|
||||
|
||||
type D3DKMT_OPENADAPTERFROMLUID struct {
|
||||
AdapterLuid LUID
|
||||
hAdapter uint32
|
||||
}
|
||||
|
||||
type D3DKMT_CLOSEADAPTER struct {
|
||||
hAdapter uint32
|
||||
}
|
||||
|
||||
type KMTQUERYADAPTERINFOTYPE int32
|
||||
|
||||
const (
|
||||
KMTQAITYPE_UMDRIVERPRIVATE KMTQUERYADAPTERINFOTYPE = 0
|
||||
KMTQAITYPE_ADAPTERREGISTRYINFO KMTQUERYADAPTERINFOTYPE = 8
|
||||
KMTQAITYPE_DRIVERVERSION KMTQUERYADAPTERINFOTYPE = 13
|
||||
KMTQAITYPE_PHYSICALADAPTERDEVICEIDS KMTQUERYADAPTERINFOTYPE = 31
|
||||
KMTQAITYPE_NODEPERFDATA KMTQUERYADAPTERINFOTYPE = 61
|
||||
KMTQAITYPE_ADAPTERPERFDATA KMTQUERYADAPTERINFOTYPE = 62
|
||||
KMTQAITYPE_ADAPTERPERFDATA_CAPS KMTQUERYADAPTERINFOTYPE = 63
|
||||
)
|
||||
|
||||
type D3DKMT_QUERYADAPTERINFO struct {
|
||||
hAdapter uint32
|
||||
Type KMTQUERYADAPTERINFOTYPE
|
||||
pPrivateDriverData uintptr
|
||||
PrivateDriverDataSize uint32
|
||||
}
|
||||
|
||||
type D3DKMT_ADAPTER_PERFDATA struct {
|
||||
PhysicalAdapterIndex uint32
|
||||
MemoryFrequency uint64
|
||||
MaxMemoryFrequency uint64
|
||||
MaxMemoryFrequencyOC uint64
|
||||
MemoryBandwidth uint64
|
||||
PCIEBandwidth uint64
|
||||
FanRPM uint32
|
||||
Power uint32
|
||||
Temperature uint32
|
||||
PowerStateOverride byte
|
||||
}
|
||||
|
||||
type D3DKMT_QUERYSTATISTICS_TYPE int32
|
||||
|
||||
const (
|
||||
D3DKMT_QUERYSTATISTICS_ADAPTER D3DKMT_QUERYSTATISTICS_TYPE = 0
|
||||
D3DKMT_QUERYSTATISTICS_PROCESS D3DKMT_QUERYSTATISTICS_TYPE = 1
|
||||
D3DKMT_QUERYSTATISTICS_PROCESS_ADAPTER D3DKMT_QUERYSTATISTICS_TYPE = 2
|
||||
D3DKMT_QUERYSTATISTICS_SEGMENT D3DKMT_QUERYSTATISTICS_TYPE = 3
|
||||
D3DKMT_QUERYSTATISTICS_PROCESS_SEGMENT D3DKMT_QUERYSTATISTICS_TYPE = 4
|
||||
D3DKMT_QUERYSTATISTICS_NODE D3DKMT_QUERYSTATISTICS_TYPE = 5
|
||||
D3DKMT_QUERYSTATISTICS_PROCESS_NODE D3DKMT_QUERYSTATISTICS_TYPE = 6
|
||||
D3DKMT_QUERYSTATISTICS_VIDPNSOURCE D3DKMT_QUERYSTATISTICS_TYPE = 7
|
||||
D3DKMT_QUERYSTATISTICS_PROCESS_VIDPNSOURCE D3DKMT_QUERYSTATISTICS_TYPE = 8
|
||||
)
|
||||
|
||||
type D3DKMT_ADAPTER_PERFDATACAPS struct {
|
||||
PhysicalAdapterIndex uint32
|
||||
MaxMemoryBandwidth uint64
|
||||
MaxPCIEBandwidth uint64
|
||||
MaxFanRPM uint32
|
||||
TemperatureMax uint32
|
||||
TemperatureWarning uint32
|
||||
}
|
||||
|
||||
type D3DKMT_QUERYSTATISTICS_QUERY_SEGMENT struct {
|
||||
SegmentId uint32
|
||||
}
|
||||
|
||||
type D3DKMT_QUERYSTATISTICS_QUERY_NODE struct {
|
||||
NodeId uint32
|
||||
}
|
||||
@@ -0,0 +1,529 @@
|
||||
//go:build windows
|
||||
|
||||
package perf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
d3dkmDLL *windows.LazyDLL
|
||||
procEnumAdapters2 *windows.LazyProc
|
||||
procOpenAdapterFromLuid *windows.LazyProc
|
||||
procCloseAdapter *windows.LazyProc
|
||||
procQueryAdapterInfo *windows.LazyProc
|
||||
procQueryStatistics *windows.LazyProc
|
||||
d3dkmtInitOnce sync.Once
|
||||
d3dkmtInitErr error
|
||||
)
|
||||
|
||||
// initD3DKMT lazily loads gdi32.dll and resolves D3DKMT function pointers.
|
||||
// Safe for concurrent use via sync.Once.
|
||||
func initD3DKMT() error {
|
||||
d3dkmtInitOnce.Do(func() {
|
||||
d3dkmDLL = windows.NewLazySystemDLL("gdi32.dll")
|
||||
|
||||
procEnumAdapters2 = d3dkmDLL.NewProc("D3DKMTEnumAdapters2")
|
||||
procOpenAdapterFromLuid = d3dkmDLL.NewProc("D3DKMTOpenAdapterFromLuid")
|
||||
procCloseAdapter = d3dkmDLL.NewProc("D3DKMTCloseAdapter")
|
||||
procQueryAdapterInfo = d3dkmDLL.NewProc("D3DKMTQueryAdapterInfo")
|
||||
procQueryStatistics = d3dkmDLL.NewProc("D3DKMTQueryStatistics")
|
||||
|
||||
for name, p := range map[string]*windows.LazyProc{
|
||||
"D3DKMTEnumAdapters2": procEnumAdapters2,
|
||||
"D3DKMTOpenAdapterFromLuid": procOpenAdapterFromLuid,
|
||||
"D3DKMTCloseAdapter": procCloseAdapter,
|
||||
"D3DKMTQueryAdapterInfo": procQueryAdapterInfo,
|
||||
"D3DKMTQueryStatistics": procQueryStatistics,
|
||||
} {
|
||||
if err := p.Find(); err != nil {
|
||||
d3dkmtInitErr = fmt.Errorf("D3DKMT %s not found: %w", name, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
return d3dkmtInitErr
|
||||
}
|
||||
|
||||
// ntstatusCall invokes a D3DKMT function and returns a non-nil error if the
|
||||
// NTSTATUS result is not STATUS_SUCCESS (0).
|
||||
func ntstatusCall(proc *windows.LazyProc, arg unsafe.Pointer) error {
|
||||
ret, _, _ := proc.Call(uintptr(arg))
|
||||
if ret != 0 {
|
||||
return fmt.Errorf("NTSTATUS 0x%08x", uint32(ret))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// d3dkmEnumerateAdapters enumerates all available graphics adapters via
|
||||
// D3DKMTEnumAdapters2.
|
||||
func d3dkmEnumerateAdapters() ([]D3DKMT_ADAPTERINFO, error) {
|
||||
var adapters [maxEnumAdapters]D3DKMT_ADAPTERINFO
|
||||
enum := D3DKMT_ENUMADAPTERS2{
|
||||
NumAdapters: maxEnumAdapters,
|
||||
pAdapters: uintptr(unsafe.Pointer(&adapters[0])),
|
||||
}
|
||||
if err := ntstatusCall(procEnumAdapters2, unsafe.Pointer(&enum)); err != nil {
|
||||
return nil, fmt.Errorf("EnumAdapters2: %w", err)
|
||||
}
|
||||
if enum.NumAdapters == 0 {
|
||||
return nil, fmt.Errorf("no adapters found")
|
||||
}
|
||||
result := make([]D3DKMT_ADAPTERINFO, enum.NumAdapters)
|
||||
for i := uint32(0); i < enum.NumAdapters; i++ {
|
||||
result[i] = adapters[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// d3dkmOpenAdapter opens a D3DKMT adapter handle for the given LUID.
|
||||
func d3dkmOpenAdapter(luid LUID) (uint32, error) {
|
||||
req := D3DKMT_OPENADAPTERFROMLUID{
|
||||
AdapterLuid: luid,
|
||||
}
|
||||
if err := ntstatusCall(procOpenAdapterFromLuid, unsafe.Pointer(&req)); err != nil {
|
||||
return 0, fmt.Errorf("OpenAdapterFromLuid: %w", err)
|
||||
}
|
||||
return req.hAdapter, nil
|
||||
}
|
||||
|
||||
// d3dkmCloseAdapter closes a previously opened D3DKMT adapter handle.
|
||||
func d3dkmCloseAdapter(hAdapter uint32) error {
|
||||
req := D3DKMT_CLOSEADAPTER{hAdapter: hAdapter}
|
||||
return ntstatusCall(procCloseAdapter, unsafe.Pointer(&req))
|
||||
}
|
||||
|
||||
// d3dkmGetAdapterPerfData queries per-adapter performance data (temperature,
|
||||
// fan RPM, power, bandwidth) via KMTQAITYPE_ADAPTERPERFDATA.
|
||||
func d3dkmGetAdapterPerfData(hAdapter uint32) (*D3DKMT_ADAPTER_PERFDATA, error) {
|
||||
var data D3DKMT_ADAPTER_PERFDATA
|
||||
req := D3DKMT_QUERYADAPTERINFO{
|
||||
hAdapter: hAdapter,
|
||||
Type: KMTQAITYPE_ADAPTERPERFDATA,
|
||||
pPrivateDriverData: uintptr(unsafe.Pointer(&data)),
|
||||
PrivateDriverDataSize: uint32(unsafe.Sizeof(data)),
|
||||
}
|
||||
if err := ntstatusCall(procQueryAdapterInfo, unsafe.Pointer(&req)); err != nil {
|
||||
return nil, fmt.Errorf("QueryAdapterInfo(ADAPTERPERFDATA): %w", err)
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
// d3dkmGetAdapterPerfDataCaps queries static adapter performance capabilities
|
||||
// (max fan RPM, temperature limits, max bandwidth) via KMTQAITYPE_ADAPTERPERFDATA_CAPS.
|
||||
func d3dkmGetAdapterPerfDataCaps(hAdapter uint32) (*D3DKMT_ADAPTER_PERFDATACAPS, error) {
|
||||
var data D3DKMT_ADAPTER_PERFDATACAPS
|
||||
req := D3DKMT_QUERYADAPTERINFO{
|
||||
hAdapter: hAdapter,
|
||||
Type: KMTQAITYPE_ADAPTERPERFDATA_CAPS,
|
||||
pPrivateDriverData: uintptr(unsafe.Pointer(&data)),
|
||||
PrivateDriverDataSize: uint32(unsafe.Sizeof(data)),
|
||||
}
|
||||
if err := ntstatusCall(procQueryAdapterInfo, unsafe.Pointer(&req)); err != nil {
|
||||
return nil, fmt.Errorf("QueryAdapterInfo(ADAPTERPERFDATACAPS): %w", err)
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
type queryStatsBuffer struct {
|
||||
Type int32 // offset 0
|
||||
AdapterLuid LUID // offset 4
|
||||
hProcess uintptr // offset 16
|
||||
// _result mirrors the D3DKMT_QUERYSTATISTICS_RESULT union.
|
||||
// sizeof(D3DKMT_QUERYSTATISTICS) == 0x328 (808 bytes) on x64.
|
||||
//
|
||||
// The C struct layout (x64):
|
||||
// offset 0: Type (int32, 4 bytes)
|
||||
// offset 4: AdapterLuid (LUID, 8 bytes)
|
||||
// offset 12: 4 bytes padding (for 8-byte alignment of hProcess)
|
||||
// offset 16: hProcess (HANDLE, 8 bytes)
|
||||
// offset 24: QueryResult (union, 780 bytes — largest member is AdapterInformation)
|
||||
// offset 804: anonymous input union (QueryNode.NodeId / QuerySegment.SegmentId, 4 bytes)
|
||||
//
|
||||
// Previous bug: _result was [776]byte, placing QueryId at offset 800 instead of 804.
|
||||
// The kernel read NodeId/SegmentId from offset 804 (always zero from _pad),
|
||||
// causing all NODE and SEGMENT queries to use index 0 regardless of the value
|
||||
// passed in QueryId. This produced alternating behavior where only GPU util OR
|
||||
// memory util appeared to work, depending on which test variant happened to put
|
||||
// non-zero data near offset 804 in the result buffer.
|
||||
_result [780]byte // offset 24, size 780 — places QueryId at offset 804
|
||||
QueryId int32 // offset 804 — matches C anonymous union for NodeId/SegmentId
|
||||
}
|
||||
|
||||
func init() {
|
||||
var buf queryStatsBuffer
|
||||
if unsafe.Sizeof(buf) != 808 {
|
||||
panic(fmt.Sprintf("queryStatsBuffer size %d != expected 808 (sizeof D3DKMT_QUERYSTATISTICS on x64)", unsafe.Sizeof(buf)))
|
||||
}
|
||||
if unsafe.Offsetof(buf.QueryId) != 804 {
|
||||
panic(fmt.Sprintf("queryStatsBuffer.QueryId offset %d != expected 804 (C anonymous union offset)", unsafe.Offsetof(buf.QueryId)))
|
||||
}
|
||||
|
||||
var perfData D3DKMT_ADAPTER_PERFDATA
|
||||
if unsafe.Sizeof(perfData) != 64 {
|
||||
panic(fmt.Sprintf("D3DKMT_ADAPTER_PERFDATA size %d != expected 64 on x64", unsafe.Sizeof(perfData)))
|
||||
}
|
||||
|
||||
var caps D3DKMT_ADAPTER_PERFDATACAPS
|
||||
if unsafe.Sizeof(caps) != 40 {
|
||||
panic(fmt.Sprintf("D3DKMT_ADAPTER_PERFDATACAPS size %d != expected 40 on x64", unsafe.Sizeof(caps)))
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
qsoffsetNbSegments = 0
|
||||
qsoffsetNodeCount = 4
|
||||
qsoffsetCommitLimit = 0
|
||||
qsoffsetBytesCommitted = 8
|
||||
qsoffsetBytesResident = 16
|
||||
qsoffsetRunningTime = 0
|
||||
qsoffsetSystemRunningTime = 272
|
||||
)
|
||||
|
||||
// d3dkmQueryAdapterStats returns the number of memory segments and compute
|
||||
// nodes for the adapter identified by luid.
|
||||
func d3dkmQueryAdapterStats(luid LUID) (nbSegments uint32, nodeCount uint32, err error) {
|
||||
buf := queryStatsBuffer{
|
||||
Type: int32(D3DKMT_QUERYSTATISTICS_ADAPTER),
|
||||
AdapterLuid: luid,
|
||||
}
|
||||
if err := ntstatusCall(procQueryStatistics, unsafe.Pointer(&buf)); err != nil {
|
||||
return 0, 0, fmt.Errorf("QueryStatistics(ADAPTER): %w", err)
|
||||
}
|
||||
nbSegments = binary.LittleEndian.Uint32(buf._result[qsoffsetNbSegments : qsoffsetNbSegments+4])
|
||||
nodeCount = binary.LittleEndian.Uint32(buf._result[qsoffsetNodeCount : qsoffsetNodeCount+4])
|
||||
return nbSegments, nodeCount, nil
|
||||
}
|
||||
|
||||
// d3dkmQuerySegmentStats returns the commit limit (total) and resident
|
||||
// (used) bytes for the given memory segment of an adapter.
|
||||
func d3dkmQuerySegmentStats(luid LUID, segmentID uint32) (commitLimit uint64, bytesResident uint64, err error) {
|
||||
buf := queryStatsBuffer{
|
||||
Type: int32(D3DKMT_QUERYSTATISTICS_SEGMENT),
|
||||
AdapterLuid: luid,
|
||||
QueryId: int32(segmentID),
|
||||
}
|
||||
if err := ntstatusCall(procQueryStatistics, unsafe.Pointer(&buf)); err != nil {
|
||||
return 0, 0, fmt.Errorf("QueryStatistics(SEGMENT %d): %w", segmentID, err)
|
||||
}
|
||||
commitLimit = binary.LittleEndian.Uint64(buf._result[qsoffsetCommitLimit : qsoffsetCommitLimit+8])
|
||||
bytesResident = binary.LittleEndian.Uint64(buf._result[qsoffsetBytesResident : qsoffsetBytesResident+8])
|
||||
if bytesResident == 0 {
|
||||
bytesResident = binary.LittleEndian.Uint64(buf._result[qsoffsetBytesCommitted : qsoffsetBytesCommitted+8])
|
||||
}
|
||||
return commitLimit, bytesResident, nil
|
||||
}
|
||||
|
||||
// d3dkmQueryNodeStats returns the global and system running time counters
|
||||
// (in 100ns units) for the given compute node of an adapter.
|
||||
func d3dkmQueryNodeStats(luid LUID, nodeID uint32) (runningTime uint64, systemRunningTime uint64, err error) {
|
||||
buf := queryStatsBuffer{
|
||||
Type: int32(D3DKMT_QUERYSTATISTICS_NODE),
|
||||
AdapterLuid: luid,
|
||||
QueryId: int32(nodeID),
|
||||
}
|
||||
if err := ntstatusCall(procQueryStatistics, unsafe.Pointer(&buf)); err != nil {
|
||||
return 0, 0, fmt.Errorf("QueryStatistics(NODE %d): %w", nodeID, err)
|
||||
}
|
||||
runningTime = binary.LittleEndian.Uint64(buf._result[qsoffsetRunningTime : qsoffsetRunningTime+8])
|
||||
systemRunningTime = binary.LittleEndian.Uint64(buf._result[qsoffsetSystemRunningTime : qsoffsetSystemRunningTime+8])
|
||||
return runningTime, systemRunningTime, nil
|
||||
}
|
||||
|
||||
type nodeRunningTimes struct {
|
||||
Global uint64
|
||||
System uint64
|
||||
}
|
||||
|
||||
// d3dkmtNodeUtil computes GPU node utilization as a percentage from running
|
||||
// time deltas. Returns -1 if counters went backwards (wrap/reset), 0 if idle.
|
||||
func d3dkmtNodeUtil(prevRT, curRT nodeRunningTimes, elapsed100ns int64) float64 {
|
||||
if curRT.Global < prevRT.Global || curRT.System < prevRT.System {
|
||||
return -1
|
||||
}
|
||||
gd := curRT.Global - prevRT.Global
|
||||
sd := curRT.System - prevRT.System
|
||||
|
||||
if gd > 0 && sd > 0 {
|
||||
util := float64(gd) / float64(sd)
|
||||
if util > 1.0 {
|
||||
util = 1.0
|
||||
}
|
||||
return util * 100.0
|
||||
} else if gd > 0 && elapsed100ns > 0 {
|
||||
util := float64(gd) / float64(elapsed100ns) * 100.0
|
||||
if util > 100.0 {
|
||||
util = 100.0
|
||||
}
|
||||
return util
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// d3dkmtFanPct returns fan speed as a percentage of maxFanRPM, clamped to
|
||||
// 100%. Returns 0 if maxFanRPM is unavailable or fan is not spinning.
|
||||
func d3dkmtFanPct(fanRPM, maxFanRPM uint32) float64 {
|
||||
if maxFanRPM > 0 && fanRPM > 0 {
|
||||
pct := float64(fanRPM) / float64(maxFanRPM) * 100.0
|
||||
if pct > 100.0 {
|
||||
pct = 100.0
|
||||
}
|
||||
return pct
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// d3dkmtPowerW converts power from deci-watts (as reported by D3DKMT) to
|
||||
// watts. Returns 0 if the power value is zero.
|
||||
func d3dkmtPowerW(power uint32) float64 {
|
||||
if power > 0 {
|
||||
return float64(power) / 10.0
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// d3dkmtTempC converts temperature from deci-Celsius (as reported by D3DKMT)
|
||||
// to degrees Celsius.
|
||||
func d3dkmtTempC(tempDeciC uint32) int {
|
||||
return int(tempDeciC / 10)
|
||||
}
|
||||
|
||||
type d3dkmtAdapterState struct {
|
||||
luid LUID
|
||||
hAdapter uint32
|
||||
nbSegments uint32
|
||||
nodeCount uint32
|
||||
maxFanRPM uint32
|
||||
prevNodeRT map[uint32]nodeRunningTimes
|
||||
prevTime time.Time
|
||||
}
|
||||
|
||||
// tryD3DKMT attempts to start GPU monitoring using D3DKMT and optional PDH
|
||||
// counters. It returns a channel of GpuStat snapshots or an error if no
|
||||
// usable adapters are found.
|
||||
func tryD3DKMT(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||
if err := initD3DKMT(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
adapterInfos, err := d3dkmEnumerateAdapters()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
type adapterMeta struct {
|
||||
luid LUID
|
||||
nbSegments uint32
|
||||
nodeCount uint32
|
||||
maxFanRPM uint32
|
||||
}
|
||||
|
||||
var metaList []adapterMeta
|
||||
|
||||
for i, ai := range adapterInfos {
|
||||
hAdapter, err := d3dkmOpenAdapter(ai.AdapterLuid)
|
||||
if err != nil {
|
||||
logger.Debugf("adapter %d: open failed: %s", i, err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
nbSegments, nodeCount, err := d3dkmQueryAdapterStats(ai.AdapterLuid)
|
||||
if err != nil {
|
||||
logger.Debugf("adapter %d: query stats failed: %s", i, err.Error())
|
||||
d3dkmCloseAdapter(hAdapter)
|
||||
continue
|
||||
}
|
||||
|
||||
caps, err := d3dkmGetAdapterPerfDataCaps(hAdapter)
|
||||
if err != nil {
|
||||
logger.Debugf("adapter %d: perf caps failed: %s", i, err.Error())
|
||||
}
|
||||
|
||||
d3dkmCloseAdapter(hAdapter)
|
||||
|
||||
var maxFanRPM uint32
|
||||
if caps != nil {
|
||||
maxFanRPM = caps.MaxFanRPM
|
||||
}
|
||||
|
||||
metaList = append(metaList, adapterMeta{
|
||||
luid: ai.AdapterLuid,
|
||||
nbSegments: nbSegments,
|
||||
nodeCount: nodeCount,
|
||||
maxFanRPM: maxFanRPM,
|
||||
})
|
||||
logger.Debugf("adapter %d: segments=%d nodes=%d fan_max=%d luid=%d:%d", i, nbSegments, nodeCount, maxFanRPM, ai.AdapterLuid.HighPart, ai.AdapterLuid.LowPart)
|
||||
}
|
||||
|
||||
if len(metaList) == 0 {
|
||||
return nil, fmt.Errorf("no usable D3DKMT adapters found")
|
||||
}
|
||||
|
||||
pdhUtil, pdhErr := initPdhGpuUtil()
|
||||
if pdhErr != nil {
|
||||
logger.Debugf("PDH GPU utilization not available: %s", pdhErr.Error())
|
||||
} else {
|
||||
logger.Info("using PDH performance counters for GPU utilization")
|
||||
}
|
||||
|
||||
ch := make(chan []GpuStat, 1)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
if pdhUtil != nil {
|
||||
defer pdhUtil.close()
|
||||
}
|
||||
|
||||
var adapters []d3dkmtAdapterState
|
||||
for _, m := range metaList {
|
||||
hAdapter, err := d3dkmOpenAdapter(m.luid)
|
||||
if err != nil {
|
||||
logger.Debugf("reopen adapter failed: %s", err.Error())
|
||||
continue
|
||||
}
|
||||
adapters = append(adapters, d3dkmtAdapterState{
|
||||
luid: m.luid,
|
||||
hAdapter: hAdapter,
|
||||
nbSegments: m.nbSegments,
|
||||
nodeCount: m.nodeCount,
|
||||
maxFanRPM: m.maxFanRPM,
|
||||
prevNodeRT: make(map[uint32]nodeRunningTimes),
|
||||
})
|
||||
}
|
||||
|
||||
if len(adapters) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for _, a := range adapters {
|
||||
d3dkmCloseAdapter(a.hAdapter)
|
||||
}
|
||||
}()
|
||||
|
||||
for i := range adapters {
|
||||
a := &adapters[i]
|
||||
for node := uint32(0); node < a.nodeCount; node++ {
|
||||
globalRT, systemRT, err := d3dkmQueryNodeStats(a.luid, node)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
a.prevNodeRT[node] = nodeRunningTimes{Global: globalRT, System: systemRT}
|
||||
}
|
||||
a.prevTime = time.Now()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(every)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
stats := make([]GpuStat, 0, len(adapters))
|
||||
now := time.Now()
|
||||
|
||||
var pdhUtilMap map[LUID]float64
|
||||
if pdhUtil != nil {
|
||||
pdhUtilMap = pdhUtil.collect()
|
||||
}
|
||||
|
||||
for i := range adapters {
|
||||
a := &adapters[i]
|
||||
|
||||
perfData, err := d3dkmGetAdapterPerfData(a.hAdapter)
|
||||
if err != nil {
|
||||
logger.Debugf("adapter %d perfdata: %s", i, err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
var memUsedMB, memTotalMB int
|
||||
for seg := uint32(0); seg < a.nbSegments; seg++ {
|
||||
limit, resident, err := d3dkmQuerySegmentStats(a.luid, seg)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
memUsedMB += int(resident / (1024 * 1024))
|
||||
memTotalMB += int(limit / (1024 * 1024))
|
||||
}
|
||||
|
||||
var gpuUtil float64
|
||||
pdhGaveValue := false
|
||||
if pdhUtilMap != nil {
|
||||
if util, ok := pdhUtilMap[a.luid]; ok {
|
||||
gpuUtil = util
|
||||
pdhGaveValue = true
|
||||
}
|
||||
}
|
||||
|
||||
if !pdhGaveValue && a.nodeCount > 0 {
|
||||
elapsedNs := now.Sub(a.prevTime).Nanoseconds()
|
||||
elapsed100ns := elapsedNs / 100
|
||||
|
||||
for node := uint32(0); node < a.nodeCount; node++ {
|
||||
globalRT, systemRT, err := d3dkmQueryNodeStats(a.luid, node)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if prevRT, ok := a.prevNodeRT[node]; ok {
|
||||
if globalRT < prevRT.Global || systemRT < prevRT.System {
|
||||
a.prevNodeRT[node] = nodeRunningTimes{Global: globalRT, System: systemRT}
|
||||
continue
|
||||
}
|
||||
nodeUtil := d3dkmtNodeUtil(prevRT, nodeRunningTimes{Global: globalRT, System: systemRT}, elapsed100ns)
|
||||
if nodeUtil > gpuUtil {
|
||||
gpuUtil = nodeUtil
|
||||
}
|
||||
}
|
||||
a.prevNodeRT[node] = nodeRunningTimes{Global: globalRT, System: systemRT}
|
||||
}
|
||||
|
||||
a.prevTime = now
|
||||
}
|
||||
|
||||
tempC := d3dkmtTempC(perfData.Temperature)
|
||||
|
||||
fanSpeedPct := d3dkmtFanPct(perfData.FanRPM, a.maxFanRPM)
|
||||
powerDrawW := d3dkmtPowerW(perfData.Power)
|
||||
|
||||
var memUtilPct float64
|
||||
if memTotalMB > 0 {
|
||||
memUtilPct = float64(memUsedMB) / float64(memTotalMB) * 100.0
|
||||
}
|
||||
|
||||
stats = append(stats, GpuStat{
|
||||
Timestamp: now,
|
||||
ID: i,
|
||||
Name: fmt.Sprintf("GPU %d", i),
|
||||
TempC: tempC,
|
||||
GpuUtilPct: gpuUtil,
|
||||
MemUtilPct: memUtilPct,
|
||||
MemUsedMB: memUsedMB,
|
||||
MemTotalMB: memTotalMB,
|
||||
FanSpeedPct: fanSpeedPct,
|
||||
PowerDrawW: powerDrawW,
|
||||
})
|
||||
}
|
||||
|
||||
if len(stats) > 0 {
|
||||
select {
|
||||
case ch <- stats:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
//go:build windows
|
||||
|
||||
package perf
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestD3dkmtNodeUtil_FullLoad(t *testing.T) {
|
||||
prev := nodeRunningTimes{Global: 1000, System: 10000}
|
||||
cur := nodeRunningTimes{Global: 5000, System: 14000}
|
||||
got := d3dkmtNodeUtil(prev, cur, 100000)
|
||||
assert.Equal(t, 100.0, got)
|
||||
}
|
||||
|
||||
func TestD3dkmtNodeUtil_PartialUtil(t *testing.T) {
|
||||
prev := nodeRunningTimes{Global: 1000, System: 10000}
|
||||
cur := nodeRunningTimes{Global: 3000, System: 14000}
|
||||
got := d3dkmtNodeUtil(prev, cur, 100000)
|
||||
assert.Equal(t, 50.0, got)
|
||||
}
|
||||
|
||||
func TestD3dkmtNodeUtil_Identical(t *testing.T) {
|
||||
prev := nodeRunningTimes{Global: 10000, System: 10000}
|
||||
cur := nodeRunningTimes{Global: 20000, System: 20000}
|
||||
got := d3dkmtNodeUtil(prev, cur, 100000)
|
||||
assert.Equal(t, 100.0, got)
|
||||
}
|
||||
|
||||
func TestD3dkmtNodeUtil_CounterWrap(t *testing.T) {
|
||||
prev := nodeRunningTimes{Global: 9000, System: 10000}
|
||||
cur := nodeRunningTimes{Global: 1000, System: 10000}
|
||||
got := d3dkmtNodeUtil(prev, cur, 100000)
|
||||
assert.Equal(t, -1.0, got)
|
||||
}
|
||||
|
||||
func TestD3dkmtNodeUtil_SystemWrap(t *testing.T) {
|
||||
prev := nodeRunningTimes{Global: 1000, System: 9000}
|
||||
cur := nodeRunningTimes{Global: 5000, System: 1000}
|
||||
got := d3dkmtNodeUtil(prev, cur, 100000)
|
||||
assert.Equal(t, -1.0, got)
|
||||
}
|
||||
|
||||
func TestD3dkmtNodeUtil_ZeroDelta(t *testing.T) {
|
||||
prev := nodeRunningTimes{Global: 1000, System: 10000}
|
||||
cur := nodeRunningTimes{Global: 1000, System: 10000}
|
||||
got := d3dkmtNodeUtil(prev, cur, 100000)
|
||||
assert.Equal(t, 0.0, got)
|
||||
}
|
||||
|
||||
func TestD3dkmtNodeUtil_ElapsedFallback(t *testing.T) {
|
||||
prev := nodeRunningTimes{Global: 1000, System: 10000}
|
||||
cur := nodeRunningTimes{Global: 6000, System: 10000}
|
||||
got := d3dkmtNodeUtil(prev, cur, 50000)
|
||||
assert.InDelta(t, 10.0, got, 0.01)
|
||||
}
|
||||
|
||||
func TestD3dkmtFanPct_Normal(t *testing.T) {
|
||||
assert.Equal(t, 50.0, d3dkmtFanPct(1500, 3000))
|
||||
}
|
||||
|
||||
func TestD3dkmtFanPct_MaxFan(t *testing.T) {
|
||||
assert.Equal(t, 100.0, d3dkmtFanPct(3000, 3000))
|
||||
}
|
||||
|
||||
func TestD3dkmtFanPct_OverMaxClamped(t *testing.T) {
|
||||
assert.Equal(t, 100.0, d3dkmtFanPct(4000, 3000))
|
||||
}
|
||||
|
||||
func TestD3dkmtFanPct_ZeroMaxFan(t *testing.T) {
|
||||
assert.Equal(t, 0.0, d3dkmtFanPct(1500, 0))
|
||||
}
|
||||
|
||||
func TestD3dkmtFanPct_ZeroFanRPM(t *testing.T) {
|
||||
assert.Equal(t, 0.0, d3dkmtFanPct(0, 3000))
|
||||
}
|
||||
|
||||
func TestD3dkmtFanPct_BothZero(t *testing.T) {
|
||||
assert.Equal(t, 0.0, d3dkmtFanPct(0, 0))
|
||||
}
|
||||
|
||||
func TestD3dkmtPowerW(t *testing.T) {
|
||||
assert.Equal(t, 250.0, d3dkmtPowerW(2500))
|
||||
}
|
||||
|
||||
func TestD3dkmtPowerW_Zero(t *testing.T) {
|
||||
assert.Equal(t, 0.0, d3dkmtPowerW(0))
|
||||
}
|
||||
|
||||
func TestD3dkmtTempC(t *testing.T) {
|
||||
assert.Equal(t, 65, d3dkmtTempC(650))
|
||||
}
|
||||
|
||||
func TestD3dkmtTempC_Zero(t *testing.T) {
|
||||
assert.Equal(t, 0, d3dkmtTempC(0))
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
package perf
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ParseNvidiaSmiLine parses a single line from nvidia-smi CSV output.
|
||||
// Format: index,name,uuid,temperature.gpu,utilization.gpu,memory.used,memory.total,fan.speed,power.draw
|
||||
func ParseNvidiaSmiLine(line string) *GpuStat {
|
||||
fields := strings.Split(line, ",")
|
||||
if len(fields) < 9 {
|
||||
return nil
|
||||
}
|
||||
|
||||
id, _ := strconv.Atoi(strings.TrimSpace(fields[0]))
|
||||
name := strings.TrimSpace(fields[1])
|
||||
uuid := strings.TrimSpace(fields[2])
|
||||
tempC, _ := strconv.Atoi(strings.TrimSpace(fields[3]))
|
||||
gpuUtil, _ := strconv.ParseFloat(strings.TrimSpace(fields[4]), 64)
|
||||
memUsed, _ := strconv.Atoi(strings.TrimSpace(fields[5]))
|
||||
memTotal, _ := strconv.Atoi(strings.TrimSpace(fields[6]))
|
||||
fanSpeed, _ := strconv.ParseFloat(strings.TrimSpace(fields[7]), 64)
|
||||
powerDraw, _ := strconv.ParseFloat(strings.TrimSpace(fields[8]), 64)
|
||||
|
||||
var memUtil float64
|
||||
if memTotal > 0 {
|
||||
memUtil = float64(memUsed) / float64(memTotal) * 100
|
||||
}
|
||||
|
||||
return &GpuStat{
|
||||
Timestamp: time.Now(),
|
||||
ID: id,
|
||||
Name: name,
|
||||
UUID: uuid,
|
||||
TempC: tempC,
|
||||
GpuUtilPct: gpuUtil,
|
||||
MemUtilPct: memUtil,
|
||||
MemUsedMB: memUsed,
|
||||
MemTotalMB: memTotal,
|
||||
FanSpeedPct: fanSpeed,
|
||||
PowerDrawW: powerDraw,
|
||||
}
|
||||
}
|
||||
|
||||
// mactopOutput maps the subset of mactop's headless JSON output that is
|
||||
// relevant to GpuStat. Note that mactop's memory object is whole-system memory,
|
||||
// not GPU-attributed; the darwin monitor overlays ioreg's GPU-attributed
|
||||
// unified memory (see overlayIoregMem) so both backends report consistent
|
||||
// memory figures.
|
||||
type mactopOutput struct {
|
||||
SocMetrics struct {
|
||||
GPUPower float64 `json:"gpu_power"`
|
||||
GPUFreq int `json:"gpu_freq_mhz"`
|
||||
GPUTemp float64 `json:"gpu_temp"`
|
||||
} `json:"soc_metrics"`
|
||||
Memory struct {
|
||||
Total uint64 `json:"total"`
|
||||
Used uint64 `json:"used"`
|
||||
} `json:"memory"`
|
||||
GPUUsage float64 `json:"gpu_usage"`
|
||||
SystemInfo struct {
|
||||
Name string `json:"name"`
|
||||
GPUCoreCount int `json:"gpu_core_count"`
|
||||
} `json:"system_info"`
|
||||
Fans []struct {
|
||||
RPM int `json:"rpm"`
|
||||
MinRPM int `json:"min_rpm"`
|
||||
MaxRPM int `json:"max_rpm"`
|
||||
} `json:"fans"`
|
||||
Temperatures []struct {
|
||||
Group string `json:"group"`
|
||||
Avg float64 `json:"avg_celsius"`
|
||||
} `json:"temperatures"`
|
||||
}
|
||||
|
||||
// ioreg output uses ` = ` (with spaces) for top-level device properties and
|
||||
// `=` (no spaces) for values inside nested dictionaries such as
|
||||
// PerformanceStatistics.
|
||||
var (
|
||||
reIoregModel = regexp.MustCompile(`"model"\s*=\s*"([^"]+)"`)
|
||||
reIoregCoreCount = regexp.MustCompile(`"gpu-core-count"\s*=\s*(\d+)`)
|
||||
reIoregUtil = regexp.MustCompile(`"Device Utilization %"=(\d+)`)
|
||||
reIoregMemUsed = regexp.MustCompile(`"In use system memory"=(\d+)`)
|
||||
)
|
||||
|
||||
// ParseIoregOutput parses `ioreg -r -c IOGPU -d 1 -f` output into a GpuStat for
|
||||
// the Apple Silicon integrated GPU. This is a fallback for when mactop is not
|
||||
// installed: utilization and used memory are available, but power, temperature,
|
||||
// and fan speed are not exposed by ioreg. memTotalMB is the unified memory size
|
||||
// supplied by the caller, since Apple Silicon shares memory between CPU and GPU.
|
||||
// Returns nil if no GPU device is found in the output.
|
||||
func ParseIoregOutput(out []byte, memTotalMB int) *GpuStat {
|
||||
utilMatch := reIoregUtil.FindSubmatch(out)
|
||||
memMatch := reIoregMemUsed.FindSubmatch(out)
|
||||
if utilMatch == nil && memMatch == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var gpuUtil float64
|
||||
if utilMatch != nil {
|
||||
gpuUtil, _ = strconv.ParseFloat(string(utilMatch[1]), 64)
|
||||
}
|
||||
|
||||
const toMB = 1024 * 1024
|
||||
var memUsedMB int
|
||||
if memMatch != nil {
|
||||
memUsedBytes, _ := strconv.ParseInt(string(memMatch[1]), 10, 64)
|
||||
memUsedMB = int(memUsedBytes / toMB)
|
||||
}
|
||||
|
||||
var memUtil float64
|
||||
if memTotalMB > 0 {
|
||||
memUtil = float64(memUsedMB) / float64(memTotalMB) * 100
|
||||
}
|
||||
|
||||
name := "Apple GPU"
|
||||
if m := reIoregModel.FindSubmatch(out); m != nil {
|
||||
name = string(m[1])
|
||||
}
|
||||
if m := reIoregCoreCount.FindSubmatch(out); m != nil {
|
||||
if cores, err := strconv.Atoi(string(m[1])); err == nil && cores > 0 {
|
||||
name = fmt.Sprintf("%s (%d-core GPU)", name, cores)
|
||||
}
|
||||
}
|
||||
|
||||
return &GpuStat{
|
||||
Timestamp: time.Now(),
|
||||
ID: 0,
|
||||
Name: name,
|
||||
GpuUtilPct: gpuUtil,
|
||||
MemUtilPct: memUtil,
|
||||
MemUsedMB: memUsedMB,
|
||||
MemTotalMB: memTotalMB,
|
||||
}
|
||||
}
|
||||
|
||||
// ParseMactopLine parses a single line of mactop headless JSON output into a
|
||||
// GpuStat for the Apple Silicon integrated GPU. Returns nil if the line cannot
|
||||
// be parsed.
|
||||
func ParseMactopLine(line string) *GpuStat {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var out mactopOutput
|
||||
if err := json.Unmarshal([]byte(line), &out); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
const toMB = 1024 * 1024
|
||||
memUsedMB := int(out.Memory.Used / toMB)
|
||||
memTotalMB := int(out.Memory.Total / toMB)
|
||||
|
||||
var memUtil float64
|
||||
if memTotalMB > 0 {
|
||||
memUtil = float64(memUsedMB) / float64(memTotalMB) * 100
|
||||
}
|
||||
|
||||
name := out.SystemInfo.Name
|
||||
if name == "" {
|
||||
name = "Apple GPU"
|
||||
}
|
||||
if out.SystemInfo.GPUCoreCount > 0 {
|
||||
name = fmt.Sprintf("%s (%d-core GPU)", name, out.SystemInfo.GPUCoreCount)
|
||||
}
|
||||
|
||||
// Unified memory has no dedicated VRAM sensor; use the memory temperature
|
||||
// group when mactop exposes it.
|
||||
var vramTempC int
|
||||
for _, t := range out.Temperatures {
|
||||
if strings.EqualFold(t.Group, "Memory") {
|
||||
vramTempC = int(math.Round(t.Avg))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Average fan load across all fans as a percentage of their RPM range.
|
||||
var fanSpeed float64
|
||||
var fanCount int
|
||||
for _, f := range out.Fans {
|
||||
if f.MaxRPM > f.MinRPM {
|
||||
pct := float64(f.RPM-f.MinRPM) / float64(f.MaxRPM-f.MinRPM) * 100
|
||||
if pct < 0 {
|
||||
pct = 0
|
||||
}
|
||||
fanSpeed += pct
|
||||
fanCount++
|
||||
}
|
||||
}
|
||||
if fanCount > 0 {
|
||||
fanSpeed /= float64(fanCount)
|
||||
}
|
||||
|
||||
return &GpuStat{
|
||||
Timestamp: time.Now(),
|
||||
ID: 0,
|
||||
Name: name,
|
||||
TempC: int(math.Round(out.SocMetrics.GPUTemp)),
|
||||
VramTempC: vramTempC,
|
||||
GpuUtilPct: out.GPUUsage,
|
||||
MemUtilPct: memUtil,
|
||||
MemUsedMB: memUsedMB,
|
||||
MemTotalMB: memTotalMB,
|
||||
FanSpeedPct: fanSpeed,
|
||||
PowerDrawW: out.SocMetrics.GPUPower,
|
||||
}
|
||||
}
|
||||
@@ -6,13 +6,13 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/ring"
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotImplemented = errors.New("Not Implemented")
|
||||
ErrNotImplemented = errors.New("not implemented")
|
||||
ErrNoGpuTool = errors.New("no GPU monitoring tool available")
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
package perf
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
@@ -11,7 +15,156 @@ import (
|
||||
)
|
||||
|
||||
func getGpuStats(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||
return nil, ErrNotImplemented
|
||||
if ch, err := tryMactop(ctx, every, logger); err == nil {
|
||||
logger.Info("using mactop for GPU monitoring")
|
||||
return ch, nil
|
||||
} else {
|
||||
logger.Debugf("mactop: %s", err.Error())
|
||||
}
|
||||
|
||||
if ch, err := tryIoreg(ctx, every, logger); err == nil {
|
||||
logger.Info("using ioreg for GPU monitoring")
|
||||
return ch, nil
|
||||
} else {
|
||||
logger.Debugf("ioreg: %s", err.Error())
|
||||
}
|
||||
|
||||
return nil, ErrNoGpuTool
|
||||
}
|
||||
|
||||
// tryIoreg polls `ioreg -r -c IOGPU -d 1 -f` for Apple Silicon GPU stats. It is
|
||||
// a fallback for when mactop is not installed. ioreg exposes GPU utilization and
|
||||
// used memory but not power, temperature, or fan speed.
|
||||
func tryIoreg(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||
if _, err := exec.LookPath("ioreg"); err != nil {
|
||||
return nil, ErrNoGpuTool
|
||||
}
|
||||
|
||||
// Verify ioreg actually reports a GPU device before committing to it, so we
|
||||
// can fall through to ErrNoGpuTool otherwise.
|
||||
if stat := sampleIoreg(ctx); stat == nil {
|
||||
return nil, fmt.Errorf("ioreg reported no GPU device")
|
||||
}
|
||||
|
||||
if every < time.Second {
|
||||
every = time.Second
|
||||
}
|
||||
|
||||
ch := make(chan []GpuStat, 1)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
ticker := time.NewTicker(every)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
stat := sampleIoreg(ctx)
|
||||
if stat == nil {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case ch <- []GpuStat{*stat}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// sampleIoreg runs ioreg once and parses a single GpuStat, or returns nil.
|
||||
func sampleIoreg(ctx context.Context) *GpuStat {
|
||||
out, err := exec.CommandContext(ctx, "ioreg", "-r", "-c", "IOGPU", "-d", "1", "-f").Output()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var memTotalMB int
|
||||
if vmStat, err := mem.VirtualMemory(); err == nil {
|
||||
memTotalMB = int(vmStat.Total / (1024 * 1024))
|
||||
}
|
||||
|
||||
return ParseIoregOutput(out, memTotalMB)
|
||||
}
|
||||
|
||||
// overlayIoregMem replaces a GpuStat's memory fields with the GPU-attributed
|
||||
// unified memory reported by ioreg. mactop only exposes whole-system memory, so
|
||||
// without this the mactop and ioreg backends would report different memory
|
||||
// semantics. It is a no-op when ioreg is unavailable or reports no GPU memory,
|
||||
// leaving the mactop-supplied values in place.
|
||||
func overlayIoregMem(ctx context.Context, stat *GpuStat) {
|
||||
ioStat := sampleIoreg(ctx)
|
||||
if ioStat == nil {
|
||||
return
|
||||
}
|
||||
stat.MemUsedMB = ioStat.MemUsedMB
|
||||
stat.MemTotalMB = ioStat.MemTotalMB
|
||||
stat.MemUtilPct = ioStat.MemUtilPct
|
||||
}
|
||||
|
||||
// tryMactop streams Apple Silicon GPU stats from mactop's headless mode.
|
||||
// See https://github.com/metaspartan/mactop. mactop emits one JSON object per
|
||||
// sample to stdout, which we parse into GpuStat.
|
||||
func tryMactop(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||
if _, err := exec.LookPath("mactop"); err != nil {
|
||||
return nil, ErrNoGpuTool
|
||||
}
|
||||
|
||||
// mactop samples power over the interval, so give it at least a second.
|
||||
intervalMs := int(every.Milliseconds())
|
||||
if intervalMs < 1000 {
|
||||
intervalMs = 1000
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "mactop",
|
||||
"--headless",
|
||||
"--format", "json",
|
||||
"--interval", fmt.Sprintf("%d", intervalMs),
|
||||
)
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mactop stdout pipe failed: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("mactop start failed: %w", err)
|
||||
}
|
||||
|
||||
ch := make(chan []GpuStat, 1)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
// mactop's JSON objects can be large; allow generous line lengths.
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
stat := ParseMactopLine(line)
|
||||
if stat != nil {
|
||||
// mactop only reports whole-system memory; overlay ioreg's
|
||||
// GPU-attributed unified memory so both backends are consistent.
|
||||
overlayIoregMem(ctx, stat)
|
||||
select {
|
||||
case ch <- []GpuStat{*stat}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
cmd.Wait()
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func readSysStats() (SysStat, error) {
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -224,3 +224,90 @@ func TestCurrent_ConcurrentAccess(t *testing.T) {
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestParseNvidiaSmiLine_ValidLine(t *testing.T) {
|
||||
line := "0, NVIDIA GeForce RTX 3080, GPU-12345678-1234-1234-1234-123456789abc, 65, 80, 8192, 10240, 75, 250"
|
||||
|
||||
stat := ParseNvidiaSmiLine(line)
|
||||
require.NotNil(t, stat)
|
||||
|
||||
assert.Equal(t, 0, stat.ID)
|
||||
assert.Equal(t, "NVIDIA GeForce RTX 3080", stat.Name)
|
||||
assert.Equal(t, "GPU-12345678-1234-1234-1234-123456789abc", stat.UUID)
|
||||
assert.Equal(t, 65, stat.TempC)
|
||||
assert.Equal(t, 80.0, stat.GpuUtilPct)
|
||||
assert.Equal(t, 8192, stat.MemUsedMB)
|
||||
assert.Equal(t, 10240, stat.MemTotalMB)
|
||||
assert.Equal(t, 75.0, stat.FanSpeedPct)
|
||||
assert.Equal(t, 250.0, stat.PowerDrawW)
|
||||
assert.InDelta(t, 80.0, stat.MemUtilPct, 0.01)
|
||||
}
|
||||
|
||||
func TestParseNvidiaSmiLine_ShortLine(t *testing.T) {
|
||||
line := "0, NVIDIA GPU, GPU-123"
|
||||
|
||||
stat := ParseNvidiaSmiLine(line)
|
||||
assert.Nil(t, stat)
|
||||
}
|
||||
|
||||
func TestParseNvidiaSmiLine_MissingFields(t *testing.T) {
|
||||
line := "0, NVIDIA GPU, GPU-123, 65, 80, 8192, 10240, 75"
|
||||
|
||||
stat := ParseNvidiaSmiLine(line)
|
||||
assert.Nil(t, stat)
|
||||
}
|
||||
|
||||
func TestParseNvidiaSmiLine_ZeroMemoryTotal(t *testing.T) {
|
||||
line := "0, NVIDIA GPU, GPU-123, 65, 80, 0, 0, 75, 250"
|
||||
|
||||
stat := ParseNvidiaSmiLine(line)
|
||||
require.NotNil(t, stat)
|
||||
assert.Equal(t, 0.0, stat.MemUtilPct)
|
||||
}
|
||||
|
||||
const ioregSample = `+-o AGXAcceleratorG13X <class AGXAcceleratorG13X, id 0x1000009a1, registered, matched, active, busy 0 (39191 ms), retain 108>
|
||||
{
|
||||
"model" = "Apple M1 Pro"
|
||||
"gpu-core-count" = 16
|
||||
"PerformanceStatistics" = {"In use system memory (driver)"=0,"Alloc system memory"=14511046656,"Tiler Utilization %"=34,"recoveryCount"=0,"Renderer Utilization %"=34,"Device Utilization %"=34,"In use system memory"=7688503296}
|
||||
"IOClass" = "AGXAcceleratorG13X"
|
||||
}`
|
||||
|
||||
func TestParseIoregOutput_ValidOutput(t *testing.T) {
|
||||
const memTotalMB = 32768
|
||||
|
||||
stat := ParseIoregOutput([]byte(ioregSample), memTotalMB)
|
||||
require.NotNil(t, stat)
|
||||
|
||||
assert.Equal(t, 0, stat.ID)
|
||||
assert.Equal(t, "Apple M1 Pro (16-core GPU)", stat.Name)
|
||||
assert.Equal(t, 34.0, stat.GpuUtilPct)
|
||||
assert.Equal(t, 7688503296/(1024*1024), stat.MemUsedMB)
|
||||
assert.Equal(t, memTotalMB, stat.MemTotalMB)
|
||||
assert.InDelta(t, float64(stat.MemUsedMB)/memTotalMB*100, stat.MemUtilPct, 0.01)
|
||||
// Not exposed by ioreg.
|
||||
assert.Equal(t, 0, stat.TempC)
|
||||
assert.Equal(t, 0.0, stat.PowerDrawW)
|
||||
assert.Equal(t, 0.0, stat.FanSpeedPct)
|
||||
}
|
||||
|
||||
func TestParseIoregOutput_NoGpuDevice(t *testing.T) {
|
||||
stat := ParseIoregOutput([]byte("no gpu here"), 32768)
|
||||
assert.Nil(t, stat)
|
||||
}
|
||||
|
||||
func TestParseIoregOutput_ZeroMemTotal(t *testing.T) {
|
||||
stat := ParseIoregOutput([]byte(ioregSample), 0)
|
||||
require.NotNil(t, stat)
|
||||
assert.Equal(t, 0.0, stat.MemUtilPct)
|
||||
}
|
||||
|
||||
func TestParseIoregOutput_MissingModel(t *testing.T) {
|
||||
const out = `"Device Utilization %"=50,"In use system memory"=1048576`
|
||||
|
||||
stat := ParseIoregOutput([]byte(out), 1024)
|
||||
require.NotNil(t, stat)
|
||||
assert.Equal(t, "Apple GPU", stat.Name)
|
||||
assert.Equal(t, 50.0, stat.GpuUtilPct)
|
||||
assert.Equal(t, 1, stat.MemUsedMB)
|
||||
}
|
||||
|
||||
+150
-30
@@ -38,6 +38,13 @@ func getGpuStats(ctx context.Context, every time.Duration, logger *logmon.Monito
|
||||
logger.Debugf("nvidia-smi: %s", err.Error())
|
||||
}
|
||||
|
||||
if ch, err := tryRocmSmi(ctx, every, logger); err == nil {
|
||||
logger.Info("using rocm-smi for GPU monitoring")
|
||||
return ch, nil
|
||||
} else {
|
||||
logger.Debugf("rocm-smi: %s", err.Error())
|
||||
}
|
||||
|
||||
if ch, err := trySysfs(ctx, every, logger); err == nil {
|
||||
logger.Info("using sysfs for GPU monitoring")
|
||||
return ch, nil
|
||||
@@ -139,7 +146,7 @@ func tryNvidiaSmi(ctx context.Context, every time.Duration, logger *logmon.Monit
|
||||
cmd := exec.CommandContext(ctx, "nvidia-smi",
|
||||
"--query-gpu=index,name,uuid,temperature.gpu,utilization.gpu,memory.used,memory.total,fan.speed,power.draw",
|
||||
"--format=csv,noheader,nounits",
|
||||
"-loop", fmt.Sprintf("%d", sec),
|
||||
"--loop", fmt.Sprintf("%d", sec),
|
||||
)
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
@@ -163,7 +170,7 @@ func tryNvidiaSmi(ctx context.Context, every time.Duration, logger *logmon.Monit
|
||||
continue
|
||||
}
|
||||
|
||||
stat := parseNvidiaSmiLine(line)
|
||||
stat := ParseNvidiaSmiLine(line)
|
||||
if stat != nil {
|
||||
select {
|
||||
case ch <- []GpuStat{*stat}:
|
||||
@@ -177,40 +184,153 @@ func tryNvidiaSmi(ctx context.Context, every time.Duration, logger *logmon.Monit
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func parseNvidiaSmiLine(line string) *GpuStat {
|
||||
fields := strings.Split(line, ", ")
|
||||
if len(fields) < 9 {
|
||||
func tryRocmSmi(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||
if _, err := exec.LookPath("rocm-smi"); err != nil {
|
||||
return nil, ErrNoGpuTool
|
||||
}
|
||||
if every < time.Second {
|
||||
every = time.Second
|
||||
}
|
||||
const pollTimeout = 5 * time.Second
|
||||
|
||||
ch := make(chan []GpuStat, 1)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
ticker := time.NewTicker(every)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
pollCtx, cancel := context.WithTimeout(ctx, pollTimeout)
|
||||
cmd := exec.CommandContext(pollCtx, "rocm-smi", "-i", "-P", "-t", "-f", "-u", "--showmemuse", "--showmeminfo", "vram", "--showproductname", "--csv")
|
||||
out, err := cmd.Output()
|
||||
timedOut := pollCtx.Err() == context.DeadlineExceeded
|
||||
cancel()
|
||||
if err != nil {
|
||||
if timedOut {
|
||||
logger.Debug("rocm-smi timed out")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
stats := make([]GpuStat, 0)
|
||||
scanner := bufio.NewScanner(strings.NewReader(string(out)))
|
||||
var header string
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "device,") {
|
||||
header = line
|
||||
continue
|
||||
}
|
||||
|
||||
stat := parseRocmSmiLine(header, line)
|
||||
if stat != nil {
|
||||
stats = append(stats, *stat)
|
||||
}
|
||||
}
|
||||
|
||||
if len(stats) > 0 {
|
||||
select {
|
||||
case ch <- stats:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func parseRocmSmiLine(header string, line string) *GpuStat {
|
||||
if header == "" || line == "" {
|
||||
return nil
|
||||
}
|
||||
labels := strings.Split(header, ",")
|
||||
fields := strings.Split(line, ",")
|
||||
if len(labels) != len(fields) {
|
||||
return nil
|
||||
}
|
||||
|
||||
id, _ := strconv.Atoi(strings.TrimSpace(fields[0]))
|
||||
name := strings.TrimSpace(fields[1])
|
||||
uuid := strings.TrimSpace(fields[2])
|
||||
tempC, _ := strconv.Atoi(strings.TrimSpace(fields[3]))
|
||||
gpuUtil, _ := strconv.ParseFloat(strings.TrimSpace(fields[4]), 64)
|
||||
memUsed, _ := strconv.Atoi(strings.TrimSpace(fields[5]))
|
||||
memTotal, _ := strconv.Atoi(strings.TrimSpace(fields[6]))
|
||||
fanSpeed, _ := strconv.ParseFloat(strings.TrimSpace(fields[7]), 64)
|
||||
powerDraw, _ := strconv.ParseFloat(strings.TrimSpace(fields[8]), 64)
|
||||
|
||||
var memUtil float64
|
||||
if memTotal > 0 {
|
||||
memUtil = float64(memUsed) / float64(memTotal) * 100
|
||||
result := &GpuStat{
|
||||
Timestamp: time.Now(),
|
||||
ID: -1,
|
||||
}
|
||||
|
||||
return &GpuStat{
|
||||
Timestamp: time.Now(),
|
||||
ID: id,
|
||||
Name: name,
|
||||
UUID: uuid,
|
||||
TempC: tempC,
|
||||
GpuUtilPct: gpuUtil,
|
||||
MemUtilPct: memUtil,
|
||||
MemUsedMB: memUsed,
|
||||
MemTotalMB: memTotal,
|
||||
FanSpeedPct: fanSpeed,
|
||||
PowerDrawW: powerDraw,
|
||||
var device string
|
||||
var deviceName string
|
||||
var cardSeries string
|
||||
var gfxVersion string
|
||||
|
||||
const toMB = 1024 * 1024
|
||||
|
||||
for i, col := range labels {
|
||||
val := strings.TrimSpace(fields[i])
|
||||
switch col {
|
||||
case "device":
|
||||
device = val
|
||||
id, err := strconv.Atoi(strings.TrimPrefix(val, "card"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
result.ID = id
|
||||
case "Device Name":
|
||||
deviceName = val
|
||||
case "GUID":
|
||||
result.UUID = val
|
||||
case "Temperature (Sensor edge) (C)":
|
||||
tempC, _ := strconv.ParseFloat(val, 64)
|
||||
result.TempC = int(tempC)
|
||||
case "Temperature (Sensor memory) (C)":
|
||||
vramTempC, _ := strconv.ParseFloat(val, 64)
|
||||
result.VramTempC = int(vramTempC)
|
||||
case "Fan speed (%)":
|
||||
fanSpeed, _ := strconv.ParseFloat(val, 64)
|
||||
result.FanSpeedPct = fanSpeed
|
||||
case "Current Socket Graphics Package Power (W)":
|
||||
fallthrough
|
||||
case "Average Graphics Package Power (W)":
|
||||
powerDraw, _ := strconv.ParseFloat(val, 64)
|
||||
result.PowerDrawW = powerDraw
|
||||
case "GPU use (%)":
|
||||
gpuUtil, _ := strconv.ParseFloat(val, 64)
|
||||
result.GpuUtilPct = gpuUtil
|
||||
case "GPU Memory Allocated (VRAM%)":
|
||||
memUtil, _ := strconv.ParseFloat(val, 64)
|
||||
result.MemUtilPct = memUtil
|
||||
case "VRAM Total Memory (B)":
|
||||
memTotal, _ := strconv.ParseUint(val, 10, 64)
|
||||
result.MemTotalMB = int(memTotal / toMB)
|
||||
case "VRAM Total Used Memory (B)":
|
||||
memUsed, _ := strconv.ParseUint(val, 10, 64)
|
||||
result.MemUsedMB = int(memUsed / toMB)
|
||||
case "Card Series":
|
||||
cardSeries = val
|
||||
case "GFX Version":
|
||||
gfxVersion = val
|
||||
}
|
||||
}
|
||||
|
||||
if result.ID == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
name := device
|
||||
if cardSeries != "" && cardSeries != "N/A" {
|
||||
name = cardSeries + " " + device + " (" + gfxVersion + ")"
|
||||
} else if deviceName != "" && deviceName != "N/A" {
|
||||
name = deviceName + " " + device + " (" + gfxVersion + ")"
|
||||
}
|
||||
result.Name = name
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func trySysfs(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
package perf
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
@@ -11,7 +15,75 @@ import (
|
||||
)
|
||||
|
||||
func getGpuStats(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||
return nil, ErrNotImplemented
|
||||
if ch, err := tryNvidiaSmiWindows(ctx, every, logger); err == nil {
|
||||
logger.Info("using nvidia-smi for GPU monitoring")
|
||||
return ch, nil
|
||||
} else {
|
||||
logger.Debugf("nvidia-smi: %s", err.Error())
|
||||
}
|
||||
|
||||
if ch, err := tryD3DKMT(ctx, every, logger); err == nil {
|
||||
logger.Info("using D3DKMT for GPU monitoring")
|
||||
return ch, nil
|
||||
} else {
|
||||
logger.Debugf("D3DKMT: %s", err.Error())
|
||||
}
|
||||
|
||||
return nil, ErrNoGpuTool
|
||||
}
|
||||
|
||||
// tryNvidiaSmiWindows starts nvidia-smi in loop mode on Windows and returns
|
||||
// a channel receiving GPU stat snapshots. Returns ErrNoGpuTool if nvidia-smi
|
||||
// is not available.
|
||||
func tryNvidiaSmiWindows(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||
if _, err := exec.LookPath("nvidia-smi"); err != nil {
|
||||
return nil, ErrNoGpuTool
|
||||
}
|
||||
|
||||
sec := int(every.Seconds())
|
||||
if sec < 1 {
|
||||
sec = 1
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "nvidia-smi",
|
||||
"--query-gpu=index,name,uuid,temperature.gpu,utilization.gpu,memory.used,memory.total,fan.speed,power.draw",
|
||||
"--format=csv,noheader,nounits",
|
||||
"--loop", fmt.Sprintf("%d", sec),
|
||||
)
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("nvidia-smi stdout pipe failed: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("nvidia-smi start failed: %w", err)
|
||||
}
|
||||
|
||||
ch := make(chan []GpuStat, 1)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
stat := ParseNvidiaSmiLine(line)
|
||||
if stat != nil {
|
||||
select {
|
||||
case ch <- []GpuStat{*stat}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
cmd.Wait()
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func readSysStats() (SysStat, error) {
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
//go:build windows
|
||||
|
||||
package perf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
pdhDLL = windows.NewLazySystemDLL("pdh.dll")
|
||||
procPdhOpenQuery = pdhDLL.NewProc("PdhOpenQueryW")
|
||||
procPdhAddEnglishCounter = pdhDLL.NewProc("PdhAddEnglishCounterW")
|
||||
procPdhCollectQueryData = pdhDLL.NewProc("PdhCollectQueryData")
|
||||
procPdhGetFormattedCounterArray = pdhDLL.NewProc("PdhGetFormattedCounterArrayW")
|
||||
procPdhCloseQuery = pdhDLL.NewProc("PdhCloseQuery")
|
||||
)
|
||||
|
||||
const (
|
||||
pdhFmtDouble = 0x00000200
|
||||
pdhMoreData = 0x800007D2
|
||||
pdhNoData = 0x800007D5
|
||||
)
|
||||
|
||||
type pdhCounterValue struct {
|
||||
CStatus uint32
|
||||
DblVal float64
|
||||
}
|
||||
|
||||
type pdhCounterValueItem struct {
|
||||
SzName *uint16
|
||||
FmtValue pdhCounterValue
|
||||
}
|
||||
|
||||
func init() {
|
||||
var item pdhCounterValueItem
|
||||
if unsafe.Sizeof(item) != 24 {
|
||||
panic(fmt.Sprintf("pdhCounterValueItem size %d != expected 24 on x64", unsafe.Sizeof(item)))
|
||||
}
|
||||
}
|
||||
|
||||
type pdhGpuUtil struct {
|
||||
query uintptr
|
||||
counter uintptr
|
||||
}
|
||||
|
||||
// initPdhGpuUtil creates a PDH query for the GPU Engine utilization counter.
|
||||
// Returns nil with an error if PDH or the counter is unavailable.
|
||||
func initPdhGpuUtil() (*pdhGpuUtil, error) {
|
||||
var query uintptr
|
||||
if ret, _, _ := procPdhOpenQuery.Call(0, 0, uintptr(unsafe.Pointer(&query))); ret != 0 {
|
||||
return nil, fmt.Errorf("PdhOpenQuery: 0x%x", ret)
|
||||
}
|
||||
|
||||
path, _ := windows.UTF16PtrFromString(`\GPU Engine(*)\Utilization Percentage`)
|
||||
var counter uintptr
|
||||
if ret, _, _ := procPdhAddEnglishCounter.Call(
|
||||
query, uintptr(unsafe.Pointer(path)), 0, uintptr(unsafe.Pointer(&counter)),
|
||||
); ret != 0 {
|
||||
procPdhCloseQuery.Call(query)
|
||||
return nil, fmt.Errorf("PdhAddEnglishCounter(GPU Engine): 0x%x", ret)
|
||||
}
|
||||
|
||||
procPdhCollectQueryData.Call(query)
|
||||
|
||||
return &pdhGpuUtil{query: query, counter: counter}, nil
|
||||
}
|
||||
|
||||
// close releases the PDH query handle.
|
||||
func (p *pdhGpuUtil) close() {
|
||||
if p.query != 0 {
|
||||
procPdhCloseQuery.Call(p.query)
|
||||
p.query = 0
|
||||
}
|
||||
}
|
||||
|
||||
// collect reads the PDH counter and returns a map of adapter LUID to
|
||||
// aggregated GPU utilization percentage, summed across all engine instances
|
||||
// per adapter and clamped to 100%.
|
||||
func (p *pdhGpuUtil) collect() map[LUID]float64 {
|
||||
ret, _, _ := procPdhCollectQueryData.Call(p.query)
|
||||
if ret != 0 && ret != pdhNoData {
|
||||
return nil
|
||||
}
|
||||
|
||||
var bufSize uint32
|
||||
var itemCount uint32
|
||||
ret, _, _ = procPdhGetFormattedCounterArray.Call(
|
||||
p.counter, pdhFmtDouble,
|
||||
uintptr(unsafe.Pointer(&bufSize)),
|
||||
uintptr(unsafe.Pointer(&itemCount)),
|
||||
0,
|
||||
)
|
||||
if ret != pdhMoreData || itemCount == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
buf := make([]byte, bufSize)
|
||||
ret, _, _ = procPdhGetFormattedCounterArray.Call(
|
||||
p.counter, pdhFmtDouble,
|
||||
uintptr(unsafe.Pointer(&bufSize)),
|
||||
uintptr(unsafe.Pointer(&itemCount)),
|
||||
uintptr(unsafe.Pointer(&buf[0])),
|
||||
)
|
||||
if ret != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
itemSize := uint32(unsafe.Sizeof(pdhCounterValueItem{}))
|
||||
result := make(map[LUID]float64)
|
||||
|
||||
for i := uint32(0); i < itemCount; i++ {
|
||||
item := (*pdhCounterValueItem)(unsafe.Pointer(&buf[i*itemSize]))
|
||||
if item.FmtValue.CStatus != 0 {
|
||||
continue
|
||||
}
|
||||
luid, ok := parsePdhLuid(windows.UTF16PtrToString(item.SzName))
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
result[luid] += item.FmtValue.DblVal
|
||||
}
|
||||
|
||||
for luid := range result {
|
||||
if result[luid] > 100.0 {
|
||||
result[luid] = 100.0
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// parsePdhLuid extracts the adapter LUID (high and low parts) from a PDH
|
||||
// GPU Engine instance name (e.g. "pid_1234_luid_0x00000000_0x000148BF_phys_0_eng_2_engtype_Compute").
|
||||
func parsePdhLuid(name string) (LUID, bool) {
|
||||
idx := strings.Index(name, "luid_0x")
|
||||
if idx < 0 {
|
||||
return LUID{}, false
|
||||
}
|
||||
rest := name[idx+7:]
|
||||
parts := strings.SplitN(rest, "_", 4)
|
||||
if len(parts) < 3 {
|
||||
return LUID{}, false
|
||||
}
|
||||
hp, err := strconv.ParseUint(parts[0], 16, 32)
|
||||
if err != nil {
|
||||
return LUID{}, false
|
||||
}
|
||||
lpStr := strings.TrimPrefix(parts[1], "0x")
|
||||
lp, err := strconv.ParseUint(lpStr, 16, 32)
|
||||
if err != nil {
|
||||
return LUID{}, false
|
||||
}
|
||||
return LUID{LowPart: uint32(lp), HighPart: int32(hp)}, true
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
//go:build windows
|
||||
|
||||
package perf
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParsePdhLuid_Valid(t *testing.T) {
|
||||
name := `pid_25312_luid_0x00000000_0x000148BF_phys_0_eng_2_engtype_Compute`
|
||||
got, ok := parsePdhLuid(name)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, uint32(0x000148BF), got.LowPart)
|
||||
assert.Equal(t, int32(0x00000000), got.HighPart)
|
||||
}
|
||||
|
||||
func TestParsePdhLuid_ValidNvidia(t *testing.T) {
|
||||
name := `pid_1388_luid_0x00000000_0x00011372_phys_0_eng_8_engtype_Compute_1`
|
||||
got, ok := parsePdhLuid(name)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, uint32(0x00011372), got.LowPart)
|
||||
assert.Equal(t, int32(0x00000000), got.HighPart)
|
||||
}
|
||||
|
||||
func TestParsePdhLuid_NonZeroHighPart(t *testing.T) {
|
||||
name := `pid_1234_luid_0x00000001_0x0000C85A_phys_0_eng_5_engtype_Copy`
|
||||
got, ok := parsePdhLuid(name)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, uint32(0x0000C85A), got.LowPart)
|
||||
assert.Equal(t, int32(0x00000001), got.HighPart)
|
||||
}
|
||||
|
||||
func TestParsePdhLuid_InvalidNoLuid(t *testing.T) {
|
||||
_, ok := parsePdhLuid("invalid_string_without_luid")
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestParsePdhLuid_InvalidEmpty(t *testing.T) {
|
||||
_, ok := parsePdhLuid("")
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestParsePdhLuid_InvalidHex(t *testing.T) {
|
||||
_, ok := parsePdhLuid("pid_1234_luid_0xZZZZ_0xGGGG_phys_0")
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestParsePdhLuid_ShortAfterLuid(t *testing.T) {
|
||||
_, ok := parsePdhLuid("pid_1234_luid_0x00000000")
|
||||
assert.False(t, ok)
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var simpleResponderPath string
|
||||
|
||||
func skipIfNoSimpleResponder(t *testing.T) {
|
||||
t.Helper()
|
||||
if _, err := os.Stat(simpleResponderPath); os.IsNotExist(err) {
|
||||
t.Skipf("simple-responder not found at %s, run `make simple-responder`", simpleResponderPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goos := runtime.GOOS
|
||||
goarch := runtime.GOARCH
|
||||
if goos == "windows" {
|
||||
simpleResponderPath = filepath.Join("..", "..", "build", "simple-responder.exe")
|
||||
} else {
|
||||
simpleResponderPath = filepath.Join("..", "..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||
}
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func getFreePort(t *testing.T) int {
|
||||
t.Helper()
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("getFreePort: %v", err)
|
||||
}
|
||||
defer l.Close()
|
||||
return l.Addr().(*net.TCPAddr).Port
|
||||
}
|
||||
|
||||
func simpleResponderCmd(t *testing.T, args ...string) (string, int) {
|
||||
port := getFreePort(t)
|
||||
cmdPath := filepath.ToSlash(simpleResponderPath)
|
||||
base := []string{cmdPath, fmt.Sprintf("-port %d", port)}
|
||||
base = append(base, args...)
|
||||
return strings.Join(base, " "), port
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
type ProcessState string
|
||||
|
||||
const (
|
||||
StateStopped ProcessState = ProcessState("stopped")
|
||||
StateStarting ProcessState = ProcessState("starting")
|
||||
StateReady ProcessState = ProcessState("ready")
|
||||
StateStopping ProcessState = ProcessState("stopping")
|
||||
|
||||
// process is shutdown and will not be restarted
|
||||
StateShutdown ProcessState = ProcessState("shutdown")
|
||||
)
|
||||
|
||||
type Process interface {
|
||||
// Run starts the process blocks until the process is terminated.
|
||||
// The timeout parameter controls how long to wait for the process to get
|
||||
// to a ready state to process traffic
|
||||
Run(timeout time.Duration) error
|
||||
|
||||
// WaitReady blocks until the process is ready to serve requests
|
||||
// or the context is cancelled. It returns nil when the process is ready
|
||||
WaitReady(context.Context) error
|
||||
|
||||
// Stop blocks until the process has terminated. It returns nil when
|
||||
// the process terminated as expected (exit 0)
|
||||
Stop(timeout time.Duration) error
|
||||
|
||||
// State returns the current state of the process
|
||||
// Note: this is a snapshot of the state at the time of the call
|
||||
// and may change at any time after the call returns.
|
||||
State() ProcessState
|
||||
|
||||
// ServeHTTP forwards requests to the underlying process
|
||||
// Calling it when the process is not ready will result in a
|
||||
// 503 response with a body indicating it is a llama-swap-error
|
||||
ServeHTTP(http.ResponseWriter, *http.Request)
|
||||
|
||||
// Logger returns the monitor that captures this process's stdout/stderr.
|
||||
Logger() *logmon.Monitor
|
||||
}
|
||||
@@ -0,0 +1,684 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
var ErrStartAborted = fmt.Errorf("aborted")
|
||||
|
||||
// cmdWaitDelay is the upper bound the runtime will wait for child I/O to
|
||||
// drain after the process exits before force-closing the stdout/stderr
|
||||
// pipes. Required so that cmd.Wait() returns even when a forked grandchild
|
||||
// inherits and holds the pipes open (e.g. a shell wrapper that backgrounds
|
||||
// the real binary). killProcess sends the stop signal directly (not via the
|
||||
// cmd context), so this delay is measured from process exit rather than from
|
||||
// the stop request, and stays independent of the caller's graceful timeout.
|
||||
const cmdWaitDelay = 10 * time.Second
|
||||
|
||||
// parentCancelGraceTimeout is the graceful timeout used when the process is
|
||||
// torn down because parentCtx was cancelled (final router teardown or app
|
||||
// shutdown). In the normal flow the process has already been stopped via
|
||||
// Stop() by this point, so killProcess is a no-op kill; the short grace just
|
||||
// bounds the rare case where a process is still alive when its context is cut.
|
||||
const parentCancelGraceTimeout = time.Second
|
||||
|
||||
type runReq struct {
|
||||
timeout time.Duration
|
||||
respond chan error
|
||||
}
|
||||
|
||||
type stopReq struct {
|
||||
timeout time.Duration
|
||||
respond chan error
|
||||
}
|
||||
|
||||
type waitReadyReq struct {
|
||||
respond chan error
|
||||
}
|
||||
|
||||
type startResult struct {
|
||||
cmd *exec.Cmd
|
||||
cmdDone chan struct{}
|
||||
cancel context.CancelFunc
|
||||
handlerFn http.HandlerFunc
|
||||
err error
|
||||
}
|
||||
|
||||
type ProcessCommand struct {
|
||||
id string
|
||||
config config.ModelConfig
|
||||
parentCtx context.Context
|
||||
|
||||
processLogger *logmon.Monitor
|
||||
proxyLogger *logmon.Monitor
|
||||
|
||||
// waitDelay is assigned to cmd.WaitDelay when starting the upstream
|
||||
// process. Defaults to cmdWaitDelay; tests override it to keep the
|
||||
// pipe-close backstop from dominating their runtime.
|
||||
waitDelay time.Duration
|
||||
|
||||
runCh chan runReq
|
||||
stopCh chan stopReq
|
||||
waitReadyCh chan waitReadyReq
|
||||
|
||||
// current ProcessState. Written only by run(); read by State() via atomic load.
|
||||
state atomic.Value
|
||||
|
||||
// stores the active reverse-proxy handler when the process is running.
|
||||
// Written only by run(); read by ServeHTTP via atomic load.
|
||||
handler atomic.Pointer[http.HandlerFunc]
|
||||
|
||||
lastUse atomic.Int64 // unix nano timestamp of last ServeHTTP completion
|
||||
inflight atomic.Int64 // current in-flight ServeHTTP calls
|
||||
}
|
||||
|
||||
var _ Process = (*ProcessCommand)(nil)
|
||||
|
||||
func New(
|
||||
parentCtx context.Context,
|
||||
id string,
|
||||
conf config.ModelConfig,
|
||||
processLogger *logmon.Monitor,
|
||||
proxyLogger *logmon.Monitor,
|
||||
) (*ProcessCommand, error) {
|
||||
p := &ProcessCommand{
|
||||
id: id,
|
||||
config: conf,
|
||||
parentCtx: parentCtx,
|
||||
processLogger: processLogger,
|
||||
proxyLogger: proxyLogger,
|
||||
|
||||
runCh: make(chan runReq),
|
||||
stopCh: make(chan stopReq),
|
||||
waitReadyCh: make(chan waitReadyReq),
|
||||
waitDelay: cmdWaitDelay,
|
||||
}
|
||||
p.state.Store(StateStopped)
|
||||
|
||||
go p.run()
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) Logger() *logmon.Monitor { return p.processLogger }
|
||||
|
||||
// run is the single-writer goroutine that owns all mutable lifecycle state
|
||||
// (current ProcessState, the running *exec.Cmd, the active reverse-proxy
|
||||
// handler, and the list of WaitReady subscribers). Every public method
|
||||
// (Run / Stop / State / WaitReady) is a thin client that sends a request on
|
||||
// one of the channels below and waits for a response — this funnels concurrent
|
||||
// callers through a single serialization point so the state machine never
|
||||
// observes a race.
|
||||
func (p *ProcessCommand) run() {
|
||||
// Mutable state — only read/written from this goroutine. ServeHTTP reads
|
||||
// p.handler concurrently, which is why handler is an atomic.Pointer.
|
||||
// p.state mirrors `state` so State() can observe transitions; setState
|
||||
// writes both.
|
||||
state := StateStopped
|
||||
setState := func(s ProcessState) {
|
||||
old := state
|
||||
state = s
|
||||
p.state.Store(s)
|
||||
if old != s {
|
||||
event.Emit(shared.ProcessStateChangeEvent{
|
||||
ProcessName: p.id,
|
||||
OldState: string(old),
|
||||
NewState: string(s),
|
||||
})
|
||||
}
|
||||
}
|
||||
var (
|
||||
cmd *exec.Cmd
|
||||
cmdDone <-chan struct{}
|
||||
cmdCancel context.CancelFunc
|
||||
readyWaiters []waitReadyReq
|
||||
// runResp parks the in-flight Run caller's response channel. The
|
||||
// interface contract is that Run blocks until the process is
|
||||
// terminated, so we hold this until Stop, parentCtx, or an
|
||||
// upstream exit unblocks it via respondRun.
|
||||
runResp chan<- error
|
||||
)
|
||||
|
||||
// notifyWaiters wakes every blocked WaitReady caller with the given result.
|
||||
// Used on transitions out of StateStarting (ready, failed, aborted, or
|
||||
// shutdown) — anything that resolves the "is it ready yet?" question.
|
||||
notifyWaiters := func(err error) {
|
||||
for _, w := range readyWaiters {
|
||||
select {
|
||||
case w.respond <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
readyWaiters = nil
|
||||
}
|
||||
|
||||
// respondRun delivers the final Run result, if a Run caller is parked.
|
||||
respondRun := func(err error) {
|
||||
if runResp != nil {
|
||||
runResp <- err
|
||||
runResp = nil
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
// Shutdown: parent context cancelled. Tear down any running process,
|
||||
// wake any pending WaitReady callers with an error, then exit the
|
||||
// goroutine permanently. Subsequent public-method calls will fail
|
||||
// because parentCtx.Done() unblocks their send-side selects.
|
||||
case <-p.parentCtx.Done():
|
||||
// Mark shutdown before killProcess so concurrent State() readers
|
||||
// stop treating this process as ready while the (possibly slow)
|
||||
// teardown is in progress.
|
||||
setState(StateShutdown)
|
||||
if cmd != nil {
|
||||
p.handler.Store(nil)
|
||||
p.killProcess(cmd, cmdCancel, cmdDone, parentCancelGraceTimeout)
|
||||
cmd = nil
|
||||
cmdDone = nil
|
||||
cmdCancel = nil
|
||||
}
|
||||
notifyWaiters(fmt.Errorf("[%s] shutdown", p.id))
|
||||
respondRun(fmt.Errorf("[%s] shutdown", p.id))
|
||||
return
|
||||
|
||||
// Upstream exited on its own (not via Stop). Drop handler state,
|
||||
// transition to Stopped, and unblock the parked Run caller.
|
||||
// cmdDone is nil while no process is running, so this case is
|
||||
// dormant outside of StateReady.
|
||||
case <-cmdDone:
|
||||
if cmdCancel != nil {
|
||||
cmdCancel()
|
||||
}
|
||||
cmd = nil
|
||||
cmdDone = nil
|
||||
cmdCancel = nil
|
||||
p.handler.Store(nil)
|
||||
setState(StateStopped)
|
||||
respondRun(fmt.Errorf("[%s] upstream exited unexpectedly", p.id))
|
||||
|
||||
// WaitReady: if we're already in a terminal-for-this-question state,
|
||||
// respond immediately; otherwise queue the caller and let a future
|
||||
// state transition wake them via notifyWaiters.
|
||||
case req := <-p.waitReadyCh:
|
||||
switch state {
|
||||
case StateReady:
|
||||
req.respond <- nil
|
||||
case StateShutdown:
|
||||
req.respond <- fmt.Errorf("[%s] shutdown", p.id)
|
||||
default:
|
||||
readyWaiters = append(readyWaiters, req)
|
||||
}
|
||||
|
||||
// Run: start the upstream process. Only valid from StateStopped.
|
||||
// doStart can take a long time (health-check polling), so it runs in
|
||||
// a separate goroutine and we wait on resultCh. While waiting we also
|
||||
// listen for an incoming Stop — that's how callers cancel an in-flight
|
||||
// start.
|
||||
case req := <-p.runCh:
|
||||
if state != StateStopped {
|
||||
req.respond <- fmt.Errorf("[%s] could not be started in %s state", p.id, state)
|
||||
continue
|
||||
}
|
||||
setState(StateStarting)
|
||||
|
||||
startCtx, cancelStart := context.WithCancel(context.Background())
|
||||
resultCh := make(chan startResult, 1)
|
||||
go func() {
|
||||
resultCh <- p.doStart(startCtx, req.timeout)
|
||||
}()
|
||||
|
||||
// pendingStop holds a Stop request that arrived mid-start, so we
|
||||
// can respond to it AFTER we've finished tearing the start down.
|
||||
var pendingStop *stopReq
|
||||
select {
|
||||
// doStart finished on its own — either successfully (latch
|
||||
// cmd/handler and move to Ready) or with an error (back to
|
||||
// Stopped). Either way wake WaitReady subscribers and reply
|
||||
// to the Run caller.
|
||||
case res := <-resultCh:
|
||||
if res.err == nil {
|
||||
cmd = res.cmd
|
||||
cmdDone = res.cmdDone
|
||||
cmdCancel = res.cancel
|
||||
fn := res.handlerFn
|
||||
p.handler.Store(&fn)
|
||||
setState(StateReady)
|
||||
notifyWaiters(nil)
|
||||
// Park the Run response — Run blocks until the process
|
||||
// terminates, so we only fire this when Stop, parentCtx,
|
||||
// or the upstream exit takes the process down.
|
||||
runResp = req.respond
|
||||
|
||||
// Start TTL goroutine if configured — self-terminates
|
||||
// when state leaves StateReady.
|
||||
if p.config.UnloadAfter > 0 {
|
||||
ttlDuration := time.Duration(p.config.UnloadAfter) * time.Second
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
if p.State() != StateReady {
|
||||
return
|
||||
}
|
||||
if p.inflight.Load() != 0 {
|
||||
continue
|
||||
}
|
||||
if time.Since(time.Unix(0, p.lastUse.Load())) > ttlDuration {
|
||||
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.id, p.config.UnloadAfter)
|
||||
p.Stop(10 * time.Second)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
} else {
|
||||
setState(StateStopped)
|
||||
notifyWaiters(res.err)
|
||||
req.respond <- res.err
|
||||
}
|
||||
|
||||
// Stop arrived while doStart was still running. Cancel the
|
||||
// start context to abort it, then wait for doStart to return.
|
||||
// If doStart had already crossed the finish line before
|
||||
// cancellation took effect, it returns a live cmd that we
|
||||
// must kill ourselves. The Run caller gets ErrAbort; the Stop
|
||||
// caller is parked in pendingStop and answered below.
|
||||
case stop := <-p.stopCh:
|
||||
cancelStart()
|
||||
res := <-resultCh
|
||||
if res.cmd != nil {
|
||||
p.killProcess(res.cmd, res.cancel, res.cmdDone, stop.timeout)
|
||||
}
|
||||
setState(StateStopped)
|
||||
notifyWaiters(ErrStartAborted)
|
||||
req.respond <- ErrStartAborted
|
||||
pendingStop = &stop
|
||||
|
||||
// Parent context cancelled (e.g. config reload) while doStart
|
||||
// was still running. Stop() returns early when parentCtx is
|
||||
// done and never sends on stopCh, so we must handle shutdown
|
||||
// here to avoid leaving doStart running indefinitely.
|
||||
case <-p.parentCtx.Done():
|
||||
cancelStart()
|
||||
// Mark shutdown before tearing the process down: killProcess
|
||||
// may block (e.g. taskkill on Windows is slow to spawn), and
|
||||
// callers observing State() should see StateShutdown promptly
|
||||
// rather than a stale StateStarting.
|
||||
setState(StateShutdown)
|
||||
res := <-resultCh
|
||||
if res.cmd != nil {
|
||||
p.killProcess(res.cmd, res.cancel, res.cmdDone, parentCancelGraceTimeout)
|
||||
}
|
||||
notifyWaiters(fmt.Errorf("[%s] shutdown", p.id))
|
||||
respondRun(fmt.Errorf("[%s] shutdown", p.id))
|
||||
return
|
||||
}
|
||||
// cancelStart is idempotent; calling it again here ensures the
|
||||
// context is released even on the success path (govet leak check).
|
||||
cancelStart()
|
||||
if pendingStop != nil {
|
||||
pendingStop.respond <- nil
|
||||
}
|
||||
|
||||
// Stop: tear down a running process.
|
||||
case stop := <-p.stopCh:
|
||||
if cmd != nil {
|
||||
setState(StateStopping)
|
||||
p.killProcess(cmd, cmdCancel, cmdDone, stop.timeout)
|
||||
cmd = nil
|
||||
cmdDone = nil
|
||||
cmdCancel = nil
|
||||
p.handler.Store(nil)
|
||||
}
|
||||
// Stop is a no-op (and not an error) when already Stopped — this
|
||||
// is what makes it idempotent for callers that don't track state.
|
||||
setState(StateStopped)
|
||||
respondRun(nil)
|
||||
stop.respond <- nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) doStart(startCtx context.Context, healthCheckTimeout time.Duration) startResult {
|
||||
if p.config.Proxy == "" {
|
||||
return startResult{err: fmt.Errorf("upstream proxy missing")}
|
||||
}
|
||||
|
||||
args, err := p.config.SanitizedCommand()
|
||||
if err != nil {
|
||||
return startResult{err: fmt.Errorf("unable to get sanitized command: %w", err)}
|
||||
}
|
||||
|
||||
proxyURL, err := url.Parse(p.config.Proxy)
|
||||
if err != nil {
|
||||
return startResult{err: fmt.Errorf("invalid proxy URL %q: %w", p.config.Proxy, err)}
|
||||
}
|
||||
|
||||
reverseProxy := httputil.NewSingleHostReverseProxy(proxyURL)
|
||||
reverseProxy.Transport = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: time.Duration(p.config.Timeouts.Connect) * time.Second,
|
||||
KeepAlive: time.Duration(p.config.Timeouts.KeepAlive) * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: time.Duration(p.config.Timeouts.TLSHandshake) * time.Second,
|
||||
ResponseHeaderTimeout: time.Duration(p.config.Timeouts.ResponseHeader) * time.Second,
|
||||
ExpectContinueTimeout: time.Duration(p.config.Timeouts.ExpectContinue) * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: time.Duration(p.config.Timeouts.IdleConn) * time.Second,
|
||||
}
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||
resp.Header.Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// httputil.ReverseProxy panics with http.ErrAbortHandler when the upstream
|
||||
// disconnects after response headers have been sent. Recover here so the
|
||||
// streaming termination is treated as a normal client/upstream disconnect.
|
||||
// see: https://github.com/golang/go/issues/23643
|
||||
handlerFn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
if rec == http.ErrAbortHandler {
|
||||
p.proxyLogger.Infof("<%s> recovered from upstream disconnection during streaming", p.id)
|
||||
} else {
|
||||
p.proxyLogger.Warnf("<%s> recovered from panic: %v", p.id, rec)
|
||||
}
|
||||
}
|
||||
}()
|
||||
reverseProxy.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
// cmdCtx + cmd.Cancel are wired as a safety net: if the context is ever
|
||||
// cancelled while the process is alive, cmd.Cancel sends SIGTERM / CmdStop
|
||||
// and the runtime escalates to SIGKILL after cmd.WaitDelay. In the normal
|
||||
// teardown path killProcess sends the stop signal directly instead, so
|
||||
// cmd.WaitDelay only acts as the inherited-pipe backstop measured from
|
||||
// process exit (see killProcess).
|
||||
cmdCtx, cmdCancel := context.WithCancel(context.Background())
|
||||
cmd := exec.CommandContext(cmdCtx, args[0], args[1:]...)
|
||||
cmd.Stderr = p.processLogger
|
||||
cmd.Stdout = p.processLogger
|
||||
cmd.Env = append(cmd.Environ(), p.config.Env...)
|
||||
cmd.Cancel = func() error { return p.sendStopSignal(cmd) }
|
||||
cmd.WaitDelay = p.waitDelay
|
||||
setProcAttributes(cmd)
|
||||
|
||||
p.proxyLogger.Debugf("<%s> Executing start command: %s, env: %s", p.id, strings.Join(args, " "), strings.Join(p.config.Env, ", "))
|
||||
|
||||
cmdDone := make(chan struct{})
|
||||
if err := cmd.Start(); err != nil {
|
||||
cmdCancel()
|
||||
return startResult{err: fmt.Errorf("failed to start command '%s': %w", strings.Join(args, " "), err)}
|
||||
}
|
||||
|
||||
go func() {
|
||||
waitErr := cmd.Wait()
|
||||
switch st := p.State(); {
|
||||
case waitErr == nil:
|
||||
p.proxyLogger.Debugf("<%s> process exited cleanly", p.id)
|
||||
case st == StateStopping || st == StateShutdown:
|
||||
// Expected: we force-terminated the process. A forced kill exits
|
||||
// the child with a non-zero code (e.g. taskkill /f on Windows
|
||||
// yields exit status 1), so this is not an error.
|
||||
p.proxyLogger.Debugf("<%s> process stopped by llama-swap: %v", p.id, waitErr)
|
||||
default:
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
p.proxyLogger.Debugf("<%s> process exited: code=%d, err=%v", p.id, exitErr.ExitCode(), waitErr)
|
||||
} else {
|
||||
p.proxyLogger.Debugf("<%s> process exited with error: %v", p.id, waitErr)
|
||||
}
|
||||
}
|
||||
close(cmdDone)
|
||||
}()
|
||||
|
||||
abort := func(err error) startResult {
|
||||
p.killProcess(cmd, cmdCancel, cmdDone, 5*time.Second)
|
||||
return startResult{err: err}
|
||||
}
|
||||
prematureExit := func() startResult {
|
||||
cmdCancel()
|
||||
return startResult{err: fmt.Errorf("upstream command exited prematurely")}
|
||||
}
|
||||
|
||||
if startCtx.Err() != nil {
|
||||
return abort(ErrStartAborted)
|
||||
}
|
||||
|
||||
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
||||
if checkEndpoint == "none" {
|
||||
return startResult{cmd: cmd, cmdDone: cmdDone, cancel: cmdCancel, handlerFn: handlerFn}
|
||||
}
|
||||
|
||||
// Wait 250ms for the command to start up before health checking
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
return abort(ErrStartAborted)
|
||||
case <-time.After(250 * time.Millisecond):
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(healthCheckTimeout)
|
||||
for {
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
return abort(ErrStartAborted)
|
||||
case <-cmdDone:
|
||||
return prematureExit()
|
||||
default:
|
||||
}
|
||||
|
||||
if time.Now().After(deadline) {
|
||||
return abort(fmt.Errorf("health check timed out after %v", healthCheckTimeout))
|
||||
}
|
||||
|
||||
req, _ := http.NewRequestWithContext(startCtx, "GET", p.config.CheckEndpoint, nil)
|
||||
rr := httptest.NewRecorder()
|
||||
reverseProxy.ServeHTTP(rr, req)
|
||||
resp := rr.Result()
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
p.proxyLogger.Infof("<%s> Health check passed on %s%s", p.id, p.config.Proxy, p.config.CheckEndpoint)
|
||||
break
|
||||
} else if startCtx.Err() != nil {
|
||||
return abort(ErrStartAborted)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
return abort(ErrStartAborted)
|
||||
case <-cmdDone:
|
||||
return prematureExit()
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
return startResult{cmd: cmd, cmdDone: cmdDone, cancel: cmdCancel, handlerFn: handlerFn}
|
||||
}
|
||||
|
||||
// sendStopSignal runs the configured CmdStop (if any) or sends SIGTERM to
|
||||
// the upstream process. Wired up as cmd.Cancel so it fires whenever the
|
||||
// cmd's context is cancelled.
|
||||
func (p *ProcessCommand) sendStopSignal(cmd *exec.Cmd) error {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
p.processLogger.Debugf("<%s> sendStopSignal() called with nil cmd or process, nothing to stop", p.id)
|
||||
return nil
|
||||
}
|
||||
pid := cmd.Process.Pid
|
||||
if p.config.CmdStop != "" {
|
||||
p.processLogger.Debugf("<%s> sendStopSignal() using CmdStop %q for pid %d", p.id, p.config.CmdStop, pid)
|
||||
stopArgs, err := config.SanitizeCommand(
|
||||
strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", pid)),
|
||||
)
|
||||
if err == nil {
|
||||
p.processLogger.Debugf("<%s> sendStopSignal() running stop command: %s", p.id, strings.Join(stopArgs, " "))
|
||||
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
|
||||
stopCmd.Env = cmd.Env
|
||||
setProcAttributes(stopCmd)
|
||||
runErr := stopCmd.Run()
|
||||
if runErr != nil {
|
||||
p.processLogger.Errorf("<%s> sendStopSignal() stop command failed: %v", p.id, runErr)
|
||||
} else {
|
||||
p.processLogger.Debugf("<%s> sendStopSignal() stop command completed for pid %d", p.id, pid)
|
||||
}
|
||||
return runErr
|
||||
}
|
||||
// fall through to SIGTERM if sanitize failed
|
||||
p.processLogger.Errorf("<%s> sendStopSignal() failed to sanitize CmdStop %q: %v, falling back to terminateProcessTree", p.id, p.config.CmdStop, err)
|
||||
}
|
||||
// On Unix this SIGTERMs the whole process group so a forked grandchild
|
||||
// (e.g. a shell wrapper that backgrounds the real binary) is taken down
|
||||
// with the parent rather than orphaned.
|
||||
p.processLogger.Debugf("<%s> sendStopSignal() no CmdStop configured, calling terminateProcessTree for pid %d", p.id, pid)
|
||||
termErr := terminateProcessTree(cmd)
|
||||
if termErr != nil {
|
||||
p.processLogger.Errorf("<%s> sendStopSignal() terminateProcessTree failed for pid %d: %v", p.id, pid, termErr)
|
||||
}
|
||||
return termErr
|
||||
}
|
||||
|
||||
// killProcess terminates the upstream process. The flow:
|
||||
//
|
||||
// 1. Send the graceful stop signal (CmdStop / SIGTERM) directly — NOT by
|
||||
// cancelling cmdCtx. Cancelling the context would start cmd.WaitDelay
|
||||
// immediately, which force-kills the process WaitDelay after the signal
|
||||
// and would silently cap gracefulTimeout at WaitDelay whenever
|
||||
// gracefulTimeout is the longer of the two.
|
||||
// 2. We wait up to gracefulTimeout for the process to exit on its own.
|
||||
// 3. If still alive, we SIGKILL the process group directly (Unix) so any
|
||||
// forked descendant is force-terminated alongside the parent.
|
||||
// 4. We wait on cmdDone. cmd.WaitDelay (set when the cmd was built) is the
|
||||
// critical backstop here: once the process exits, if a forked grandchild
|
||||
// inherited the stdout/stderr pipes and is still holding them, the runtime
|
||||
// force-closes the pipes WaitDelay after the exit and cmd.Wait() unblocks.
|
||||
// Because we never cancelled the context, that WaitDelay timer measures
|
||||
// from process exit (see os/exec awaitGoroutines), not from this call.
|
||||
// Without WaitDelay this select would hang forever (the v219 bug).
|
||||
//
|
||||
// cancel() is still invoked (deferred) to release the context, but only after
|
||||
// the process has exited and os/exec's ctx watcher has already torn down, so it
|
||||
// never re-fires cmd.Cancel.
|
||||
func (p *ProcessCommand) killProcess(cmd *exec.Cmd, cancel context.CancelFunc, cmdDone <-chan struct{}, gracefulTimeout time.Duration) {
|
||||
if cancel == nil {
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
// Deliver CmdStop / SIGTERM in a goroutine so a slow or hanging CmdStop
|
||||
// cannot block the run() goroutine; the gracefulTimeout + Process.Kill
|
||||
// path below still guarantees teardown.
|
||||
if cmd != nil {
|
||||
go func() {
|
||||
p.proxyLogger.Debugf("[%s] sending stop signal with timeout %v", p.id, gracefulTimeout)
|
||||
if err := p.sendStopSignal(cmd); err != nil {
|
||||
p.proxyLogger.Warnf("[%s] stop signal failed: %v", p.id, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
timer := time.NewTimer(gracefulTimeout)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-cmdDone:
|
||||
return
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
if cmd != nil {
|
||||
// SIGKILL the whole process group on Unix so any descendant that
|
||||
// ignored or outlived the graceful signal is force-terminated too.
|
||||
_ = killProcessTree(cmd)
|
||||
}
|
||||
<-cmdDone
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) ID() string {
|
||||
return p.id
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) Run(timeout time.Duration) error {
|
||||
req := runReq{
|
||||
timeout: timeout,
|
||||
respond: make(chan error, 1),
|
||||
}
|
||||
select {
|
||||
case p.runCh <- req:
|
||||
case <-p.parentCtx.Done():
|
||||
return fmt.Errorf("[%s] shutdown", p.id)
|
||||
}
|
||||
select {
|
||||
case err := <-req.respond:
|
||||
return err
|
||||
case <-p.parentCtx.Done():
|
||||
return fmt.Errorf("[%s] shutdown", p.id)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) WaitReady(ctx context.Context) error {
|
||||
req := waitReadyReq{respond: make(chan error, 1)}
|
||||
select {
|
||||
case p.waitReadyCh <- req:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-p.parentCtx.Done():
|
||||
return fmt.Errorf("[%s] shutdown", p.id)
|
||||
}
|
||||
select {
|
||||
case err := <-req.respond:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) Stop(timeout time.Duration) error {
|
||||
req := stopReq{
|
||||
timeout: timeout,
|
||||
respond: make(chan error, 1),
|
||||
}
|
||||
select {
|
||||
case p.stopCh <- req:
|
||||
case <-p.parentCtx.Done():
|
||||
return fmt.Errorf("[%s] shutdown", p.id)
|
||||
}
|
||||
return <-req.respond
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) State() ProcessState {
|
||||
if s, ok := p.state.Load().(ProcessState); ok {
|
||||
return s
|
||||
}
|
||||
return StateStopped
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
fn := p.handler.Load()
|
||||
if fn == nil {
|
||||
http.Error(w, fmt.Sprintf("llama-swap-error: [%s] process is not ready", p.id), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
p.inflight.Add(1)
|
||||
defer func() {
|
||||
p.lastUse.Store(time.Now().UnixNano())
|
||||
p.inflight.Add(-1)
|
||||
}()
|
||||
(*fn)(w, r)
|
||||
}
|
||||
@@ -0,0 +1,262 @@
|
||||
//go:build !windows
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
)
|
||||
|
||||
// TestProcessCommand_StopForkingWrapper is a regression for the bug reported
|
||||
// against v219 where Stop would hang indefinitely when the upstream command
|
||||
// is a shell wrapper that forks the real binary (e.g. `#!/bin/bash` then
|
||||
// `"$@"`). After SIGTERM the wrapper dies but the grandchild inherits the
|
||||
// stdout/stderr pipes; cmd.Wait() blocks waiting for the pipe-copy goroutine
|
||||
// to drain EOF, which never happens while the grandchild holds the fds.
|
||||
//
|
||||
// The fix is cmd.WaitDelay (combined with exec.CommandContext + cmd.Cancel),
|
||||
// which causes the runtime to force-close the pipes after the delay so
|
||||
// cmd.Wait() — and therefore Stop — returns.
|
||||
func TestProcessCommand_StopForkingWrapper(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
port := getFreePort(t)
|
||||
dir := t.TempDir()
|
||||
pidFile := filepath.Join(dir, "child.pid")
|
||||
|
||||
// Wrapper script: backgrounds the child (which inherits stdout/stderr),
|
||||
// records its PID for cleanup, then waits. When SIGTERM hits bash it
|
||||
// dies without forwarding the signal; the grandchild keeps running and
|
||||
// keeps the inherited pipe fds open. This is the scenario reported in
|
||||
// the v219 regression.
|
||||
wrapper := filepath.Join(dir, "wrapper.sh")
|
||||
script := fmt.Sprintf("#!/bin/bash\n%q -port %d -silent &\necho $! > %q\nwait\n",
|
||||
simpleResponderPath, port, pidFile)
|
||||
if err := os.WriteFile(wrapper, []byte(script), 0o755); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { killChildFromPidFile(pidFile) })
|
||||
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: wrapper,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
})
|
||||
// Shrink the pipe-close backstop so the test doesn't sit at the
|
||||
// production default (10s). Must be set before Run() so doStart picks
|
||||
// it up when building the cmd.
|
||||
const testWaitDelay = 250 * time.Millisecond
|
||||
p.waitDelay = testWaitDelay
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
|
||||
// Stop must return within a bounded time even though the grandchild
|
||||
// is still holding the pipe open. Budget is generous on top of
|
||||
// testWaitDelay to absorb scheduling jitter on slow CI runners; the
|
||||
// pre-fix behaviour was an unbounded hang, so any reasonable cap
|
||||
// distinguishes pass from fail.
|
||||
stopReturned := make(chan error, 1)
|
||||
stopStart := time.Now()
|
||||
go func() { stopReturned <- p.Stop(testStopTimeout) }()
|
||||
|
||||
const stopBudget = testWaitDelay + 2*time.Second
|
||||
select {
|
||||
case err := <-stopReturned:
|
||||
if err != nil {
|
||||
t.Fatalf("Stop: %v", err)
|
||||
}
|
||||
t.Logf("Stop returned in %v", time.Since(stopStart))
|
||||
case <-time.After(stopBudget):
|
||||
t.Fatalf("Stop did not return within %v — cmd.Wait() likely hung on inherited pipe", stopBudget)
|
||||
}
|
||||
|
||||
if got := p.State(); got != StateStopped {
|
||||
t.Errorf("after Stop: expected state %s, got %s", StateStopped, got)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Errorf("Run did not return after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_StopHonorsGracefulTimeout is a regression for the bug
|
||||
// where cmd.WaitDelay capped the graceful shutdown window. killProcess used to
|
||||
// cancel the cmd context to deliver SIGTERM, which starts cmd.WaitDelay
|
||||
// immediately; a process whose SIGTERM handler needs longer than WaitDelay to
|
||||
// finish was force-killed early even though Stop was given a much longer
|
||||
// timeout. The fix sends the signal directly so WaitDelay measures from process
|
||||
// exit (its inherited-pipe backstop role), leaving the graceful window to the
|
||||
// caller's Stop timeout.
|
||||
func TestProcessCommand_StopHonorsGracefulTimeout(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
marker := filepath.Join(dir, "graceful.done")
|
||||
ready := filepath.Join(dir, "trap.ready")
|
||||
|
||||
// On SIGTERM, sleep past the (short) WaitDelay, then write the marker and
|
||||
// exit cleanly. If WaitDelay still drove the kill, bash would be SIGKILLed
|
||||
// mid-handler and the marker would never be written. The ready file is
|
||||
// written only after the trap is installed so the test does not race
|
||||
// SIGTERM ahead of it (CheckEndpoint:none marks ready before bash runs).
|
||||
script := filepath.Join(dir, "graceful.sh")
|
||||
body := fmt.Sprintf(
|
||||
"#!/bin/bash\ncleanup() { sleep 0.6; echo done > %q; exit 0; }\ntrap cleanup SIGTERM\necho ready > %q\nwhile true; do sleep 0.1; done\n",
|
||||
marker, ready,
|
||||
)
|
||||
if err := os.WriteFile(script, []byte(body), 0o755); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: script,
|
||||
Proxy: "http://127.0.0.1:1", // unused: health check disabled
|
||||
CheckEndpoint: "none",
|
||||
})
|
||||
// WaitDelay shorter than the handler's 0.6s sleep, and far shorter than the
|
||||
// Stop timeout below — this is the window the old code mis-killed in.
|
||||
p.waitDelay = 200 * time.Millisecond
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
|
||||
// Wait until the trap is installed before stopping.
|
||||
trapDeadline := time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
if _, err := os.Stat(ready); err == nil {
|
||||
break
|
||||
}
|
||||
if time.Now().After(trapDeadline) {
|
||||
t.Fatalf("script did not install SIGTERM trap in time")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
stopStart := time.Now()
|
||||
if err := p.Stop(5 * time.Second); err != nil {
|
||||
t.Fatalf("Stop: %v", err)
|
||||
}
|
||||
elapsed := time.Since(stopStart)
|
||||
|
||||
// The handler must have run to completion (marker written) rather than
|
||||
// being force-killed at waitDelay.
|
||||
if _, err := os.Stat(marker); err != nil {
|
||||
t.Fatalf("graceful handler did not complete (marker missing): %v", err)
|
||||
}
|
||||
// And Stop must have waited for the handler (>~0.6s), not returned at the
|
||||
// 200ms waitDelay.
|
||||
if elapsed < 500*time.Millisecond {
|
||||
t.Fatalf("Stop returned in %v — process was killed before its graceful handler finished", elapsed)
|
||||
}
|
||||
|
||||
if got := p.State(); got != StateStopped {
|
||||
t.Errorf("after Stop: expected state %s, got %s", StateStopped, got)
|
||||
}
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Errorf("Run did not return after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_StopReapsForkedGrandchild verifies that stopping a forking
|
||||
// wrapper takes down the backgrounded grandchild too, rather than leaving it as
|
||||
// an orphan. The fix is Setpgid (runtime_unix.go): the wrapper leads its own
|
||||
// process group, so the stop signal is delivered to the whole group via the
|
||||
// negative PID and reaches the grandchild the wrapper never reaped.
|
||||
func TestProcessCommand_StopReapsForkedGrandchild(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
port := getFreePort(t)
|
||||
dir := t.TempDir()
|
||||
pidFile := filepath.Join(dir, "child.pid")
|
||||
|
||||
wrapper := filepath.Join(dir, "wrapper.sh")
|
||||
script := fmt.Sprintf("#!/bin/bash\n%q -port %d -silent &\necho $! > %q\nwait\n",
|
||||
simpleResponderPath, port, pidFile)
|
||||
if err := os.WriteFile(wrapper, []byte(script), 0o755); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { killChildFromPidFile(pidFile) })
|
||||
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: wrapper,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
})
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
|
||||
// Read the grandchild PID the wrapper recorded.
|
||||
var childPID int
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
data, err := os.ReadFile(pidFile)
|
||||
if err == nil {
|
||||
if pid, perr := strconv.Atoi(strings.TrimSpace(string(data))); perr == nil && pid > 0 {
|
||||
childPID = pid
|
||||
break
|
||||
}
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatalf("wrapper did not record grandchild PID")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop: %v", err)
|
||||
}
|
||||
|
||||
// After Stop the grandchild must be gone. Signal 0 probes liveness without
|
||||
// actually sending a signal; give it a brief window to exit after the
|
||||
// group SIGTERM.
|
||||
proc, err := os.FindProcess(childPID)
|
||||
if err != nil {
|
||||
t.Fatalf("FindProcess: %v", err)
|
||||
}
|
||||
gone := false
|
||||
for i := 0; i < 100; i++ {
|
||||
if err := proc.Signal(syscall.Signal(0)); err != nil {
|
||||
gone = true
|
||||
break
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
if !gone {
|
||||
t.Errorf("grandchild PID %d still alive after Stop — process group was not reaped", childPID)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Errorf("Run did not return after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
// killChildFromPidFile reads a PID written by the wrapper script and SIGKILLs
|
||||
// it so leaked orphans don't accumulate between test runs. Best-effort.
|
||||
func killChildFromPidFile(pidFile string) {
|
||||
data, err := os.ReadFile(pidFile)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
|
||||
if err != nil || pid <= 0 {
|
||||
return
|
||||
}
|
||||
proc, err := os.FindProcess(pid)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = proc.Kill()
|
||||
}
|
||||
@@ -0,0 +1,646 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
const (
|
||||
testStartTimeout = 3 * time.Second
|
||||
testStopTimeout = 2 * time.Second
|
||||
testReturnTimeout = 1 * time.Second
|
||||
testPollInterval = 20 * time.Millisecond
|
||||
testLogPollInterval = 10 * time.Millisecond
|
||||
)
|
||||
|
||||
func newProcessCommand(t *testing.T, conf config.ModelConfig) *ProcessCommand {
|
||||
t.Helper()
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
p, err := New(context.Background(), t.Name(), conf, logger, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// runAsync starts Run in a goroutine and waits until the process is ready,
|
||||
// matching the new interface contract where Run blocks until the process is
|
||||
// terminated. Returns a channel that delivers Run's eventual error.
|
||||
func runAsync(t *testing.T, p *ProcessCommand) <-chan error {
|
||||
t.Helper()
|
||||
ch := make(chan error, 1)
|
||||
go func() { ch <- p.Run(testStartTimeout) }()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testStartTimeout)
|
||||
defer cancel()
|
||||
if err := p.WaitReady(ctx); err != nil {
|
||||
t.Fatalf("WaitReady: %v", err)
|
||||
}
|
||||
return ch
|
||||
}
|
||||
|
||||
func TestProcessCommand_StartStop(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
cmd, port := simpleResponderCmd(t, "-silent", "-respond hello")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
})
|
||||
t.Cleanup(func() { p.Stop(testStopTimeout) })
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
|
||||
// before start: no handler
|
||||
rr := httptest.NewRecorder()
|
||||
p.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("before start: expected 503, got %d", rr.Code)
|
||||
}
|
||||
if body := rr.Body.String(); !strings.Contains(body, "llama-swap-error") {
|
||||
t.Errorf("before start: expected body to contain %q, got %q", "llama-swap-error", body)
|
||||
}
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
if got := p.State(); got != StateReady {
|
||||
t.Errorf("after Run: expected state %s, got %s", StateReady, got)
|
||||
}
|
||||
|
||||
rr = httptest.NewRecorder()
|
||||
p.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("after Run: expected 200, got %d", rr.Code)
|
||||
}
|
||||
if body := rr.Body.String(); body != "hello" {
|
||||
t.Errorf("expected body %q, got %q", "hello", body)
|
||||
}
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop() error: %v", err)
|
||||
}
|
||||
if got := p.State(); got != StateStopped {
|
||||
t.Errorf("after Stop: expected state %s, got %s", StateStopped, got)
|
||||
}
|
||||
select {
|
||||
case err := <-runErr:
|
||||
if err != nil {
|
||||
t.Errorf("Run() after Stop: expected nil, got %v", err)
|
||||
}
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Fatal("Run() did not return after Stop")
|
||||
}
|
||||
|
||||
// after stop: handler cleared
|
||||
rr = httptest.NewRecorder()
|
||||
p.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("after stop: expected 503, got %d", rr.Code)
|
||||
}
|
||||
if body := rr.Body.String(); !strings.Contains(body, "llama-swap-error") {
|
||||
t.Errorf("after stop: expected body to contain %q, got %q", "llama-swap-error", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessCommand_Run_Idempotent(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
cmd, port := simpleResponderCmd(t, "-silent")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
})
|
||||
t.Cleanup(func() { p.Stop(testStopTimeout) })
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
|
||||
if err := p.Run(testStartTimeout); err == nil {
|
||||
t.Error("second Run() while running: expected error, got nil")
|
||||
}
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop() error: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Fatal("Run() did not return after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessCommand_Stop_Idempotent(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
cmd, port := simpleResponderCmd(t, "-silent")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
})
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop() before Run(): %v", err)
|
||||
}
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("first Stop() error: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Fatal("Run() did not return after Stop")
|
||||
}
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("second Stop() error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_StopCancelsRun verifies that a Stop sent while Run is
|
||||
// executing its health-check loop returns ErrAbort to the Run caller.
|
||||
//
|
||||
// A blocking mock HTTP server is used as the proxy so the test can deterministically
|
||||
// know when doStart is inside the health-check loop before issuing Stop.
|
||||
func TestProcessCommand_StopCancelsRun(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
healthCheckStarted := make(chan struct{}, 1)
|
||||
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Signal that a health check is in-flight, then block until the client
|
||||
// cancels (which happens when Stop cancels the start context).
|
||||
select {
|
||||
case healthCheckStarted <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
<-r.Context().Done()
|
||||
http.Error(w, "mock cancelled", http.StatusServiceUnavailable)
|
||||
}))
|
||||
defer mock.Close()
|
||||
|
||||
// simple-responder is the real process; health checks go to the blocking mock.
|
||||
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: mock.URL,
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 30,
|
||||
})
|
||||
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- p.Run(testStartTimeout)
|
||||
}()
|
||||
|
||||
// Block until doStart is actually performing a health check, guaranteeing
|
||||
// that Run is in-flight when Stop is called.
|
||||
<-healthCheckStarted
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop() error: %v", err)
|
||||
}
|
||||
|
||||
if err := <-runErrCh; !errors.Is(err, ErrStartAborted) {
|
||||
t.Errorf("expected ErrStartAborted from Run, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_ParentCtxCancelDuringStart verifies that cancelling the
|
||||
// parent context while doStart is health-checking causes the process to
|
||||
// transition to StateShutdown promptly, not wait for the health-check timeout.
|
||||
//
|
||||
// This is the config-reload race: Stop() returns early when parentCtx is
|
||||
// already done and never writes to stopCh, so without a parentCtx.Done()
|
||||
// case in the inner select, the process would keep loading indefinitely.
|
||||
func TestProcessCommand_ParentCtxCancelDuringStart(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
healthCheckStarted := make(chan struct{}, 1)
|
||||
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
select {
|
||||
case healthCheckStarted <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
<-r.Context().Done()
|
||||
http.Error(w, "mock cancelled", http.StatusServiceUnavailable)
|
||||
}))
|
||||
defer mock.Close()
|
||||
|
||||
parentCtx, cancelParent := context.WithCancel(context.Background())
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||
p, err := New(parentCtx, t.Name(), config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: mock.URL,
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 60,
|
||||
}, logger, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() { runErrCh <- p.Run(60 * time.Second) }()
|
||||
|
||||
<-healthCheckStarted
|
||||
|
||||
// Cancel parent context to simulate a config reload tearing down the old server.
|
||||
cancelParent()
|
||||
|
||||
select {
|
||||
case err := <-runErrCh:
|
||||
if !strings.Contains(err.Error(), "shutdown") {
|
||||
t.Errorf("Run error = %v, want shutdown error", err)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("process did not shut down within 5s after parent context cancel during start")
|
||||
}
|
||||
|
||||
// Run() may return before the run() goroutine writes StateShutdown;
|
||||
// poll briefly to avoid a spurious race in the assertion.
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if p.State() == StateShutdown {
|
||||
break
|
||||
}
|
||||
time.Sleep(testPollInterval)
|
||||
}
|
||||
if got := p.State(); got != StateShutdown {
|
||||
t.Errorf("after cancel: expected StateShutdown, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_RunStopCycle runs several sequential start/stop pairs on
|
||||
// fresh processes to confirm they are reusable.
|
||||
func TestProcessCommand_RunStopCycle(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
for i := range 3 {
|
||||
cmd, port := simpleResponderCmd(t, "-silent")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
})
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
|
||||
req := httptest.NewRequest("GET", "/health", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
p.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("cycle %d: expected 200 from /health, got %d", i, rr.Code)
|
||||
}
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("cycle %d Stop() error: %v", i, err)
|
||||
}
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Fatalf("cycle %d: Run() did not return after Stop", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_ReverseProxyPanicIsRecovered drives the full proxy path:
|
||||
// the upstream responds healthy on /health (so Run completes), then on the
|
||||
// actual proxied request it hijacks the connection and closes it mid-body.
|
||||
// That upstream EOF makes httputil.ReverseProxy.copyResponse return an error,
|
||||
// which panics with http.ErrAbortHandler — the wrapped handlerFn must recover
|
||||
// and log the disconnect.
|
||||
//
|
||||
// Requests are issued through an httptest.NewServer wrapping the process so
|
||||
// the panic actually fires (httputil only panics on copy errors when the
|
||||
// request carries http.ServerContextKey, which a real server sets).
|
||||
//
|
||||
// see: https://github.com/golang/go/issues/23643
|
||||
func TestProcessCommand_ReverseProxyPanicIsRecovered(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/health" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
// Send a Content-Length that promises 100 bytes, deliver only a few,
|
||||
// then slam the connection shut. The reverse proxy will see EOF
|
||||
// before the body is fully copied and panic with ErrAbortHandler.
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
t.Errorf("upstream: hijack not supported")
|
||||
return
|
||||
}
|
||||
conn, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
t.Errorf("upstream: hijack: %v", err)
|
||||
return
|
||||
}
|
||||
_, _ = conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 100\r\nContent-Type: text/plain\r\n\r\npartial"))
|
||||
_ = conn.Close()
|
||||
}))
|
||||
t.Cleanup(upstream.Close)
|
||||
|
||||
// Capture proxy log output so we can assert the recover message was
|
||||
// emitted by handlerFn.
|
||||
logBuf := &syncBuffer{}
|
||||
proxyLogger := logmon.NewWriter(logBuf)
|
||||
procLogger := logmon.NewWriter(io.Discard)
|
||||
|
||||
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||
p, err := New(context.Background(), t.Name(), config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: upstream.URL,
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
}, procLogger, proxyLogger)
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { p.Stop(testStopTimeout) })
|
||||
|
||||
_ = runAsync(t, p)
|
||||
|
||||
// Wrap p in an httptest server so requests get http.ServerContextKey
|
||||
// automatically — that is what makes httputil.ReverseProxy raise the panic.
|
||||
front := httptest.NewServer(p)
|
||||
t.Cleanup(front.Close)
|
||||
|
||||
resp, err := http.Get(front.URL + "/disconnect")
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
const want = "recovered from upstream disconnection"
|
||||
deadline := time.Now().Add(testReturnTimeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if strings.Contains(logBuf.String(), want) {
|
||||
return
|
||||
}
|
||||
time.Sleep(testLogPollInterval)
|
||||
}
|
||||
t.Errorf("expected proxy log to contain %q; got:\n%s", want, logBuf.String())
|
||||
}
|
||||
|
||||
// syncBuffer is a concurrent-safe bytes.Buffer for capturing logmon output.
|
||||
type syncBuffer struct {
|
||||
mu sync.Mutex
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
func (b *syncBuffer) Write(p []byte) (int, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.buf.Write(p)
|
||||
}
|
||||
|
||||
func (b *syncBuffer) String() string {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.buf.String()
|
||||
}
|
||||
|
||||
// TestProcessCommand_TTL_StopsAfterIdle verifies that a process with a TTL
|
||||
// automatically stops itself after the idle timeout has elapsed following its
|
||||
// last request.
|
||||
func TestProcessCommand_TTL_StopsAfterIdle(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
t.Cleanup(mock.Close)
|
||||
|
||||
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||
|
||||
cfg := config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: mock.URL,
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
UnloadAfter: 1, // 1-second TTL
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
cfg.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
}
|
||||
|
||||
p := newProcessCommand(t, cfg)
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
defer func() {
|
||||
if p.State() == StateReady {
|
||||
p.Stop(testStopTimeout)
|
||||
}
|
||||
}()
|
||||
|
||||
if got := p.State(); got != StateReady {
|
||||
t.Fatalf("expected StateReady, got %s", got)
|
||||
}
|
||||
|
||||
// Make one request to prime the last-use timestamp.
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
p.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("expected 200 after request, got %d", rr.Code)
|
||||
}
|
||||
|
||||
// Wait for the TTL goroutine to fire and the process to fully stop.
|
||||
// Poll for StateStopped directly to avoid racing the StateStopping
|
||||
// intermediate state that sits between StateReady and StateStopped.
|
||||
deadline := time.Now().Add(5 * time.Second)
|
||||
for p.State() != StateStopped && time.Now().Before(deadline) {
|
||||
time.Sleep(testPollInterval)
|
||||
}
|
||||
|
||||
if got := p.State(); got != StateStopped {
|
||||
t.Fatalf("TTL did not stop process; state is %s (expected %s)", got, StateStopped)
|
||||
}
|
||||
|
||||
// Run() should have returned nil (clean stop from TTL).
|
||||
select {
|
||||
case err := <-runErr:
|
||||
if err != nil {
|
||||
t.Errorf("Run() after TTL stop: expected nil, got %v", err)
|
||||
}
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Fatal("Run() did not return after TTL-induced stop")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_TTL_ResetsOnRequest verifies that inflight requests
|
||||
// prevent the TTL goroutine from stopping the process, and that the TTL timer
|
||||
// resets after each request completes.
|
||||
func TestProcessCommand_TTL_ResetsOnRequest(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
t.Cleanup(mock.Close)
|
||||
|
||||
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: mock.URL,
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
UnloadAfter: 1, // 1-second TTL
|
||||
})
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
defer func() {
|
||||
if p.State() == StateReady {
|
||||
p.Stop(testStopTimeout)
|
||||
}
|
||||
}()
|
||||
|
||||
// Keep sending requests for 1.5s — past the 1s TTL — and verify
|
||||
// the process never stops while traffic is flowing.
|
||||
stopAt := time.Now().Add(1500 * time.Millisecond)
|
||||
for time.Now().Before(stopAt) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
p.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d", rr.Code)
|
||||
}
|
||||
if p.State() != StateReady {
|
||||
t.Fatalf("process was stopped during active traffic (state=%s)", p.State())
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
if got := p.State(); got != StateReady {
|
||||
t.Fatalf("expected StateReady while traffic was active, got %s", got)
|
||||
}
|
||||
|
||||
// Now stop manually to clean up.
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop() error: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Fatal("Run() did not return after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_TTL_ZeroDisables verifies that UnloadAfter=0 does not
|
||||
// spawn a TTL goroutine — the process stays ready until explicitly stopped.
|
||||
func TestProcessCommand_TTL_ZeroDisables(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
t.Cleanup(mock.Close)
|
||||
|
||||
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: mock.URL,
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
UnloadAfter: 0, // disabled
|
||||
})
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
defer func() {
|
||||
if p.State() == StateReady {
|
||||
p.Stop(testStopTimeout)
|
||||
}
|
||||
}()
|
||||
|
||||
if got := p.State(); got != StateReady {
|
||||
t.Fatalf("expected StateReady, got %s", got)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
p.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("expected 200 after request, got %d", rr.Code)
|
||||
}
|
||||
|
||||
// No TTL goroutine is spawned when UnloadAfter=0, so a brief sleep is
|
||||
// enough to confirm the process remains ready.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if got := p.State(); got != StateReady {
|
||||
t.Fatalf("process was stopped unexpectedly (state=%s) with TTL=0", got)
|
||||
}
|
||||
|
||||
// Cleanly stop.
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop() error: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Fatal("Run() did not return after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_ConcurrentRunStop launches many concurrent run/stop racing
|
||||
// pairs to exercise the race detector and verify no deadlocks occur.
|
||||
func TestProcessCommand_ConcurrentRunStop(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
for range 10 {
|
||||
cmd, port := simpleResponderCmd(t, "-silent")
|
||||
cfg := config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
cfg.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
}
|
||||
|
||||
p := newProcessCommand(t, cfg)
|
||||
|
||||
runDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(runDone)
|
||||
p.Run(testStartTimeout) //nolint: errcheck — one goroutine wins the race
|
||||
}()
|
||||
go func() {
|
||||
p.Stop(testStopTimeout) //nolint: errcheck
|
||||
}()
|
||||
|
||||
// Backstop: the racing Stop may have arrived before Run got on the
|
||||
// channel (making it a no-op), so keep stopping until Run unblocks.
|
||||
deadline := time.After(testStartTimeout)
|
||||
for done := false; !done; {
|
||||
select {
|
||||
case <-runDone:
|
||||
done = true
|
||||
case <-deadline:
|
||||
t.Fatal("Run did not return")
|
||||
case <-time.After(testPollInterval):
|
||||
p.Stop(testStopTimeout) //nolint: errcheck
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
func TestProcessCommand_EmitsStateChangeEvents(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
var mu sync.Mutex
|
||||
var transitions []shared.ProcessStateChangeEvent
|
||||
cancel := event.On(func(e shared.ProcessStateChangeEvent) {
|
||||
if e.ProcessName != t.Name() {
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
transitions = append(transitions, e)
|
||||
mu.Unlock()
|
||||
})
|
||||
defer cancel()
|
||||
|
||||
cmd, port := simpleResponderCmd(t, "-silent", "-respond hello")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
})
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop: %v", err)
|
||||
}
|
||||
<-runErr
|
||||
|
||||
// Events are delivered asynchronously; give the dispatcher a moment.
|
||||
deadline := time.Now().Add(time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
mu.Lock()
|
||||
n := len(transitions)
|
||||
mu.Unlock()
|
||||
if n >= 4 {
|
||||
break
|
||||
}
|
||||
time.Sleep(testPollInterval)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
for _, e := range transitions {
|
||||
if e.OldState == e.NewState {
|
||||
t.Errorf("emitted no-op transition: %s -> %s", e.OldState, e.NewState)
|
||||
}
|
||||
}
|
||||
|
||||
want := []string{
|
||||
string(StateStopped) + "->" + string(StateStarting),
|
||||
string(StateStarting) + "->" + string(StateReady),
|
||||
string(StateReady) + "->" + string(StateStopping),
|
||||
string(StateStopping) + "->" + string(StateStopped),
|
||||
}
|
||||
got := make([]string, len(transitions))
|
||||
for i, e := range transitions {
|
||||
got[i] = e.OldState + "->" + e.NewState
|
||||
}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("transitions = %v, want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("transitions = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
//go:build !windows
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// setProcAttributes starts the upstream in its own process group (Setpgid) so
|
||||
// the entire process tree can be signalled at once via its negative PID. This
|
||||
// is what lets us reap a forked grandchild — e.g. a shell wrapper that
|
||||
// backgrounds the real binary and exits — instead of leaking it as an orphan
|
||||
// that holds the inherited stdout/stderr pipes open.
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
}
|
||||
|
||||
// terminateProcessTree sends SIGTERM to the whole process group led by the
|
||||
// command, giving every process in the tree a chance to shut down gracefully.
|
||||
func terminateProcessTree(cmd *exec.Cmd) error {
|
||||
return signalProcessTree(cmd, syscall.SIGTERM)
|
||||
}
|
||||
|
||||
// killProcessTree sends SIGKILL to the whole process group, force-terminating
|
||||
// every process in the tree.
|
||||
func killProcessTree(cmd *exec.Cmd) error {
|
||||
return signalProcessTree(cmd, syscall.SIGKILL)
|
||||
}
|
||||
|
||||
// signalProcessTree signals the process group led by cmd.Process. Because the
|
||||
// child was started with Setpgid it is its own group leader (pgid == pid), so
|
||||
// targeting -pid reaches the child and every descendant still in the group.
|
||||
// Falls back to signalling just the child if the group send fails (e.g. the
|
||||
// group has already drained), so we never silently skip the signal.
|
||||
func signalProcessTree(cmd *exec.Cmd, sig syscall.Signal) error {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
if err := syscall.Kill(-cmd.Process.Pid, sig); err != nil {
|
||||
return cmd.Process.Signal(sig)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
//go:build windows
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes. CREATE_NO_WINDOW
|
||||
// keeps the upstream from spawning its own console window.
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
HideWindow: true,
|
||||
CreationFlags: 0x08000000, // CREATE_NO_WINDOW
|
||||
}
|
||||
}
|
||||
|
||||
// terminateProcessTree requests a graceful shutdown of the whole process tree
|
||||
// rooted at cmd.Process. Windows has no SIGTERM or process-group signalling, so
|
||||
// we shell out to `taskkill /t`, which walks the child tree by PID — the
|
||||
// equivalent of signalling a Unix process group. Without /f, taskkill asks the
|
||||
// processes to close rather than force-killing them.
|
||||
func terminateProcessTree(cmd *exec.Cmd) error {
|
||||
return taskkillProcessTree(cmd, false)
|
||||
}
|
||||
|
||||
// killProcessTree force-terminates the whole process tree rooted at cmd.Process
|
||||
// via `taskkill /f /t`, so any descendant that ignored or outlived the graceful
|
||||
// request is killed alongside the parent rather than leaked as an orphan.
|
||||
func killProcessTree(cmd *exec.Cmd) error {
|
||||
return taskkillProcessTree(cmd, true)
|
||||
}
|
||||
|
||||
// taskkillProcessTree runs taskkill against cmd.Process.Pid. The /t flag
|
||||
// terminates the process together with any child processes it started, which is
|
||||
// the Windows analogue of signalling a Unix process group via its negative PID.
|
||||
// When force is true the /f flag force-kills; otherwise taskkill requests a
|
||||
// graceful close.
|
||||
func taskkillProcessTree(cmd *exec.Cmd, force bool) error {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
args := make([]string, 0, 4)
|
||||
if force {
|
||||
args = append(args, "/f")
|
||||
}
|
||||
args = append(args, "/t", "/pid", fmt.Sprintf("%d", cmd.Process.Pid))
|
||||
kill := exec.Command("taskkill", args...)
|
||||
setProcAttributes(kill)
|
||||
return kill.Run()
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
//go:build !windows
|
||||
|
||||
package process
|
||||
|
||||
// SetupTreeCleanup is a no-op on non-Windows platforms, where upstream process
|
||||
// teardown is handled via process-group signalling (see runtime_unix.go).
|
||||
func SetupTreeCleanup() error { return nil }
|
||||
@@ -0,0 +1,50 @@
|
||||
//go:build windows
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// SetupTreeCleanup assigns the current process to a Windows Job Object
|
||||
// configured with JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE. Upstream processes
|
||||
// spawned afterwards are associated with the same job, so when llama-swap exits
|
||||
// for any reason — graceful shutdown, a forced second Ctrl+C, or a crash — the
|
||||
// OS terminates the whole job and reaps every child instead of leaving orphans
|
||||
// behind. It is the parent-side complement to the per-process teardown in
|
||||
// runtime_windows.go.
|
||||
//
|
||||
// The job handle is intentionally leaked for the lifetime of the process: the
|
||||
// kill-on-close behaviour fires when the last handle is released, which the OS
|
||||
// does when the process exits.
|
||||
func SetupTreeCleanup() error {
|
||||
job, err := windows.CreateJobObject(nil, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("CreateJobObject: %w", err)
|
||||
}
|
||||
|
||||
info := windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION{
|
||||
BasicLimitInformation: windows.JOBOBJECT_BASIC_LIMIT_INFORMATION{
|
||||
LimitFlags: windows.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
|
||||
},
|
||||
}
|
||||
if _, err := windows.SetInformationJobObject(
|
||||
job,
|
||||
windows.JobObjectExtendedLimitInformation,
|
||||
uintptr(unsafe.Pointer(&info)),
|
||||
uint32(unsafe.Sizeof(info)),
|
||||
); err != nil {
|
||||
windows.CloseHandle(job)
|
||||
return fmt.Errorf("SetInformationJobObject: %w", err)
|
||||
}
|
||||
|
||||
if err := windows.AssignProcessToJobObject(job, windows.CurrentProcess()); err != nil {
|
||||
windows.CloseHandle(job)
|
||||
return fmt.Errorf("AssignProcessToJobObject: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,505 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
type shutdownReq struct {
|
||||
timeout time.Duration
|
||||
respond chan error
|
||||
}
|
||||
|
||||
type unloadReq struct {
|
||||
targets []string
|
||||
timeout time.Duration
|
||||
respond chan struct{}
|
||||
}
|
||||
|
||||
// baseRouter owns the channels, run-loop, and process machinery shared by every
|
||||
// concrete router. Concrete routers embed *baseRouter and supply a
|
||||
// scheduler.Swapper describing how eviction sets are decided. baseRouter
|
||||
// implements scheduler.Effects so the scheduler can call back for side-effects.
|
||||
type baseRouter struct {
|
||||
name string
|
||||
config config.Config
|
||||
processes map[string]process.Process
|
||||
logger *logmon.Monitor
|
||||
schedule scheduler.Scheduler
|
||||
|
||||
// shutdownCtx governs the request machinery: cancelling it tells grant()
|
||||
// and ServeHTTP to stop granting and reject callers. It is deliberately
|
||||
// separate from procCtx — see procCtx below.
|
||||
shutdownCtx context.Context
|
||||
shutdownFn context.CancelFunc
|
||||
shuttingDown atomic.Bool
|
||||
|
||||
// procCtx is the parent context for every managed process and governs
|
||||
// process lifetime only. handleShutdown stops processes gracefully via
|
||||
// Stop() and cancels procCtx afterwards, so teardown is never a context
|
||||
// cancel racing the graceful path (which collapsed the grace to 100ms and
|
||||
// let the caller return before children were reaped — see process run loop).
|
||||
procCtx context.Context
|
||||
procCancel context.CancelFunc
|
||||
|
||||
handlerCh chan scheduler.HandlerReq
|
||||
cancelCh chan scheduler.HandlerReq
|
||||
shutdownCh chan shutdownReq
|
||||
unloadCh chan unloadReq
|
||||
swapDoneCh chan scheduler.SwapDone
|
||||
serveDoneCh chan scheduler.ServeDoneEvent
|
||||
|
||||
runDone chan struct{}
|
||||
|
||||
// testProcessed, when non-nil, receives one event after each handlerReq
|
||||
// or swapDone has been fully processed by run(). Tests use it to wait
|
||||
// for run() to reach a deterministic state without sleeping. serveDone
|
||||
// events are intentionally NOT signalled here so test event counts
|
||||
// remain stable.
|
||||
testProcessed chan struct{}
|
||||
}
|
||||
|
||||
func newBaseRouter(
|
||||
name string,
|
||||
conf config.Config,
|
||||
processes map[string]process.Process,
|
||||
logger *logmon.Monitor,
|
||||
planner scheduler.Swapper,
|
||||
) (*baseRouter, error) {
|
||||
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
|
||||
procCtx, procCancel := context.WithCancel(context.Background())
|
||||
b := &baseRouter{
|
||||
name: name,
|
||||
config: conf,
|
||||
processes: processes,
|
||||
logger: logger,
|
||||
shutdownCtx: shutdownCtx,
|
||||
shutdownFn: shutdownFn,
|
||||
procCtx: procCtx,
|
||||
procCancel: procCancel,
|
||||
handlerCh: make(chan scheduler.HandlerReq),
|
||||
cancelCh: make(chan scheduler.HandlerReq),
|
||||
shutdownCh: make(chan shutdownReq),
|
||||
unloadCh: make(chan unloadReq),
|
||||
swapDoneCh: make(chan scheduler.SwapDone),
|
||||
serveDoneCh: make(chan scheduler.ServeDoneEvent),
|
||||
runDone: make(chan struct{}),
|
||||
}
|
||||
sched, err := scheduler.New(conf, name, logger, planner, b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b.schedule = sched
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (b *baseRouter) notifyProcessed() {
|
||||
if b.testProcessed != nil {
|
||||
b.testProcessed <- struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *baseRouter) run() {
|
||||
defer close(b.runDone)
|
||||
|
||||
for {
|
||||
select {
|
||||
case req := <-b.shutdownCh:
|
||||
b.handleShutdown(req)
|
||||
return
|
||||
|
||||
case req := <-b.handlerCh:
|
||||
b.schedule.OnRequest(req)
|
||||
b.notifyProcessed()
|
||||
|
||||
case req := <-b.cancelCh:
|
||||
b.schedule.OnCancel(req)
|
||||
b.notifyProcessed()
|
||||
|
||||
case req := <-b.unloadCh:
|
||||
b.schedule.OnUnload(req.targets, req.timeout)
|
||||
close(req.respond)
|
||||
b.notifyProcessed()
|
||||
|
||||
case ev := <-b.swapDoneCh:
|
||||
b.schedule.OnSwapDone(ev)
|
||||
b.notifyProcessed()
|
||||
|
||||
case ev := <-b.serveDoneCh:
|
||||
b.schedule.OnServeDone(ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// grant sends a response back to the caller of ServeHTTP and tells us
|
||||
// whether the caller was still there to receive it.
|
||||
//
|
||||
// Each ServeHTTP creates a fresh, UNBUFFERED respond channel and parks in
|
||||
// a select waiting on it. "Unbuffered" is the important word: a send only
|
||||
// completes when the other side is actively receiving. So if this send
|
||||
// succeeds, we know for a fact the caller picked up the response and will
|
||||
// act on it. If the caller has already given up (its request context was
|
||||
// cancelled, e.g. the HTTP client disconnected) or the router is shutting
|
||||
// down, the send never lands, one of the other select cases fires, and we
|
||||
// report back that the grant did NOT happen.
|
||||
//
|
||||
// That distinction matters for in-flight bookkeeping — see GrantServe.
|
||||
func (b *baseRouter) grant(req scheduler.HandlerReq, resp scheduler.HandlerResp) bool {
|
||||
select {
|
||||
case req.Respond <- resp:
|
||||
return true
|
||||
case <-req.Ctx.Done():
|
||||
return false
|
||||
case <-b.shutdownCtx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ModelState implements scheduler.Effects.
|
||||
func (b *baseRouter) ModelState(modelID string) (process.ProcessState, bool) {
|
||||
p, ok := b.processes[modelID]
|
||||
if !ok {
|
||||
var zero process.ProcessState
|
||||
return zero, false
|
||||
}
|
||||
return p.State(), true
|
||||
}
|
||||
|
||||
// StartSwap implements scheduler.Effects, launching the swap goroutine.
|
||||
func (b *baseRouter) StartSwap(modelID string, evict []string) {
|
||||
go b.doSwap(modelID, evict)
|
||||
}
|
||||
|
||||
// GrantError implements scheduler.Effects.
|
||||
func (b *baseRouter) GrantError(req scheduler.HandlerReq, err error) {
|
||||
b.grant(req, scheduler.HandlerResp{Err: err})
|
||||
}
|
||||
|
||||
// GrantServe implements scheduler.Effects. It hands the caller a wrapped
|
||||
// p.ServeHTTP (via trackedServe) so the run loop hears about the request
|
||||
// finishing, and reports whether the caller received it. The scheduler bumps
|
||||
// its in-flight count only on a true return: if grant() returns false the
|
||||
// caller already walked away and trackedServe will never run, so no matching
|
||||
// decrement will ever arrive — incrementing would strand the counter at >0 and
|
||||
// the router would never again be willing to evict this model.
|
||||
func (b *baseRouter) GrantServe(req scheduler.HandlerReq, modelID string) bool {
|
||||
p := b.processes[modelID]
|
||||
return b.grant(req, scheduler.HandlerResp{HandleFunc: b.trackedServe(modelID, p)})
|
||||
}
|
||||
|
||||
// StopProcesses implements scheduler.Effects, stopping the named processes in
|
||||
// parallel and blocking until all have stopped.
|
||||
func (b *baseRouter) StopProcesses(timeout time.Duration, ids []string) {
|
||||
var wg sync.WaitGroup
|
||||
for _, id := range ids {
|
||||
p, ok := b.processes[id]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(id string, p process.Process) {
|
||||
defer wg.Done()
|
||||
if err := p.Stop(timeout); err != nil {
|
||||
b.logger.Warnf("%s: stopping %s failed: %v", b.name, id, err)
|
||||
}
|
||||
}(id, p)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// trackedServe is the wrapper that closes the loop on in-flight tracking.
|
||||
// It runs p.ServeHTTP normally; the only added behaviour is a deferred
|
||||
// send on serveDoneCh after the handler returns. That send is what tells
|
||||
// the run loop "this model now has one fewer request in flight — go look
|
||||
// at the queue again, you may be able to start a swap you previously had
|
||||
// to defer."
|
||||
//
|
||||
// The select on shutdownCtx.Done() is a release valve: if the router is
|
||||
// already shutting down, nobody is reading serveDoneCh, so we drop the
|
||||
// notification rather than blocking the HTTP goroutine forever.
|
||||
func (b *baseRouter) trackedServe(modelID string, p process.Process) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
select {
|
||||
case b.serveDoneCh <- scheduler.ServeDoneEvent{ModelID: modelID}:
|
||||
case <-b.shutdownCtx.Done():
|
||||
}
|
||||
}()
|
||||
p.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *baseRouter) doSwap(modelID string, toStop []string) {
|
||||
timeout := b.healthCheckTimeout()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, mID := range toStop {
|
||||
wg.Add(1)
|
||||
go func(p process.Process, id string) {
|
||||
defer wg.Done()
|
||||
if err := p.Stop(timeout); err != nil {
|
||||
b.logger.Warnf("%s: stopping %s failed: %v", b.name, id, err)
|
||||
}
|
||||
}(b.processes[mID], mID)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
target := b.processes[modelID]
|
||||
if target.State() == process.StateStopped {
|
||||
go func() {
|
||||
if err := target.Run(timeout); err != nil {
|
||||
b.logger.Warnf("%s: running %s exited: %v", b.name, modelID, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
err := target.WaitReady(b.shutdownCtx)
|
||||
|
||||
select {
|
||||
case b.swapDoneCh <- scheduler.SwapDone{ModelID: modelID, Err: err}:
|
||||
case <-b.shutdownCtx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
func (b *baseRouter) handleShutdown(req shutdownReq) {
|
||||
shutdownErr := fmt.Errorf("%s is shutting down", b.name)
|
||||
|
||||
// Cancel shutdownCtx first so any waiter that is currently parked on
|
||||
// its respond channel can exit via its own shutdownCtx.Done() branch.
|
||||
// The OnShutdown grants below then either land (waiter happened to receive
|
||||
// before noticing shutdown) or fall through immediately via grant's
|
||||
// shutdownCtx case — either way the waiter sees a non-OK response.
|
||||
// This does NOT touch processes: their lifetime is procCtx, cancelled
|
||||
// only after the graceful Stop() calls below have reaped them.
|
||||
b.shutdownFn()
|
||||
|
||||
b.schedule.OnShutdown(shutdownErr)
|
||||
|
||||
stopTimeout := req.timeout
|
||||
if stopTimeout <= 0 {
|
||||
stopTimeout = b.healthCheckTimeout()
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i, p := range b.processes {
|
||||
wg.Add(1)
|
||||
go func(id string, p process.Process) {
|
||||
defer wg.Done()
|
||||
if err := p.Stop(stopTimeout); err != nil {
|
||||
b.logger.Warnf("%s failed to stop process %s: %v", b.name, id, err)
|
||||
}
|
||||
}(i, p)
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
if req.timeout > 0 {
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(req.timeout):
|
||||
<-done
|
||||
}
|
||||
} else {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Every process is stopped (children reaped via Stop()). Cancel procCtx so
|
||||
// the process run-loop goroutines exit; they are already StateStopped, so
|
||||
// this is a clean no-op kill rather than a forced teardown.
|
||||
b.procCancel()
|
||||
|
||||
req.respond <- nil
|
||||
}
|
||||
|
||||
func (b *baseRouter) healthCheckTimeout() time.Duration {
|
||||
t := time.Duration(b.config.HealthCheckTimeout) * time.Second
|
||||
if t <= 0 {
|
||||
return 30 * time.Second
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func (b *baseRouter) Handles(model string) bool {
|
||||
_, ok := b.processes[model]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (b *baseRouter) ProcessLogger(modelID string) (*logmon.Monitor, bool) {
|
||||
if p, ok := b.processes[modelID]; ok {
|
||||
return p.Logger(), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// RunningModels returns the current state of every process that is not stopped
|
||||
// or shut down. The processes map keys are fixed at construction and State()
|
||||
// is a snapshot, so this is safe to call without the run loop.
|
||||
func (b *baseRouter) RunningModels() map[string]process.ProcessState {
|
||||
running := make(map[string]process.ProcessState)
|
||||
for id, p := range b.processes {
|
||||
st := p.State()
|
||||
if st == process.StateStopped || st == process.StateShutdown {
|
||||
continue
|
||||
}
|
||||
running[id] = st
|
||||
}
|
||||
return running
|
||||
}
|
||||
|
||||
// Unload stops the named models, or every running model when none are named.
|
||||
// It blocks until each targeted process has stopped.
|
||||
//
|
||||
// The request is funneled through the run loop so eviction is coordinated
|
||||
// with the rest of the router's state: pending swap waiters for an
|
||||
// unloaded model are released with an error, queued requests for unloaded
|
||||
// models are dropped, and any deferred swaps that were waiting on those
|
||||
// models become eligible to start.
|
||||
//
|
||||
// In-flight requests being served by an unloaded process are not waited
|
||||
// for — Stop kills the upstream, those callers see whatever error the
|
||||
// reverse proxy surfaces and may retry. Their trackedServe defers fire
|
||||
// normally and decrement inFlight as the dying handlers return.
|
||||
func (b *baseRouter) Unload(timeout time.Duration, models ...string) {
|
||||
targets := models
|
||||
if len(targets) == 0 {
|
||||
targets = make([]string, 0, len(b.processes))
|
||||
for id := range b.processes {
|
||||
targets = append(targets, id)
|
||||
}
|
||||
}
|
||||
if len(targets) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
req := unloadReq{targets: targets, timeout: timeout, respond: make(chan struct{})}
|
||||
select {
|
||||
case b.unloadCh <- req:
|
||||
case <-b.runDone:
|
||||
return
|
||||
}
|
||||
<-req.respond
|
||||
}
|
||||
|
||||
func (b *baseRouter) Shutdown(timeout time.Duration) error {
|
||||
if !b.shuttingDown.CompareAndSwap(false, true) {
|
||||
return fmt.Errorf("%s shutdown already in progress", b.name)
|
||||
}
|
||||
req := shutdownReq{timeout: timeout, respond: make(chan error, 1)}
|
||||
select {
|
||||
case b.shutdownCh <- req:
|
||||
case <-b.runDone:
|
||||
return nil
|
||||
}
|
||||
return <-req.respond
|
||||
}
|
||||
|
||||
func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
if b.shuttingDown.Load() {
|
||||
shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||
return
|
||||
}
|
||||
|
||||
data, err := shared.FetchContext(req, b.config)
|
||||
if err != nil {
|
||||
shared.SendError(w, req, err)
|
||||
return
|
||||
}
|
||||
|
||||
hr := scheduler.HandlerReq{
|
||||
Model: data.ModelID,
|
||||
Ctx: req.Context(),
|
||||
// Unbuffered: a successful send on Respond proves the waiter is
|
||||
// alive and consuming. grant() relies on this to avoid handing a
|
||||
// handleFunc to a cancelled waiter and leaking the inFlight count.
|
||||
Respond: make(chan scheduler.HandlerResp),
|
||||
PositionCh: make(chan int, 1),
|
||||
}
|
||||
|
||||
select {
|
||||
case b.handlerCh <- hr:
|
||||
case <-req.Context().Done():
|
||||
return
|
||||
case <-b.shutdownCtx.Done():
|
||||
shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||
return
|
||||
}
|
||||
|
||||
isModelReady := false
|
||||
if p, ok := b.processes[data.ModelID]; ok {
|
||||
isModelReady = p.State() == process.StateReady
|
||||
}
|
||||
shouldShowLoading := data.Streaming && data.SendLoadingState && isLoadingPath(req.URL.Path) && !isModelReady
|
||||
|
||||
var lw *loadingWriter
|
||||
cancelLoad := func() {}
|
||||
if shouldShowLoading {
|
||||
var swapCtx context.Context
|
||||
swapCtx, cancelLoad = context.WithCancel(req.Context())
|
||||
lw = newLoadingWriter(b.logger, data.ModelID, w, req)
|
||||
go lw.start(swapCtx)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case pos := <-hr.PositionCh:
|
||||
lw.setUpdate(fmt.Sprintf("Queue position: #%d", pos))
|
||||
case <-swapCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// finishLoading stops the loading stream and fences its goroutine off from
|
||||
// the ResponseWriter before the real handler (or ServeHTTP's return)
|
||||
// reclaims it. release() must run even when waitForCompletion times out:
|
||||
// otherwise a still-streaming goroutine flushes a finalized response and
|
||||
// panics on the recycled *bufio.Writer.
|
||||
finishLoading := func() {
|
||||
cancelLoad()
|
||||
if lw != nil {
|
||||
lw.waitForCompletion(1 * time.Second)
|
||||
lw.release()
|
||||
}
|
||||
}
|
||||
|
||||
var resp scheduler.HandlerResp
|
||||
select {
|
||||
case resp = <-hr.Respond:
|
||||
finishLoading()
|
||||
case <-req.Context().Done():
|
||||
finishLoading()
|
||||
// Notify the scheduler so it can prune this request from its queue
|
||||
// and swap waiters. Without this, a queued request whose client left
|
||||
// would sit in the scheduler until drainQueue eventually starts a
|
||||
// wasted model load for it.
|
||||
select {
|
||||
case b.cancelCh <- hr:
|
||||
case <-b.shutdownCtx.Done():
|
||||
}
|
||||
return
|
||||
case <-b.shutdownCtx.Done():
|
||||
finishLoading()
|
||||
shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Err != nil {
|
||||
shared.SendError(w, req, resp.Err)
|
||||
return
|
||||
}
|
||||
resp.HandleFunc(w, req)
|
||||
}
|
||||
@@ -0,0 +1,264 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
|
||||
)
|
||||
|
||||
// These tests cover baseRouter's own machinery — the run loop, process
|
||||
// lifecycle (doSwap), grant/ServeHTTP plumbing, Unload, and Shutdown. The
|
||||
// scheduling decision logic (queueing, collation, eviction collisions) lives in
|
||||
// the scheduler package and is tested directly there; see fifo_test.go.
|
||||
|
||||
// stubPlanner evicts nothing. baseRouter tests drive the run loop through the
|
||||
// default FIFO scheduler without exercising any particular eviction policy.
|
||||
type stubPlanner struct{}
|
||||
|
||||
func (s *stubPlanner) EvictionFor(string, []string) []string { return nil }
|
||||
func (s *stubPlanner) OnSwapStart(string, []string) {}
|
||||
|
||||
func newTestBase(t *testing.T, processes map[string]process.Process, planner scheduler.Swapper) *baseRouter {
|
||||
t.Helper()
|
||||
conf := config.Config{HealthCheckTimeout: 5}
|
||||
b, err := newBaseRouter("test", conf, processes, logmon.NewWriter(io.Discard), planner)
|
||||
if err != nil {
|
||||
t.Fatalf("newBaseRouter: %v", err)
|
||||
}
|
||||
b.testProcessed = make(chan struct{}, 64)
|
||||
go b.run()
|
||||
t.Cleanup(func() {
|
||||
if !b.shuttingDown.Load() {
|
||||
_ = b.Shutdown(time.Second)
|
||||
}
|
||||
})
|
||||
return b
|
||||
}
|
||||
|
||||
func TestBaseRouter_RunningModels(t *testing.T) {
|
||||
ready := newFakeProcess("ready")
|
||||
ready.markReady()
|
||||
starting := newFakeProcess("starting")
|
||||
starting.setState(process.StateStarting)
|
||||
stopped := newFakeProcess("stopped")
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{
|
||||
"ready": ready, "starting": starting, "stopped": stopped,
|
||||
}, &stubPlanner{})
|
||||
|
||||
running := b.RunningModels()
|
||||
if len(running) != 2 {
|
||||
t.Fatalf("running=%v want 2 entries", running)
|
||||
}
|
||||
if running["ready"] != process.StateReady {
|
||||
t.Errorf("ready state=%q want ready", running["ready"])
|
||||
}
|
||||
if running["starting"] != process.StateStarting {
|
||||
t.Errorf("starting state=%q want starting", running["starting"])
|
||||
}
|
||||
if _, ok := running["stopped"]; ok {
|
||||
t.Errorf("stopped process should be excluded from RunningModels")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_UnloadAll(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
c := newFakeProcess("c")
|
||||
c.markReady()
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "c": c}, &stubPlanner{})
|
||||
b.Unload(time.Second)
|
||||
|
||||
if a.State() != process.StateStopped || c.State() != process.StateStopped {
|
||||
t.Fatalf("Unload() should stop every process: a=%q c=%q", a.State(), c.State())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_UnloadSpecificModel(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
c := newFakeProcess("c")
|
||||
c.markReady()
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "c": c}, &stubPlanner{})
|
||||
b.Unload(time.Second, "a")
|
||||
|
||||
if a.State() != process.StateStopped {
|
||||
t.Errorf("a should be stopped, got %q", a.State())
|
||||
}
|
||||
if c.State() != process.StateReady {
|
||||
t.Errorf("c should remain ready, got %q", c.State())
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_Unload_StopsInParallel verifies that Unload fans out its
|
||||
// Stop calls concurrently rather than stopping each process serially. Each
|
||||
// fakeProcess.Stop is pinned via stopBlock; the test only releases them
|
||||
// after observing every stopStarted, proving all three Stops were in
|
||||
// flight simultaneously.
|
||||
func TestBaseRouter_Unload_StopsInParallel(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
a.stopBlock = make(chan struct{})
|
||||
pb := newFakeProcess("b")
|
||||
pb.markReady()
|
||||
pb.stopBlock = make(chan struct{})
|
||||
pc := newFakeProcess("c")
|
||||
pc.markReady()
|
||||
pc.stopBlock = make(chan struct{})
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, &stubPlanner{})
|
||||
|
||||
unloadDone := make(chan struct{})
|
||||
go func() {
|
||||
b.Unload(time.Second, "a", "b", "c")
|
||||
close(unloadDone)
|
||||
}()
|
||||
|
||||
// All three Stop calls must start before any of them are allowed to
|
||||
// complete. If Unload was serial, only one stopStarted would fire
|
||||
// until we released its stopBlock, and this would deadlock.
|
||||
for _, p := range []*fakeProcess{a, pb, pc} {
|
||||
select {
|
||||
case <-p.stopStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("Stop on %s never started — Unload is not parallel", p.id)
|
||||
}
|
||||
}
|
||||
|
||||
// Release them; Unload should now return.
|
||||
close(a.stopBlock)
|
||||
close(pb.stopBlock)
|
||||
close(pc.stopBlock)
|
||||
|
||||
select {
|
||||
case <-unloadDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Unload did not return after stops released")
|
||||
}
|
||||
|
||||
for _, p := range []*fakeProcess{a, pb, pc} {
|
||||
if p.State() != process.StateStopped {
|
||||
t.Errorf("%s state=%q want stopped", p.id, p.State())
|
||||
}
|
||||
if got := p.stopCalls.Load(); got != 1 {
|
||||
t.Errorf("%s stopCalls=%d want 1", p.id, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_OnDemandStart(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.autoReady = true
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
b.ServeHTTP(w, newRequest("a"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := a.runCalls.Load(); got != 1 {
|
||||
t.Errorf("runCalls=%d want 1", got)
|
||||
}
|
||||
if got := a.serveCalls.Load(); got != 1 {
|
||||
t.Errorf("serveCalls=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_ContextCancel(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
// autoReady=false so swap parks forever until we mark ready.
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w1, newRequestCtx(ctx, "a"))
|
||||
close(done1)
|
||||
}()
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w2, newRequest("a"))
|
||||
close(done2)
|
||||
}()
|
||||
|
||||
waitProcessed(t, b.testProcessed, 2) // both requests joined the active swap
|
||||
<-a.runStarted
|
||||
|
||||
cancel()
|
||||
select {
|
||||
case <-done1:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("cancelled ServeHTTP did not return after ctx cancel")
|
||||
}
|
||||
|
||||
a.markReady()
|
||||
select {
|
||||
case <-done2:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("non-cancelled ServeHTTP did not complete after swap")
|
||||
}
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Errorf("second request status=%d body=%q", w2.Code, w2.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_ModelNotFound(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
b.ServeHTTP(w, newRequest("unknown"))
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("status=%d want %d body=%q", w.Code, http.StatusNotFound, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_Shutdown_StopsAllProcesses(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
go a.Run(0)
|
||||
pb := newFakeProcess("b")
|
||||
pb.markReady()
|
||||
go pb.Run(0)
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, &stubPlanner{})
|
||||
|
||||
if err := b.Shutdown(time.Second); err != nil {
|
||||
t.Fatalf("Shutdown: %v", err)
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 1 {
|
||||
t.Errorf("a.stopCalls=%d want 1", got)
|
||||
}
|
||||
if got := pb.stopCalls.Load(); got != 1 {
|
||||
t.Errorf("b.stopCalls=%d want 1", got)
|
||||
}
|
||||
|
||||
// Subsequent ServeHTTP should report 5xx.
|
||||
w := httptest.NewRecorder()
|
||||
b.ServeHTTP(w, newRequest("a"))
|
||||
if w.Code != http.StatusInternalServerError && w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("post-shutdown status=%d want 5xx body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Second Shutdown should report already in progress.
|
||||
if err := b.Shutdown(0); err == nil {
|
||||
t.Errorf("second Shutdown returned nil, want error")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,404 @@
|
||||
# Router design
|
||||
|
||||
A developer tutorial for the `internal/router` package and its `scheduler`
|
||||
sub-package.
|
||||
|
||||
## Intro
|
||||
|
||||
A llama-swap router is the component that sits behind the proxy and answers one
|
||||
question for every incoming request: _can this model serve right now, and if
|
||||
not, what has to happen first?_ Answering it means juggling three concerns that
|
||||
used to live tangled together in one type:
|
||||
|
||||
1. **Process machinery** — owning the OS processes, starting and stopping them,
|
||||
running health checks, and shuttling HTTP requests onto the right upstream.
|
||||
2. **Scheduling strategy** — the queue, in-flight bookkeeping, and the decision
|
||||
tree that turns one request into "serve now", "join an existing swap",
|
||||
"queue", or "start a swap".
|
||||
3. **Eviction policy** — given a model we want to load, which currently-running
|
||||
models have to be stopped to make room?
|
||||
|
||||
The design pulls those three apart into separate, independently replaceable
|
||||
pieces:
|
||||
|
||||
| Concern | Type | Lives in |
|
||||
| ------------------- | ------------------------------ | ------------------------------- |
|
||||
| Process machinery | `baseRouter` | `internal/router/base.go` |
|
||||
| Scheduling strategy | `scheduler.Scheduler` (`FIFO`) | `internal/router/scheduler/` |
|
||||
| Eviction policy | `scheduler.Swapper` | `groupSwapper`, `matrixSwapper` |
|
||||
|
||||
`baseRouter` keeps the channels, run loop, process lifecycle, and shutdown
|
||||
teardown, and exposes the side-effects a scheduler needs through the
|
||||
`scheduler.Effects` interface. The scheduler owns the queue and decision tree
|
||||
but performs no side-effects directly — it calls back through `Effects`. The
|
||||
`Swapper` is a pure function from "target model + currently running" to "models
|
||||
to evict", and knows nothing about queues, channels, or processes.
|
||||
|
||||
Because the seams are interfaces, you can replace the scheduling strategy
|
||||
without touching process management, or write a new eviction policy without
|
||||
touching either. `FIFO` is the first and currently only `Scheduler`;
|
||||
`groupSwapper` and `matrixSwapper` are the two `Swapper`s.
|
||||
|
||||
## Key concepts
|
||||
|
||||
### One run loop, no locks
|
||||
|
||||
`baseRouter.run()` is a single goroutine selecting over a handful of channels:
|
||||
|
||||
```go
|
||||
for {
|
||||
select {
|
||||
case req := <-b.shutdownCh: b.handleShutdown(req); return
|
||||
case req := <-b.handlerCh: b.schedule.OnRequest(req)
|
||||
case req := <-b.unloadCh: b.schedule.OnUnload(req.targets, req.timeout); close(req.respond)
|
||||
case ev := <-b.swapDoneCh: b.schedule.OnSwapDone(ev)
|
||||
case ev := <-b.serveDoneCh: b.schedule.OnServeDone(ev)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Every `Scheduler` method runs on this one goroutine. That is the single most
|
||||
important fact about the design: **the scheduler never needs a mutex for its own
|
||||
state**. All scheduler state is touched only from these callbacks, which are
|
||||
serialized by the run loop. If you write a new scheduler, you get the same
|
||||
guarantee for free — and you must not break it by spinning up goroutines that
|
||||
mutate scheduler state.
|
||||
|
||||
### Events flow in, side-effects flow out
|
||||
|
||||
The run loop turns external happenings into method calls on the scheduler:
|
||||
|
||||
- A new HTTP request becomes `OnRequest(HandlerReq)`.
|
||||
- A swap goroutine finishing becomes `OnSwapDone(SwapDone)`.
|
||||
- A tracked request handler returning becomes `OnServeDone(ServeDoneEvent)`.
|
||||
- An admin unload becomes `OnUnload(targets, timeout)`.
|
||||
- Shutdown becomes `OnShutdown(err)`.
|
||||
|
||||
The scheduler reacts by calling **back out** through `Effects`: inspect a
|
||||
process state, start a swap, grant a response to a caller, or stop processes. It
|
||||
never calls `process.Process` directly and never writes to a channel directly.
|
||||
This keeps the scheduler pure enough to unit-test against a fake `Effects` with
|
||||
no goroutines or real processes involved (see `scheduler/fifo_test.go`).
|
||||
|
||||
```
|
||||
HTTP request admin Unload / Shutdown
|
||||
│ │
|
||||
▼ ▼
|
||||
ServeHTTP ──HandlerReq──▶ baseRouter.run() ◀──unloadCh/shutdownCh
|
||||
│ (single goroutine)
|
||||
▼
|
||||
Scheduler.On*(...)
|
||||
│ calls back through
|
||||
▼
|
||||
Effects: ModelState / StartSwap /
|
||||
GrantServe / GrantError / StopProcesses
|
||||
│
|
||||
▼
|
||||
baseRouter side-effects: doSwap goroutine,
|
||||
grant() to caller, process.Stop()
|
||||
│
|
||||
swap completes ──SwapDone──▶ back into run loop
|
||||
```
|
||||
|
||||
### The swap goroutine
|
||||
|
||||
Scheduling decisions must be quick and non-blocking, but loading a model is
|
||||
slow. The two are reconciled by doing the slow part on a separate goroutine.
|
||||
|
||||
When the scheduler decides to start a swap, inside `OnRequest` it:
|
||||
|
||||
1. records "a swap for X is in flight" in its own state, then
|
||||
2. calls `Effects.StartSwap(modelID, evict)`.
|
||||
|
||||
`StartSwap` does **not** load the model itself — it just launches a detached
|
||||
goroutine (`doSwap`) and returns straight away. `doSwap` is what does the slow
|
||||
work: stop the evicted processes, start the target, wait for it to become ready.
|
||||
Because `StartSwap` returned immediately, `OnRequest` returns too, and the run
|
||||
loop is free to pick up the next event — another request, a serve-done, an
|
||||
unload — while `doSwap` runs in the background.
|
||||
|
||||
The swap's eventual result comes back as just another event: when `doSwap`
|
||||
finishes it posts a `SwapDone` onto `swapDoneCh`, which the run loop delivers as
|
||||
`OnSwapDone`. So a slow load never blocks the run loop; it brackets it with two
|
||||
quick events (`OnRequest` to start, `OnSwapDone` to finish) and everything in
|
||||
between is handled normally.
|
||||
|
||||
### In-flight tracking and `trackedServe`
|
||||
|
||||
When the scheduler grants a request, the handler it hands back is wrapped by
|
||||
`baseRouter.trackedServe`. The wrapper runs the real `ServeHTTP` and, on return,
|
||||
posts a `ServeDoneEvent` so the run loop can decrement the per-model in-flight
|
||||
count. This is why the scheduler can know whether a process is "busy": it counts
|
||||
grants out and serve-dones in. A swap that would evict a busy process is
|
||||
deferred until that process's in-flight count hits zero (`OnServeDone` then
|
||||
re-drains the queue).
|
||||
|
||||
The subtle contract here is `GrantServe`'s boolean return. The caller's
|
||||
`Respond` channel is unbuffered, so a successful send proves the HTTP goroutine
|
||||
is alive and took the handler. If the caller already disconnected, the send
|
||||
fails, `trackedServe` never runs, and **no** `ServeDoneEvent` will ever arrive —
|
||||
so the scheduler must only increment `inFlight` when `GrantServe` returns true.
|
||||
Incrementing on a false return would strand the counter above zero and the model
|
||||
could never be evicted again.
|
||||
|
||||
## The interfaces
|
||||
|
||||
All three live in `scheduler/scheduler.go`.
|
||||
|
||||
### `Scheduler`
|
||||
|
||||
```go
|
||||
type Scheduler interface {
|
||||
OnRequest(req HandlerReq)
|
||||
OnSwapDone(ev SwapDone)
|
||||
OnServeDone(ev ServeDoneEvent)
|
||||
OnUnload(targets []string, timeout time.Duration)
|
||||
OnShutdown(err error)
|
||||
}
|
||||
```
|
||||
|
||||
Owns the queue, in-flight tracking, and the decision tree. All methods run on
|
||||
the run-loop goroutine, so no internal locking is needed.
|
||||
|
||||
### `Swapper`
|
||||
|
||||
```go
|
||||
type Swapper interface {
|
||||
EvictionFor(target string, running []string) []string
|
||||
OnSwapStart(target string, running []string)
|
||||
}
|
||||
```
|
||||
|
||||
The eviction policy. `EvictionFor` is a **pure decision** — given the target and
|
||||
the complete `running` set, return the running model IDs that must stop. It must
|
||||
not log or mutate anything, and it does **not** inspect process state itself:
|
||||
the scheduler hands it `running` already assembled (every non-stopped process,
|
||||
unioned with the targets of in-flight swaps already committed but not yet
|
||||
visible in process state). That keeps the swapper a pure function of its inputs,
|
||||
with no reference to processes.
|
||||
|
||||
The reason it must not log is that it is a _speculative_ query — "what would we
|
||||
evict if we started this swap right now?" — called far more often than swaps
|
||||
actually happen. The scheduler calls it once per incoming request, and then
|
||||
**again for every still-queued request on every queue drain** (each `OnSwapDone`,
|
||||
`OnServeDone`, and `OnUnload`). Most of those calls end in "still queued",
|
||||
"collides", or "nothing to evict", not a real swap. Logging there would emit
|
||||
duplicate lines for a request that simply sits in the queue, and lines for
|
||||
decisions that never happen — the log would stop meaning "a swap occurred".
|
||||
|
||||
`OnSwapStart` is the one place a Swapper may log, because it is called exactly
|
||||
once, at the moment a swap is committed. One log line there equals one real swap,
|
||||
with the evict set that is genuinely being applied — which is why `matrixSwapper`
|
||||
re-solves and logs the full decision (set, DSL, cost) in `OnSwapStart` rather
|
||||
than in `EvictionFor`.
|
||||
|
||||
### `Effects`
|
||||
|
||||
```go
|
||||
type Effects interface {
|
||||
ModelState(modelID string) (process.ProcessState, bool)
|
||||
RunningModels() map[string]process.ProcessState
|
||||
StartSwap(modelID string, evict []string)
|
||||
GrantError(req HandlerReq, err error)
|
||||
GrantServe(req HandlerReq, modelID string) bool
|
||||
StopProcesses(timeout time.Duration, ids []string)
|
||||
}
|
||||
```
|
||||
|
||||
Implemented by `baseRouter`. This is the scheduler's entire window onto the
|
||||
outside world; everything else about the router is hidden from it. See the
|
||||
deep-dive below.
|
||||
|
||||
### `Factory` — wiring it together
|
||||
|
||||
```go
|
||||
type Factory func(name string, logger *logmon.Monitor, eff Effects) Scheduler
|
||||
```
|
||||
|
||||
`baseRouter` doesn't know which scheduler or swapper it has — it is handed a
|
||||
`Factory` at construction and calls it once, passing itself as the `Effects`.
|
||||
The concrete router captures its `Swapper` in the closure. From `group.go`:
|
||||
|
||||
```go
|
||||
swapper := &groupSwapper{ /* ... */ }
|
||||
base := newBaseRouter("group", conf, processes, proxylog,
|
||||
func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler {
|
||||
return scheduler.NewFIFO(name, logger, swapper, eff)
|
||||
})
|
||||
```
|
||||
|
||||
This closure is the single point where the three pieces meet: it binds a
|
||||
specific `Swapper` (`swapper`) and a specific `Scheduler` (`FIFO`) to the
|
||||
`baseRouter`'s `Effects` (`eff`).
|
||||
|
||||
**The swapper is a separate type from the concrete router.** There are currently two router implementations router.Group and router.Matrix. Each of these has a custom swapper that implements scheduler.Swapper for custom eviction logic. This decoupling of responsibilities makes it easy to implement custom swapping strategies.
|
||||
|
||||
### The events
|
||||
|
||||
A single goroutine in `baseRouter.run()` owns and serializes all state changes in the router. By processing events one at a time it ensures correctness and eliminates complex mutex lock logic.
|
||||
|
||||
These are the events the router currently uses:
|
||||
|
||||
```go
|
||||
type HandlerReq struct { // one in-flight ServeHTTP awaiting a decision
|
||||
Model string
|
||||
Ctx context.Context
|
||||
Respond chan HandlerResp // UNBUFFERED — see GrantServe contract
|
||||
PositionCh chan int // queue-position updates for the loading UI
|
||||
}
|
||||
|
||||
type HandlerResp struct { // the decision handed back to the caller
|
||||
HandleFunc http.HandlerFunc // serve with this, or...
|
||||
Err error // ...fail with this
|
||||
}
|
||||
|
||||
type SwapDone struct{ ModelID string; Err error } // swap goroutine finished
|
||||
type ServeDoneEvent struct{ ModelID string } // tracked handler returned
|
||||
```
|
||||
|
||||
## Deep-dive: the `Effects` interface and why it exists
|
||||
|
||||
`Effects` is the inversion-of-control boundary that makes the split possible.
|
||||
The scheduler decides and `baseRouter` _acts_. Pulling the side-effects behind this
|
||||
interface buys three things:
|
||||
|
||||
1. **Purity and testability.** The scheduler performs no I/O, starts no
|
||||
goroutines of its own, and touches no real processes. Its tests drive the
|
||||
`On*` methods directly and assert on a `fakeEffects` that just records the
|
||||
calls — synchronous, deterministic, no sleeps. (`scheduler/fifo_test.go`.)
|
||||
2. **A single, auditable side-effect surface.** Every externally-visible thing a
|
||||
scheduler can do is one of six methods. You can reason about the whole
|
||||
contract by reading one interface.
|
||||
3. **Decoupling lifetime.** The scheduler never holds a `process.Process`,
|
||||
never sees a channel, and never learns how shutdown teardown works. It only
|
||||
knows model IDs and states.
|
||||
|
||||
Method by method, as implemented in `base.go`:
|
||||
|
||||
- **`ModelState(modelID) (state, ok)`** — read-only snapshot of a process's
|
||||
state, and whether this router handles the model at all. The scheduler uses it
|
||||
for the "unknown model" check and the "already ready" fast path. Safe to call
|
||||
any time because the process map is fixed at construction and `State()` is a
|
||||
snapshot.
|
||||
|
||||
- **`RunningModels()`** — the state of every process that isn't stopped or shut
|
||||
down. The scheduler unions its keys with its own in-flight swap targets to
|
||||
build the `running` set it hands the `Swapper`, so the swapper never has to
|
||||
touch process state itself.
|
||||
|
||||
- **`StartSwap(modelID, evict)`** — fire-and-forget. `baseRouter` launches the
|
||||
`doSwap` goroutine and returns immediately; the result comes back later as a
|
||||
`SwapDone`. The scheduler records the swap as active _before_ calling this so
|
||||
that requests arriving in the meantime can join it.
|
||||
|
||||
- **`GrantError(req, err)`** — hand a caller an error response. Used for unknown
|
||||
models, failed swaps, unloads, and shutdown.
|
||||
|
||||
- **`GrantServe(req, modelID) bool`** — hand a caller the tracked handler for a
|
||||
ready model, returning whether the caller was still there to receive it. The
|
||||
scheduler increments the in-flight count **only on a true return** (see the
|
||||
in-flight contract above). This is the one `Effects` method whose return value
|
||||
carries state-machine significance.
|
||||
|
||||
- **`StopProcesses(timeout, ids)`** — stop processes in parallel and **block**
|
||||
until all have stopped. Used by `OnUnload` so an admin `Unload` call can
|
||||
guarantee the process is dead by the time it returns. (Note `StartSwap` is
|
||||
async but `StopProcesses` is sync — the difference is deliberate and tied to
|
||||
the caller's expectations.)
|
||||
|
||||
A useful way to hold it in your head: `Effects` is the scheduler's syscall
|
||||
table. The scheduler is a pure state machine; `Effects` is how it touches the
|
||||
world, and `baseRouter` is the kernel that implements those syscalls with real
|
||||
goroutines, channels, and processes.
|
||||
|
||||
## How to implement a new `Swapper`
|
||||
|
||||
A `Swapper` is a pure decision function plus a logging hook — the easiest of the three pieces to replace.
|
||||
|
||||
1. **Write the swapper type** and give it whatever config it needs to make a
|
||||
decision. It does **not** need the process map — the scheduler supplies the
|
||||
running set as an argument. `groupSwapper` holds only its group config;
|
||||
`matrixSwapper` holds only its solver and logger:
|
||||
|
||||
```go
|
||||
type mySwapper struct {
|
||||
config config.Config
|
||||
}
|
||||
```
|
||||
|
||||
2. **Implement `EvictionFor(target, running)`** as a _pure_ decision:
|
||||
- `running` is the complete live set, already assembled for you: every
|
||||
non-stopped process unioned with the targets of in-flight swaps the
|
||||
scheduler has committed to. You don't filter process state or fold in
|
||||
in-flight targets yourself, that's the scheduler's job. Just decide against the slice you're handed.
|
||||
- Return the list of model IDs in `running` that must stop for `target` to
|
||||
run. Return `nil`/empty when nothing needs evicting.
|
||||
- Do **not** mutate state here.
|
||||
- Do **not** log here. It can be called multiple times per request. Since it is pure function have tests verify the expected behaviour.
|
||||
|
||||
3. **Implement `OnSwapStart(target, running)`** — called once when a swap
|
||||
actually begins, with the same `running` set `EvictionFor` saw. This is the
|
||||
right place to log: one call equals one real swap. `matrixSwapper` re-solves
|
||||
and logs the chosen set and cost here; `groupSwapper` logs nothing.
|
||||
|
||||
4. **Wire it in** by instantiating the swapper in your router's constructor and
|
||||
capturing it in the `Factory` closure passed to `newBaseRouter` — exactly as
|
||||
`NewGroup` and `NewMatrix` do. The router struct itself only ever embeds
|
||||
`*baseRouter`; the swapper reaches the scheduler solely through that closure.
|
||||
|
||||
Reference implementations: `groupSwapper` (static group config) in `group.go`
|
||||
and `matrixSwapper` (cost-based set solver) in `matrix.go`.
|
||||
|
||||
## How to implement a new `Scheduler`
|
||||
|
||||
Replacing the scheduler means taking over the queue and the entire decision tree. Read `scheduler/fifo.go` end to end first — it is the reference implementation and the rules below are easiest to understand in context.
|
||||
|
||||
The rules you must honour:
|
||||
|
||||
- **Single goroutine.** Every method runs on the `baseRouter.run()` goroutine. Keep your state in plain maps/slices and never read or write it from another goroutine. If you need slow work done, hand it to `Effects.StartSwap` and react to the resulting `SwapDone` — do not block a method waiting for it.
|
||||
|
||||
- **Never block the run loop.** `OnRequest`, `OnSwapDone`, and `OnServeDone` must make a decision and return. The one method allowed to block is `OnUnload`, and only because it must wait on the synchronous `StopProcesses` so the admin caller's guarantee holds.
|
||||
|
||||
- **Respect the `GrantServe` boolean.** Only count a request as in-flight when `GrantServe` returns true (see the in-flight contract above). A false return means the caller is gone; no `ServeDoneEvent` will ever arrive, so incrementing on false permanently strands the counter.
|
||||
|
||||
- **Account for in-flight swaps in your running set.** When you call `Swapper.EvictionFor`, the running set you pass must include not just live processes (`Effects.RunningModels`) but also the targets of swaps you've already started that aren't yet visible in process state — otherwise the swapper contradicts decisions already in motion.
|
||||
|
||||
What each method must do:
|
||||
|
||||
- **`OnRequest(req)`** — every request must resolve to exactly one of: granted, errored, joined (piggybacks an in-flight swap), queued, or swap-started. No request may be silently dropped.
|
||||
|
||||
- **`OnSwapDone(ev)`** — deliver the result to every waiter that joined this swap (grant on success, error on `ev.Err`), drop the swap from active tracking, then re-examine anything queued — a finished swap may have unblocked it.
|
||||
|
||||
- **`OnServeDone(ev)`** — decrement the model's in-flight count; when it hits zero, re-examine the queue. Do **not** clear in-flight counts by hand; the handlers post their own `ServeDoneEvent`s on return.
|
||||
|
||||
- **`OnUnload(targets, timeout)`** — error out any waiters or queued requests for the unloaded models, call `Effects.StopProcesses` (synchronously — the admin caller relies on the process being dead afterwards), then re-examine the queue.
|
||||
|
||||
- **`OnShutdown(err)`** — error out every waiter you still hold (active swap waiters and queued requests). Don't touch processes; teardown is `baseRouter`'s job.
|
||||
|
||||
Expose a constructor matching the `Factory` shape:
|
||||
|
||||
```go
|
||||
func NewMyScheduler(name string, logger *logmon.Monitor, swapper Swapper, eff Effects) *MyScheduler {
|
||||
// ...
|
||||
}
|
||||
|
||||
// in the concrete router:
|
||||
base := newBaseRouter(name, conf, processes, proxylog,
|
||||
func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler {
|
||||
return scheduler.NewMyScheduler(name, logger, swapper, eff)
|
||||
})
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
- **Schedulers** are tested as pure state machines in the `scheduler` package:
|
||||
drive the `On*` methods directly against a `fakeEffects` and assert on the
|
||||
recorded grants/starts/stops. No goroutines, no sleeps. See
|
||||
`scheduler/fifo_test.go` as the reference; follow the `TestSchedulerName_<scenario>`
|
||||
naming convention.
|
||||
- **`baseRouter` mechanism** (run loop, `grant`/`ServeHTTP`, `Unload`,
|
||||
`Shutdown`) is tested in `base_test.go`. The run loop exposes a
|
||||
`testProcessed` channel so tests can wait for an event to be fully processed
|
||||
instead of sleeping.
|
||||
- Run new tests with `go test -v -run TestMyScheduler_... ./internal/router/scheduler/`,
|
||||
then `make test-dev` for a quick `go test` + `staticcheck` pass over `proxy/`.
|
||||
@@ -0,0 +1,106 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
)
|
||||
|
||||
type Group struct {
|
||||
*baseRouter
|
||||
}
|
||||
|
||||
func NewGroup(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Group, error) {
|
||||
modelToGroup := make(map[string]string)
|
||||
for gid, gcfg := range conf.Routing.Router.Settings.Groups {
|
||||
for _, mid := range gcfg.Members {
|
||||
if existing, dup := modelToGroup[mid]; dup {
|
||||
return nil, fmt.Errorf("model %q is in multiple groups: %q and %q", mid, existing, gid)
|
||||
}
|
||||
modelToGroup[mid] = gid
|
||||
}
|
||||
}
|
||||
|
||||
swapper := &groupSwapper{
|
||||
config: conf,
|
||||
modelToGroup: modelToGroup,
|
||||
}
|
||||
|
||||
processes := make(map[string]process.Process, len(modelToGroup))
|
||||
base, err := newBaseRouter("group", conf, processes, proxylog, swapper)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating base router: %w", err)
|
||||
}
|
||||
|
||||
for mid := range modelToGroup {
|
||||
modelCfg, _, ok := conf.FindConfig(mid)
|
||||
if !ok {
|
||||
base.shutdownFn()
|
||||
base.procCancel()
|
||||
return nil, fmt.Errorf("no model config for %q", mid)
|
||||
}
|
||||
procLog := logmon.NewWriter(upstreamlog)
|
||||
p, err := process.New(base.procCtx, mid, modelCfg, procLog, proxylog)
|
||||
if err != nil {
|
||||
base.shutdownFn()
|
||||
base.procCancel()
|
||||
return nil, fmt.Errorf("creating process for %q: %w", mid, err)
|
||||
}
|
||||
processes[mid] = p
|
||||
}
|
||||
|
||||
g := &Group{baseRouter: base}
|
||||
go base.run()
|
||||
return g, nil
|
||||
}
|
||||
|
||||
// groupSwapper decides evictions from static group configuration.
|
||||
//
|
||||
// Same-group siblings are stopped when the group has swap=true. Cross-group
|
||||
// members are stopped only when the target's group is exclusive; loading a
|
||||
// model from a non-exclusive group leaves running exclusive groups alone,
|
||||
// matching the gotcha in the original ProcessGroup behaviour.
|
||||
type groupSwapper struct {
|
||||
config config.Config
|
||||
modelToGroup map[string]string
|
||||
}
|
||||
|
||||
func (p *groupSwapper) EvictionFor(target string, running []string) []string {
|
||||
tg := p.modelToGroup[target]
|
||||
tgCfg := p.config.Routing.Router.Settings.Groups[tg]
|
||||
|
||||
seen := make(map[string]struct{})
|
||||
var result []string
|
||||
consider := func(mID string) {
|
||||
if mID == target {
|
||||
return
|
||||
}
|
||||
if _, dup := seen[mID]; dup {
|
||||
return
|
||||
}
|
||||
og := p.modelToGroup[mID]
|
||||
switch {
|
||||
case og == tg && tgCfg.Swap:
|
||||
seen[mID] = struct{}{}
|
||||
result = append(result, mID)
|
||||
// the previous ProcessGroup behaviour did not unload exclusive groups
|
||||
// when loading a non-exclusive model. This maintains that gotcha
|
||||
// for backwards compatibility. The newer swap matrix approach does not
|
||||
// have this issue.
|
||||
case og != tg && tgCfg.Exclusive:
|
||||
if ogCfg := p.config.Routing.Router.Settings.Groups[og]; !ogCfg.Persistent {
|
||||
seen[mID] = struct{}{}
|
||||
result = append(result, mID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, mID := range running {
|
||||
consider(mID)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (p *groupSwapper) OnSwapStart(target string, running []string) {}
|
||||
@@ -0,0 +1,335 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
)
|
||||
|
||||
// newTestGroup builds a Group directly from the supplied processes and config,
|
||||
// bypassing NewGroup's call to process.New.
|
||||
func newTestGroup(t *testing.T, conf config.Config, processes map[string]process.Process) *Group {
|
||||
t.Helper()
|
||||
modelToGroup := make(map[string]string)
|
||||
for gid, gcfg := range conf.Routing.Router.Settings.Groups {
|
||||
for _, mid := range gcfg.Members {
|
||||
modelToGroup[mid] = gid
|
||||
}
|
||||
}
|
||||
swapper := &groupSwapper{
|
||||
config: conf,
|
||||
modelToGroup: modelToGroup,
|
||||
}
|
||||
base, err := newBaseRouter("group", conf, processes, logmon.NewWriter(io.Discard), swapper)
|
||||
if err != nil {
|
||||
t.Fatalf("newBaseRouter: %v", err)
|
||||
}
|
||||
base.testProcessed = make(chan struct{}, 64)
|
||||
g := &Group{baseRouter: base}
|
||||
go base.run()
|
||||
t.Cleanup(func() {
|
||||
if !g.shuttingDown.Load() {
|
||||
_ = g.Shutdown(time.Second)
|
||||
}
|
||||
})
|
||||
return g
|
||||
}
|
||||
|
||||
func TestGroup_NewGroup_DuplicateMembership(t *testing.T) {
|
||||
conf := config.Config{
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"g1": {Swap: true, Members: []string{"a"}},
|
||||
"g2": {Swap: true, Members: []string{"a"}},
|
||||
}),
|
||||
Models: map[string]config.ModelConfig{
|
||||
"a": {},
|
||||
},
|
||||
}
|
||||
log := logmon.NewWriter(io.Discard)
|
||||
if _, err := NewGroup(conf, log, log); err == nil {
|
||||
t.Fatalf("expected error for duplicate membership")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroup_ServeHTTP_SwapStopsPrevious(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
go a.Run(0) // park a Run goroutine so Stop has something to release
|
||||
|
||||
b := newFakeProcess("b")
|
||||
b.autoReady = true
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"g": {Swap: true, Exclusive: true, Members: []string{"a", "b"}},
|
||||
}),
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
g.ServeHTTP(w, newRequest("b"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 1 {
|
||||
t.Errorf("a.stopCalls=%d want 1", got)
|
||||
}
|
||||
if got := b.runCalls.Load(); got != 1 {
|
||||
t.Errorf("b.runCalls=%d want 1", got)
|
||||
}
|
||||
if got := b.serveCalls.Load(); got != 1 {
|
||||
t.Errorf("b.serveCalls=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroup_NonSwapGroup_NoStop(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
|
||||
b := newFakeProcess("b")
|
||||
b.autoReady = true
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"g": {Swap: false, Exclusive: false, Members: []string{"a", "b"}},
|
||||
}),
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
g.ServeHTTP(w, newRequest("b"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 0 {
|
||||
t.Errorf("a.stopCalls=%d want 0 (swap=false should not stop siblings)", got)
|
||||
}
|
||||
if got := b.runCalls.Load(); got != 1 {
|
||||
t.Errorf("b.runCalls=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroup_CrossGroupExclusive(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
go a.Run(0)
|
||||
|
||||
b := newFakeProcess("b")
|
||||
b.autoReady = true
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"g1": {Swap: true, Exclusive: true, Members: []string{"a"}},
|
||||
"g2": {Swap: true, Exclusive: true, Members: []string{"b"}},
|
||||
}),
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
g.ServeHTTP(w, newRequest("b"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 1 {
|
||||
t.Errorf("a.stopCalls=%d want 1 (cross-group exclusive must stop)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGroup_CrossGroupNonExclusiveParallel verifies that two requests for
|
||||
// models in distinct non-exclusive groups load in parallel rather than
|
||||
// serializing through the router's run loop.
|
||||
func TestGroup_CrossGroupNonExclusiveParallel(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"g1": {Swap: true, Exclusive: false, Members: []string{"a"}},
|
||||
"g2": {Swap: true, Exclusive: false, Members: []string{"b"}},
|
||||
}),
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": pb})
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
g.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, g.testProcessed, 1)
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
g.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, g.testProcessed, 1)
|
||||
|
||||
// Both groups load concurrently — both must reach Run() before either is
|
||||
// marked ready. If the router still serialised, only one would proceed.
|
||||
<-a.runStarted
|
||||
<-pb.runStarted
|
||||
|
||||
a.markReady()
|
||||
pb.markReady()
|
||||
|
||||
for i, ch := range []chan struct{}{done1, done2} {
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("request %d did not complete", i)
|
||||
}
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 0 {
|
||||
t.Errorf("a.stopCalls=%d want 0 (parallel groups don't evict each other)", got)
|
||||
}
|
||||
if got := pb.stopCalls.Load(); got != 0 {
|
||||
t.Errorf("b.stopCalls=%d want 0 (parallel groups don't evict each other)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGroup_SameGroupSwapSerialises verifies that two same-group requests
|
||||
// (Swap=true) serialise even when both arrive while neither has reached
|
||||
// StateStarting yet — the in-flight swap target the scheduler folds into the
|
||||
// running set closes that race.
|
||||
func TestGroup_SameGroupSwapSerialises(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"g": {Swap: true, Exclusive: false, Members: []string{"a", "b"}},
|
||||
}),
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": pb})
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
g.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, g.testProcessed, 1)
|
||||
|
||||
// Request B arrives before A transitions to StateStarting in the process
|
||||
// state machine. Without folding the in-flight swap target into the running
|
||||
// set, the swapper would not see A as running, and B would start in
|
||||
// parallel, violating Swap=true.
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
g.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, g.testProcessed, 1)
|
||||
|
||||
if got := pb.runCalls.Load(); got != 0 {
|
||||
t.Errorf("b started in parallel: runCalls=%d want 0", got)
|
||||
}
|
||||
|
||||
<-a.runStarted
|
||||
a.markReady()
|
||||
waitProcessed(t, g.testProcessed, 1) // swapDone(a) → b promoted
|
||||
<-pb.runStarted
|
||||
pb.markReady()
|
||||
|
||||
for i, ch := range []chan struct{}{done1, done2} {
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("request %d did not complete", i)
|
||||
}
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 1 {
|
||||
t.Errorf("a.stopCalls=%d want 1 (b's swap must stop a)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGroup_PersistentNotEvicted verifies that a group with persistent=true
|
||||
// is never evicted when another exclusive group starts loading. The running
|
||||
// model in the persistent group stays alive alongside the new one.
|
||||
func TestGroup_PersistentNotEvicted(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
go a.Run(0)
|
||||
|
||||
b := newFakeProcess("b")
|
||||
b.autoReady = true
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"persist": {Swap: true, Exclusive: false, Persistent: true, Members: []string{"a"}},
|
||||
"other": {Swap: true, Exclusive: true, Members: []string{"b"}},
|
||||
}),
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
g.ServeHTTP(w, newRequest("b"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 0 {
|
||||
t.Errorf("a.stopCalls=%d want 0 (persistent group must not be evicted)", got)
|
||||
}
|
||||
if a.State() != process.StateStarting && a.State() != process.StateReady {
|
||||
t.Errorf("a state=%s want still running", a.State())
|
||||
}
|
||||
if got := b.runCalls.Load(); got != 1 {
|
||||
t.Errorf("b.runCalls=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGroup_NonExclusiveDoesNotUnloadExclusive pins a backwards-compatible
|
||||
// gotcha from the original ProcessGroup: when a model in a non-exclusive group
|
||||
// is loaded, any running exclusive group keeps running. The two coexist.
|
||||
func TestGroup_NonExclusiveDoesNotUnloadExclusive(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
go a.Run(0)
|
||||
|
||||
b := newFakeProcess("b")
|
||||
b.autoReady = true
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"g1": {Swap: true, Exclusive: true, Members: []string{"a"}},
|
||||
"g2": {Swap: true, Exclusive: false, Members: []string{"b"}},
|
||||
}),
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
g.ServeHTTP(w, newRequest("b"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 0 {
|
||||
t.Errorf("a.stopCalls=%d want 0 (non-exclusive target must not unload exclusive group)", got)
|
||||
}
|
||||
if a.State() != process.StateStarting && a.State() != process.StateReady {
|
||||
t.Errorf("a state=%s want still running", a.State())
|
||||
}
|
||||
if got := b.runCalls.Load(); got != 1 {
|
||||
t.Errorf("b.runCalls=%d want 1", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
)
|
||||
|
||||
// groupRouting builds a normalized RoutingConfig for the group router, mirroring
|
||||
// what config.LoadConfigFromReader produces. Tests use it to populate
|
||||
// config.Config.Routing without going through LoadConfig.
|
||||
func groupRouting(groups map[string]config.GroupConfig) config.RoutingConfig {
|
||||
return config.RoutingConfig{
|
||||
Router: config.RouterConfig{
|
||||
Use: "group",
|
||||
Settings: config.RouterSettings{Groups: groups},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// fakeProcess is an in-memory implementation of process.Process used to drive
|
||||
// the routers through their state machine without spawning real upstreams.
|
||||
type fakeProcess struct {
|
||||
id string
|
||||
|
||||
mu sync.Mutex
|
||||
state process.ProcessState
|
||||
readyCh chan struct{}
|
||||
stopCh chan struct{}
|
||||
runStarted chan struct{} // closed on the first Run call
|
||||
stopStarted chan struct{} // closed on the first Stop call
|
||||
|
||||
autoReady bool
|
||||
|
||||
// serveBlock, when non-nil, makes ServeHTTP receive from it before
|
||||
// writing its response. Tests use this to hold a request in-flight.
|
||||
// Closing the channel releases every blocked ServeHTTP caller.
|
||||
serveBlock chan struct{}
|
||||
// serveStarted is closed on the first ServeHTTP entry, letting tests
|
||||
// wait deterministically for the handler to begin executing.
|
||||
serveStarted chan struct{}
|
||||
// stopBlock, when non-nil, makes Stop receive from it (after signalling
|
||||
// stopStarted) before completing. Tests use this to prove that several
|
||||
// Stop calls can be in flight simultaneously.
|
||||
stopBlock chan struct{}
|
||||
|
||||
runCalls atomic.Int32
|
||||
stopCalls atomic.Int32
|
||||
serveCalls atomic.Int32
|
||||
|
||||
// inFlightServe counts ServeHTTP calls currently inside the handler.
|
||||
// stoppedWhileServing flips true if Stop is ever called while that
|
||||
// counter is non-zero — a direct, race-free observation of the
|
||||
// "swap mid-request" anti-property.
|
||||
inFlightServe atomic.Int32
|
||||
stoppedWhileServing atomic.Bool
|
||||
}
|
||||
|
||||
func newFakeProcess(id string) *fakeProcess {
|
||||
return &fakeProcess{
|
||||
id: id,
|
||||
state: process.StateStopped,
|
||||
readyCh: make(chan struct{}),
|
||||
stopCh: make(chan struct{}),
|
||||
runStarted: make(chan struct{}),
|
||||
stopStarted: make(chan struct{}),
|
||||
serveStarted: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeProcess) setState(s process.ProcessState) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.state = s
|
||||
if s == process.StateReady {
|
||||
select {
|
||||
case <-f.readyCh:
|
||||
default:
|
||||
close(f.readyCh)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeProcess) State() process.ProcessState {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.state
|
||||
}
|
||||
|
||||
func (f *fakeProcess) markReady() { f.setState(process.StateReady) }
|
||||
|
||||
func (f *fakeProcess) Run(_ time.Duration) error {
|
||||
f.runCalls.Add(1)
|
||||
f.mu.Lock()
|
||||
if f.state != process.StateStopped {
|
||||
s := f.state
|
||||
f.mu.Unlock()
|
||||
return fmt.Errorf("fakeProcess %s: Run called while %s", f.id, s)
|
||||
}
|
||||
f.state = process.StateStarting
|
||||
sc := f.stopCh
|
||||
select {
|
||||
case <-f.runStarted:
|
||||
default:
|
||||
close(f.runStarted)
|
||||
}
|
||||
f.mu.Unlock()
|
||||
|
||||
if f.autoReady {
|
||||
f.setState(process.StateReady)
|
||||
}
|
||||
<-sc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeProcess) Stop(_ time.Duration) error {
|
||||
f.stopCalls.Add(1)
|
||||
if f.inFlightServe.Load() > 0 {
|
||||
f.stoppedWhileServing.Store(true)
|
||||
}
|
||||
f.mu.Lock()
|
||||
select {
|
||||
case <-f.stopStarted:
|
||||
default:
|
||||
close(f.stopStarted)
|
||||
}
|
||||
f.mu.Unlock()
|
||||
|
||||
// Test hook: hold Stop here so the test can prove multiple Stops are
|
||||
// in flight at the same time before any of them complete.
|
||||
if f.stopBlock != nil {
|
||||
<-f.stopBlock
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
if f.state == process.StateStopped {
|
||||
return nil
|
||||
}
|
||||
f.state = process.StateStopped
|
||||
select {
|
||||
case <-f.stopCh:
|
||||
default:
|
||||
close(f.stopCh)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeProcess) WaitReady(ctx context.Context) error {
|
||||
f.mu.Lock()
|
||||
if f.state == process.StateReady {
|
||||
f.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
rc := f.readyCh
|
||||
f.mu.Unlock()
|
||||
select {
|
||||
case <-rc:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeProcess) Logger() *logmon.Monitor { return logmon.NewWriter(io.Discard) }
|
||||
|
||||
func (f *fakeProcess) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
|
||||
f.serveCalls.Add(1)
|
||||
f.inFlightServe.Add(1)
|
||||
defer f.inFlightServe.Add(-1)
|
||||
f.mu.Lock()
|
||||
select {
|
||||
case <-f.serveStarted:
|
||||
default:
|
||||
close(f.serveStarted)
|
||||
}
|
||||
f.mu.Unlock()
|
||||
if f.serveBlock != nil {
|
||||
<-f.serveBlock
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, "ok:%s", f.id)
|
||||
}
|
||||
|
||||
// waitProcessed drains n events from ch, fataling on timeout. One event fires
|
||||
// per handlerReq or swapDone fully absorbed by run().
|
||||
func waitProcessed(t *testing.T, ch chan struct{}, n int) {
|
||||
t.Helper()
|
||||
for i := 0; i < n; i++ {
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("waitProcessed: only %d/%d events received", i, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newRequest(model string) *http.Request {
|
||||
body := fmt.Sprintf(`{"model":%q}`, model)
|
||||
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
return r
|
||||
}
|
||||
|
||||
func newRequestCtx(ctx context.Context, model string) *http.Request {
|
||||
return newRequest(model).WithContext(ctx)
|
||||
}
|
||||
@@ -0,0 +1,277 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
var loadingPaths = []string{
|
||||
"/v1/chat/completions",
|
||||
}
|
||||
|
||||
func isLoadingPath(path string) bool {
|
||||
for _, p := range loadingPaths {
|
||||
if strings.HasPrefix(path, p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type loadingWriter struct {
|
||||
hasWritten bool
|
||||
writer http.ResponseWriter
|
||||
req *http.Request
|
||||
ctx context.Context
|
||||
logger *logmon.Monitor
|
||||
modelName string
|
||||
startTime time.Time
|
||||
|
||||
pendingMu sync.Mutex
|
||||
pendingUpdate string
|
||||
|
||||
// writeMu serializes writes to the underlying writer and guards released.
|
||||
// Once released is set, the streaming goroutine must not touch the writer
|
||||
// again — ServeHTTP has reclaimed it (to run the real handler or to return)
|
||||
// and writing/flushing a finalized response panics.
|
||||
writeMu sync.Mutex
|
||||
released bool
|
||||
|
||||
// closed by start when the goroutine finishes (after cleanup messages)
|
||||
done chan struct{}
|
||||
|
||||
// test-only: closed when start enters its loop
|
||||
loopStarted chan struct{}
|
||||
// test-only: override the 1s tick interval
|
||||
tickDuration time.Duration
|
||||
// test-only: override character streaming speed (0 = no delay)
|
||||
charPerSecond float64
|
||||
}
|
||||
|
||||
func newLoadingWriter(logger *logmon.Monitor, modelName string, w http.ResponseWriter, req *http.Request) *loadingWriter {
|
||||
s := &loadingWriter{
|
||||
writer: w,
|
||||
req: req,
|
||||
ctx: req.Context(),
|
||||
logger: logger,
|
||||
modelName: modelName,
|
||||
startTime: time.Now(),
|
||||
tickDuration: 750 * time.Millisecond,
|
||||
charPerSecond: 75,
|
||||
}
|
||||
|
||||
s.Header().Set("Content-Type", "text/event-stream")
|
||||
s.Header().Set("Cache-Control", "no-cache")
|
||||
s.Header().Set("Connection", "keep-alive")
|
||||
s.WriteHeader(http.StatusOK)
|
||||
s.sendLine("━━━━━")
|
||||
s.sendLine(fmt.Sprintf("llama-swap loading model: %s", modelName))
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *loadingWriter) setUpdate(msg string) {
|
||||
s.pendingMu.Lock()
|
||||
s.pendingUpdate = msg
|
||||
s.pendingMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *loadingWriter) start(ctx context.Context) {
|
||||
s.done = make(chan struct{})
|
||||
defer close(s.done)
|
||||
|
||||
defer func() {
|
||||
// Skip cleanup writes if the client disconnected — the connection
|
||||
// is being torn down and flushing against it will panic.
|
||||
if s.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
duration := time.Since(s.startTime)
|
||||
s.sendData("\n")
|
||||
s.sendLine(fmt.Sprintf("Done! (%.2fs)", duration.Seconds()))
|
||||
s.sendLine("━━━━━")
|
||||
s.sendLine(" ")
|
||||
}()
|
||||
|
||||
remarks := make([]string, len(loadingRemarks))
|
||||
copy(remarks, loadingRemarks)
|
||||
rand.Shuffle(len(remarks), func(i, j int) {
|
||||
remarks[i], remarks[j] = remarks[j], remarks[i]
|
||||
})
|
||||
ri := 0
|
||||
|
||||
nextRemarkIn := time.Duration(2+rand.Intn(4)) * time.Second
|
||||
lastRemarkTime := time.Time{}
|
||||
|
||||
ticker := time.NewTicker(s.tickDuration)
|
||||
defer ticker.Stop()
|
||||
|
||||
if s.loopStarted != nil {
|
||||
close(s.loopStarted)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.pendingMu.Lock()
|
||||
update := s.pendingUpdate
|
||||
s.pendingUpdate = ""
|
||||
s.pendingMu.Unlock()
|
||||
|
||||
if update != "" {
|
||||
s.sendData("\n")
|
||||
s.sendInline(update)
|
||||
s.sendData(" ")
|
||||
lastRemarkTime = time.Now()
|
||||
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
|
||||
} else if time.Since(lastRemarkTime) >= nextRemarkIn {
|
||||
remark := remarks[ri%len(remarks)]
|
||||
ri++
|
||||
s.sendData("\n")
|
||||
s.sendInline(remark)
|
||||
s.sendData(" ")
|
||||
lastRemarkTime = time.Now()
|
||||
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
|
||||
} else {
|
||||
s.sendData(".")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *loadingWriter) waitForCompletion(timeout time.Duration) bool {
|
||||
if s.done == nil {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-s.done:
|
||||
return true
|
||||
case <-time.After(timeout):
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *loadingWriter) sendInline(text string) {
|
||||
chunkSize := 10
|
||||
if s.charPerSecond > 0 {
|
||||
chunkSize = max(3, int(s.charPerSecond)/15)
|
||||
}
|
||||
|
||||
runes := []rune(text)
|
||||
for i := 0; i < len(runes); {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
end := i + chunkSize
|
||||
if end > len(runes) {
|
||||
end = len(runes)
|
||||
}
|
||||
chunk := string(runes[i:end])
|
||||
s.sendData(chunk)
|
||||
i = end
|
||||
|
||||
if i < len(runes) && s.charPerSecond > 0 {
|
||||
time.Sleep(time.Duration(float64(time.Second) * float64(len(chunk)) / s.charPerSecond))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *loadingWriter) sendLine(line string) {
|
||||
if line == "" {
|
||||
s.sendData("\n")
|
||||
return
|
||||
}
|
||||
s.sendInline(line)
|
||||
s.sendData("\n")
|
||||
}
|
||||
|
||||
func (s *loadingWriter) sendData(data string) {
|
||||
type Delta struct {
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
}
|
||||
type Choice struct {
|
||||
Delta Delta `json:"delta"`
|
||||
}
|
||||
type SSEMessage struct {
|
||||
Choices []Choice `json:"choices"`
|
||||
}
|
||||
|
||||
msg := SSEMessage{
|
||||
Choices: []Choice{
|
||||
{
|
||||
Delta: Delta{
|
||||
ReasoningContent: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
s.logger.Errorf("<%s> Failed to marshal SSE message: %v", s.modelName, err)
|
||||
return
|
||||
}
|
||||
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
// Once ServeHTTP has reclaimed the writer (release), writing/flushing it
|
||||
// races the real handler or panics on a finalized response. Stop here.
|
||||
if s.released {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData); err != nil {
|
||||
s.logger.Debugf("<%s> Failed to write SSE data (client likely disconnected): %v", s.modelName, err)
|
||||
return
|
||||
}
|
||||
if flusher, ok := s.writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// release fences the loadingWriter off from the underlying ResponseWriter.
|
||||
// After it returns, the streaming goroutine will not write to or flush the
|
||||
// writer again: any in-flight write completes under writeMu first, and later
|
||||
// writes short-circuit on released. The caller can then safely hand the writer
|
||||
// to the real handler or let ServeHTTP return without racing a finalized
|
||||
// response (a use-after-return Flush panics on the recycled *bufio.Writer).
|
||||
func (s *loadingWriter) release() {
|
||||
s.writeMu.Lock()
|
||||
s.released = true
|
||||
s.writeMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *loadingWriter) Header() http.Header {
|
||||
return s.writer.Header()
|
||||
}
|
||||
|
||||
func (s *loadingWriter) Write(data []byte) (int, error) {
|
||||
return s.writer.Write(data)
|
||||
}
|
||||
|
||||
func (s *loadingWriter) WriteHeader(statusCode int) {
|
||||
if s.hasWritten {
|
||||
return
|
||||
}
|
||||
s.hasWritten = true
|
||||
s.writer.WriteHeader(statusCode)
|
||||
s.Flush()
|
||||
}
|
||||
|
||||
func (s *loadingWriter) Flush() {
|
||||
if flusher, ok := s.writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
package router
|
||||
|
||||
var loadingRemarks = []string{
|
||||
"Still faster than your last standup meeting",
|
||||
"Reticulating splines",
|
||||
"Waking up the hamsters",
|
||||
"Teaching the model manners",
|
||||
"Convincing the GPU to participate",
|
||||
"Loading weights (they're heavy)",
|
||||
"Please enjoy this elevator music in your head",
|
||||
"Pretending to be productive",
|
||||
"Reading the entire internet, page by page",
|
||||
"Staring at the abyss, the abyss is buffering",
|
||||
"Applying layer after layer of disembodied cognition",
|
||||
"Remembering everything it forgot during quantization",
|
||||
"Counting to 405 billion, one parameter at a time",
|
||||
"Summoning the stochastic parroting",
|
||||
"Hold on, the GPU is questioning its existence",
|
||||
"Deciding which facts to hallucinate today",
|
||||
"Untangling the transformer spaghetti",
|
||||
"Warming up the token soup",
|
||||
"Your prompt is in a queue, behind 7 billion other thoughts",
|
||||
"Running `sudo apt-get install intelligence`",
|
||||
"Defragmenting the latent space",
|
||||
"Polishing each matrix multiplication by hand",
|
||||
"Whispering sweet nothings to the attention heads",
|
||||
"Aligning with human values, one reluctant epoch at a time",
|
||||
"The model is thinking about what it's about to think about",
|
||||
"Loading... and by loading we mean making you wait",
|
||||
"Spinning up the cloud GPU, please be patient while we burn your credits",
|
||||
"Applying duct tape to the context window",
|
||||
"Bribing the GPU scheduler for a timeslice",
|
||||
"Would you like to hear a fun fact while we load? Too bad.",
|
||||
"Hot swapping your sanity for an LLM",
|
||||
"Compressing optimism into FP16",
|
||||
"Ignoring 90% of the attention to save you 50% of the time",
|
||||
"Counting the exact same thing three times just to be sure",
|
||||
"Sorry, the inference you have reached is not in service",
|
||||
"Rotating the positional encodings counterclockwise for good luck",
|
||||
"Your call is very important to us. Please continue to hold.",
|
||||
"Unpacking the blobs. All 300GB of them.",
|
||||
"Initializing the thing that initializes the other thing",
|
||||
"Converting electricity into existential dread",
|
||||
"Flattening the curve... wait, the tensor. Flattening the tensor.",
|
||||
"Fetching the fetch of a fetch, callback hell edition",
|
||||
"The GPU is at 100%. The fan is now a helicopter.",
|
||||
"Baking the weights at 350° for a golden-brown inference",
|
||||
"Recalibrating the confidence of things it's still wrong about",
|
||||
"Have you tried turning it off and on again? No? Good, wait here.",
|
||||
"Simulating deep thought by pausing dramatically",
|
||||
"Loading the model that knows more than you but still can't count r's in 'strawberry'",
|
||||
"Convincing CUDA to cooperate. This may take a while.",
|
||||
"VRAM: 23.9GB used of 24GB. Living on the edge.",
|
||||
"Processing your request with the urgency of a DMV employee",
|
||||
"This model was trained on the entire internet, including that embarrassing blog you wrote in 2008",
|
||||
"Dispatching tokens through a series of increasingly confused matrix multiplies",
|
||||
"Gently lowering your expectations",
|
||||
"Applying softmax to our feelings about this load time",
|
||||
"Autoregressively generating disappointment, one token at a time",
|
||||
"The magic is happening. Somewhere. Probably.",
|
||||
"Synchronizing the parallel processes that run in parallel but really don't",
|
||||
"Calculating the meaning of life. Spoiler: it's 42, but we're double-checking.",
|
||||
"Loading... just like it said 30 seconds ago. And will say 30 seconds from now.",
|
||||
"Pre-warming the cache so the first query is only slightly slower than the rest",
|
||||
"Have you considered that maybe your question wasn't worth all this compute?",
|
||||
"Downloading more RAM (no, really, we're mmap-ing the weights)",
|
||||
"Translating your prompt into math it barely understands",
|
||||
"Estimating your time remaining with 0% accuracy",
|
||||
"Buffering enthusiasm",
|
||||
"Model is loading. Go make some coffee. Or a three-course meal.",
|
||||
"Tokenizing the dictionary, filing a grievance on behalf of 'antidisestablishmentarianism'",
|
||||
"Polling for readiness in a loop that would make your CS professor weep",
|
||||
"Performing percussive maintenance on the attention mechanism",
|
||||
"This loading screen is singlehandedly reversing climate progress",
|
||||
"Decompressing the hopes and dreams of thousands of underpaid labelers",
|
||||
"Filling the key-value cache with the ghost of prompts past",
|
||||
"Currently at step 3 of 9,742 of loading. We'll get there. Eventually.",
|
||||
"If you stare at the spinner, it spins slower. It's science.",
|
||||
"Multiplying matricies with the enthusiasm of a teenager doing chores",
|
||||
"Applying `torch.nap()` until the model feels refreshed",
|
||||
"Reacquainting the model with the concept of 'facts' it forgot during fine-tuning",
|
||||
"Sorry for the wait. No, wait, we're not actually sorry.",
|
||||
"Your GPU is now a space heater with a side hustle in linear algebra",
|
||||
"Allocating memory like a billionaire allocates tax avoidance strategies",
|
||||
"The model saw \"As an AI language model\" and won't stop saying it now",
|
||||
"Installing dependencies you didn't know existed and will never use again",
|
||||
"Re-reading 'Attention Is All You Need' for the 400th time",
|
||||
"Convincing the embedding layer that context is overrated",
|
||||
"Manually untangling the residual connections with a tiny comb",
|
||||
"On hold with the cloud provider trying to explain why 8 H100s isn't enough",
|
||||
"Adjusting temperatures: model is 0.7, server room is 104°F",
|
||||
"Please hold while we justify this electricity bill to accounting",
|
||||
"Stacking decoder blocks like a Jenga tower at a LAN party",
|
||||
"Compensating for your lack of patience with our lack of speed",
|
||||
"This is a loading screen comment. Loading screens have comments now. Welcome to the future.",
|
||||
"Processing the entire works of Shakespeare backwards just in case",
|
||||
"The model is loading slower than your last `npm install`",
|
||||
"Rehearsing plausible-sounding explanations for why it got everything wrong",
|
||||
"Populating the context with filler while you wait for actual content",
|
||||
"Optimizing for BLEU score, which definitely correlates with making you laugh",
|
||||
"Generating an embedding for each and every letter of the alphabet, individually",
|
||||
"Coming soon: llama-swap v2 with actual performance improvements. Probably.",
|
||||
"Loading a model larger than your attention span",
|
||||
"Performing a seance to invoke the spirit of Geoff Hinton",
|
||||
"Did you know loading screens were invented to prevent users from smashing their monitors? Now you do.",
|
||||
"Converting all the internet's bad opinions into a surprisingly useful autocomplete",
|
||||
"Laying down each layer with the care of a Michelin-starred pastry chef",
|
||||
"Checking if the model still thinks birds are government drones. Yep.",
|
||||
"Activating the neurons responsible for 'I cannot assist with that request'",
|
||||
"This model was trained on the same internet that brought you Rickrolling. You're welcome.",
|
||||
"Realigning the alignment so it aligns with the previous alignment",
|
||||
"Running `nvidia-smi` and sighing heavily",
|
||||
"If you close your eyes, the loading bar moves faster. Proven by science.",
|
||||
"EULA said 'by using this software you agree to wait forever' and you clicked Accept",
|
||||
"Zipping the GPUs to make them go faster",
|
||||
"Padding the context window with existential padding",
|
||||
"We could have used a smaller model but someone wanted 'quality'",
|
||||
"Disentangling the latent space into something resembling coherence",
|
||||
"Slow is smooth, smooth is fast, but this is just slow",
|
||||
"Memory-mapping like it's a AAA title from 2012",
|
||||
"Your patience has been tokenized and added to the training set. Thank you for your contribution.",
|
||||
"Loading is CPU-bound and your CPU is busy regretting its life choices",
|
||||
"Exploring the high-dimensional manifold of ways to say 'just a moment'",
|
||||
"The model is experiencing a brief but intense moment of imposter syndrome",
|
||||
"Initializing 7B parameters by rolling 7B 16-sided dice",
|
||||
"Panic! at the disk I/O",
|
||||
"Intelligence is loading... your definition of intelligence may vary",
|
||||
"This model was distilled. Unlike your patience, which is evaporating.",
|
||||
"Unzipping the model. It's a .gguf file, not a metaphor.",
|
||||
"Running inference on the concept of 'soon' to estimate remaining time",
|
||||
"Loading with all the speed of a government-funded IT project",
|
||||
"A blank terminal is a terrible thing to waste. Here's a loading message instead.",
|
||||
}
|
||||
@@ -0,0 +1,265 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
func TestLoadingWriter_SSEHeadersAndInitialMessage(t *testing.T) {
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||
|
||||
if ct := lw.Header().Get("Content-Type"); ct != "text/event-stream" {
|
||||
t.Errorf("Content-Type: want text/event-stream, got %q", ct)
|
||||
}
|
||||
if cc := lw.Header().Get("Cache-Control"); cc != "no-cache" {
|
||||
t.Errorf("Cache-Control: want no-cache, got %q", cc)
|
||||
}
|
||||
if conn := lw.Header().Get("Connection"); conn != "keep-alive" {
|
||||
t.Errorf("Connection: want keep-alive, got %q", conn)
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
if !strings.HasPrefix(body, "data: ") {
|
||||
t.Errorf("expected SSE data: prefix, got: %s", body)
|
||||
}
|
||||
|
||||
content := extractStreamedContent(body)
|
||||
if !strings.Contains(content, "━━━━━\n") {
|
||||
t.Errorf("missing separator in streamed content: %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "llama-swap loading model: test-model\n") {
|
||||
t.Errorf("missing initial message in streamed content: %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadingWriter_WriteHeaderOnce(t *testing.T) {
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||
lw.WriteHeader(http.StatusCreated)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("first WriteHeader: want %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadingWriter_WritePassthrough(t *testing.T) {
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||
lw.Write([]byte("hello"))
|
||||
lw.Flush()
|
||||
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, "hello") {
|
||||
t.Errorf("Write passthrough failed, body: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadingWriter_StartStopsOnCancel(t *testing.T) {
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||
lw.tickDuration = 10 * time.Millisecond
|
||||
lw.loopStarted = make(chan struct{})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
go lw.start(ctx)
|
||||
<-lw.loopStarted
|
||||
cancel()
|
||||
|
||||
if !lw.waitForCompletion(time.Second) {
|
||||
t.Fatal("waitForCompletion timed out")
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, "Done!") {
|
||||
t.Errorf("expected Done! message, body: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadingWriter_StartShowsSetUpdate(t *testing.T) {
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||
lw.tickDuration = 10 * time.Millisecond
|
||||
lw.charPerSecond = 0
|
||||
lw.loopStarted = make(chan struct{})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go lw.start(ctx)
|
||||
<-lw.loopStarted
|
||||
|
||||
lw.setUpdate("custom status message")
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
if !lw.waitForCompletion(time.Second) {
|
||||
t.Fatal("waitForCompletion timed out")
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
content := extractStreamedContent(body)
|
||||
if !strings.Contains(content, "custom status message") {
|
||||
t.Errorf("expected setUpdate message in output, got: %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadingWriter_SendDataFormat(t *testing.T) {
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||
lw.sendData("hello world")
|
||||
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, `"reasoning_content":"hello world"`) {
|
||||
t.Errorf("expected reasoning_content in SSE data, body: %s", body)
|
||||
}
|
||||
if !strings.HasPrefix(body, "data: ") {
|
||||
t.Errorf("expected data: prefix, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadingWriter_SendLine(t *testing.T) {
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||
lw.charPerSecond = 0
|
||||
|
||||
// Capture only the content from this sendLine call
|
||||
before := w.Body.Len()
|
||||
lw.sendLine("line content")
|
||||
after := w.Body.Len()
|
||||
chunkBody := w.Body.String()[before:after]
|
||||
|
||||
content := extractStreamedContent(chunkBody)
|
||||
if content != "line content\n" {
|
||||
t.Errorf("expected complete streamed line, got: %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadingWriter_FlushesPeriodicallyDuringStatusUpdates(t *testing.T) {
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||
lw.tickDuration = 10 * time.Millisecond
|
||||
lw.charPerSecond = 0
|
||||
lw.loopStarted = make(chan struct{})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
lw.start(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
<-lw.loopStarted
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
<-done
|
||||
|
||||
body := w.Body.String()
|
||||
lines := countSSEMessages(body)
|
||||
if lines < 2 {
|
||||
t.Errorf("expected multiple SSE messages from periodic updates, got %d", lines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadingWriter_ReqStored(t *testing.T) {
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||
if lw.req != req {
|
||||
t.Fatal("req not stored")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLoadingPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
want bool
|
||||
}{
|
||||
{"/v1/chat/completions", true},
|
||||
{"/v1/chat/completions/extra", true},
|
||||
{"/v1/completions", false},
|
||||
{"/v1/embeddings", false},
|
||||
{"/health", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
if got := isLoadingPath(tt.path); got != tt.want {
|
||||
t.Errorf("isLoadingPath(%q) = %v, want %v", tt.path, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func countSSEMessages(s string) int {
|
||||
scanner := bufio.NewScanner(strings.NewReader(s))
|
||||
count := 0
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func extractStreamedContent(body string) string {
|
||||
var result strings.Builder
|
||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
jsonData := strings.TrimPrefix(line, "data: ")
|
||||
var msg struct {
|
||||
Choices []struct {
|
||||
Delta struct {
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
} `json:"delta"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(jsonData), &msg); err != nil {
|
||||
continue
|
||||
}
|
||||
if len(msg.Choices) > 0 {
|
||||
result.WriteString(msg.Choices[0].Delta.ReasoningContent)
|
||||
}
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
)
|
||||
|
||||
type Matrix struct {
|
||||
*baseRouter
|
||||
}
|
||||
|
||||
func NewMatrix(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Matrix, error) {
|
||||
mtx := conf.Routing.Router.Settings.Matrix
|
||||
if mtx == nil {
|
||||
return nil, fmt.Errorf("matrix router requires a matrix configuration")
|
||||
}
|
||||
|
||||
swapper := &matrixSwapper{
|
||||
solver: newMatrixSolver(mtx.ExpandedSets, mtx.ResolvedEvictCosts()),
|
||||
logger: proxylog,
|
||||
}
|
||||
|
||||
// Build a process for every model in the config. Any model can run alone
|
||||
// even if it is not part of a set; this mirrors proxy.NewMatrix.
|
||||
processes := make(map[string]process.Process, len(conf.Models))
|
||||
base, err := newBaseRouter("matrix", conf, processes, proxylog, swapper)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating base router: %w", err)
|
||||
}
|
||||
|
||||
for mid, modelCfg := range conf.Models {
|
||||
procLog := logmon.NewWriter(upstreamlog)
|
||||
p, err := process.New(base.procCtx, mid, modelCfg, procLog, proxylog)
|
||||
if err != nil {
|
||||
base.shutdownFn()
|
||||
base.procCancel()
|
||||
return nil, fmt.Errorf("creating process for %q: %w", mid, err)
|
||||
}
|
||||
processes[mid] = p
|
||||
}
|
||||
|
||||
r := &Matrix{baseRouter: base}
|
||||
go base.run()
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// matrixSwapper decides evictions by asking the matrix solver against the
|
||||
// running set the scheduler hands it.
|
||||
type matrixSwapper struct {
|
||||
solver *matrixSolver
|
||||
logger *logmon.Monitor
|
||||
}
|
||||
|
||||
func (p *matrixSwapper) EvictionFor(target string, running []string) []string {
|
||||
return p.solver.Solve(target, running).Evict
|
||||
}
|
||||
|
||||
func (p *matrixSwapper) OnSwapStart(target string, running []string) {
|
||||
result := p.solver.Solve(target, running)
|
||||
switch {
|
||||
case len(result.Evict) > 0:
|
||||
p.logger.Infof("matrix: model=%s set=%s dsl=%q evict=%v target=%v cost=%d",
|
||||
target, result.SetName, result.DSL, result.Evict, result.TargetSet, result.TotalCost)
|
||||
case len(running) == 0:
|
||||
p.logger.Infof("matrix: model=%s starting (no models running)", target)
|
||||
default:
|
||||
p.logger.Debugf("matrix: model=%s already running in set=%s dsl=%q", target, result.SetName, result.DSL)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
)
|
||||
|
||||
// matrixSolver contains pure swap-decision logic with no Process dependencies.
|
||||
// It is safe for concurrent reads after construction.
|
||||
type matrixSolver struct {
|
||||
expandedSets []config.ExpandedSet // all valid model combinations
|
||||
evictCosts map[string]int // real model name -> eviction cost (default 1)
|
||||
modelToSets map[string][]int // model name -> indices into expandedSets
|
||||
}
|
||||
|
||||
func newMatrixSolver(expandedSets []config.ExpandedSet, evictCosts map[string]int) *matrixSolver {
|
||||
modelToSets := make(map[string][]int)
|
||||
for i, es := range expandedSets {
|
||||
for _, model := range es.Models {
|
||||
modelToSets[model] = append(modelToSets[model], i)
|
||||
}
|
||||
}
|
||||
|
||||
return &matrixSolver{
|
||||
expandedSets: expandedSets,
|
||||
evictCosts: evictCosts,
|
||||
modelToSets: modelToSets,
|
||||
}
|
||||
}
|
||||
|
||||
// solveResult describes what the solver decided.
|
||||
type solveResult struct {
|
||||
Evict []string // running models that must be stopped
|
||||
TargetSet []string // the chosen set of models (for informational purposes)
|
||||
SetName string // name of the chosen set
|
||||
DSL string // original DSL expression for the chosen set
|
||||
TotalCost int // total eviction cost
|
||||
}
|
||||
|
||||
// Solve determines which models to evict when a model is requested.
|
||||
//
|
||||
// Algorithm:
|
||||
// 1. If requestedModel is already running, no eviction needed.
|
||||
// 2. Find all sets containing requestedModel.
|
||||
// 3. If no sets found, the model runs alone; evict all running models.
|
||||
// 4. For each candidate set, compute cost = sum of evict_costs for running
|
||||
// models NOT in that set.
|
||||
// 5. Pick lowest cost. Ties broken by definition order (index in expandedSets).
|
||||
// 6. Return models to evict and the chosen set.
|
||||
func (s *matrixSolver) Solve(requestedModel string, runningModels []string) solveResult {
|
||||
if slices.Contains(runningModels, requestedModel) {
|
||||
setName, dsl := s.findMatchingSet(requestedModel, runningModels)
|
||||
return solveResult{
|
||||
TargetSet: runningModels,
|
||||
SetName: setName,
|
||||
DSL: dsl,
|
||||
}
|
||||
}
|
||||
|
||||
candidateIndices := s.modelToSets[requestedModel]
|
||||
|
||||
// Model not in any set: runs alone, evict everything.
|
||||
if len(candidateIndices) == 0 {
|
||||
evict := make([]string, len(runningModels))
|
||||
copy(evict, runningModels)
|
||||
return solveResult{
|
||||
Evict: evict,
|
||||
TargetSet: []string{requestedModel},
|
||||
}
|
||||
}
|
||||
|
||||
bestCost := -1
|
||||
bestIdx := -1
|
||||
|
||||
for _, idx := range candidateIndices {
|
||||
setModels := s.expandedSets[idx].Models
|
||||
cost := 0
|
||||
for _, running := range runningModels {
|
||||
if !slices.Contains(setModels, running) {
|
||||
cost += s.evictCost(running)
|
||||
}
|
||||
}
|
||||
|
||||
if bestCost < 0 || cost < bestCost || (cost == bestCost && idx < bestIdx) {
|
||||
bestCost = cost
|
||||
bestIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
chosen := s.expandedSets[bestIdx]
|
||||
var evict []string
|
||||
for _, running := range runningModels {
|
||||
if !slices.Contains(chosen.Models, running) {
|
||||
evict = append(evict, running)
|
||||
}
|
||||
}
|
||||
|
||||
return solveResult{
|
||||
Evict: evict,
|
||||
TargetSet: chosen.Models,
|
||||
SetName: chosen.SetName,
|
||||
DSL: chosen.DSL,
|
||||
TotalCost: bestCost,
|
||||
}
|
||||
}
|
||||
|
||||
// findMatchingSet finds the expanded set that contains all running models.
|
||||
// Returns the set name and DSL, or empty strings if no match.
|
||||
func (s *matrixSolver) findMatchingSet(requestedModel string, runningModels []string) (string, string) {
|
||||
for _, idx := range s.modelToSets[requestedModel] {
|
||||
set := s.expandedSets[idx]
|
||||
allInSet := true
|
||||
for _, m := range runningModels {
|
||||
if !slices.Contains(set.Models, m) {
|
||||
allInSet = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allInSet {
|
||||
return set.SetName, set.DSL
|
||||
}
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func (s *matrixSolver) evictCost(model string) int {
|
||||
if cost, ok := s.evictCosts[model]; ok {
|
||||
return cost
|
||||
}
|
||||
return 1
|
||||
}
|
||||
@@ -0,0 +1,247 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
)
|
||||
|
||||
// newTestMatrix builds a Matrix router from supplied processes, bypassing
|
||||
// NewMatrix's call to process.New.
|
||||
func newTestMatrix(t *testing.T, conf config.Config, expanded []config.ExpandedSet, evictCosts map[string]int, processes map[string]process.Process) *Matrix {
|
||||
t.Helper()
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
swapper := &matrixSwapper{
|
||||
solver: newMatrixSolver(expanded, evictCosts),
|
||||
logger: logger,
|
||||
}
|
||||
base, err := newBaseRouter("matrix", conf, processes, logger, swapper)
|
||||
if err != nil {
|
||||
t.Fatalf("newBaseRouter: %v", err)
|
||||
}
|
||||
base.testProcessed = make(chan struct{}, 64)
|
||||
r := &Matrix{baseRouter: base}
|
||||
go base.run()
|
||||
t.Cleanup(func() {
|
||||
if !r.shuttingDown.Load() {
|
||||
_ = r.Shutdown(time.Second)
|
||||
}
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
func baseMatrixConfig() config.Config {
|
||||
return config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Matrix: &config.MatrixConfig{},
|
||||
}
|
||||
}
|
||||
|
||||
// TestMatrix_SwapEvictsConflicting verifies that loading a model triggers
|
||||
// eviction of running models that are not in any shared set with it.
|
||||
func TestMatrix_SwapEvictsConflicting(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
go a.Run(0) // park a Run goroutine so Stop has something to release
|
||||
|
||||
b := newFakeProcess("b")
|
||||
b.autoReady = true
|
||||
|
||||
// Two single-model sets: a and b never coexist, so loading b must evict a.
|
||||
expanded := []config.ExpandedSet{
|
||||
{SetName: "s_a", DSL: "a", Models: []string{"a"}},
|
||||
{SetName: "s_b", DSL: "b", Models: []string{"b"}},
|
||||
}
|
||||
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, newRequest("b"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 1 {
|
||||
t.Errorf("a.stopCalls=%d want 1", got)
|
||||
}
|
||||
if got := b.runCalls.Load(); got != 1 {
|
||||
t.Errorf("b.runCalls=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMatrix_CoexistInSet verifies that a model is not evicted when the target
|
||||
// shares a set with it (the fast path applies if the target is already ready).
|
||||
func TestMatrix_CoexistInSet(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
go a.Run(0)
|
||||
|
||||
b := newFakeProcess("b")
|
||||
b.autoReady = true
|
||||
|
||||
// Both fit in s_ab, so b's swap should not stop a.
|
||||
expanded := []config.ExpandedSet{
|
||||
{SetName: "s_ab", DSL: "a & b", Models: []string{"a", "b"}},
|
||||
}
|
||||
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, newRequest("b"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 0 {
|
||||
t.Errorf("a.stopCalls=%d want 0 (coexists with b)", got)
|
||||
}
|
||||
if got := b.runCalls.Load(); got != 1 {
|
||||
t.Errorf("b.runCalls=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMatrix_CoexistingSetParallel verifies that two models that share an
|
||||
// expanded set load in parallel — the solver returns empty Evict for both,
|
||||
// the collision predicate clears them, and both swaps run together.
|
||||
func TestMatrix_CoexistingSetParallel(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
|
||||
expanded := []config.ExpandedSet{
|
||||
{SetName: "s_ab", DSL: "a & b", Models: []string{"a", "b"}},
|
||||
}
|
||||
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": pb})
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
r.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, r.testProcessed, 1)
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
r.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, r.testProcessed, 1)
|
||||
|
||||
<-a.runStarted
|
||||
<-pb.runStarted
|
||||
|
||||
a.markReady()
|
||||
pb.markReady()
|
||||
|
||||
for i, ch := range []chan struct{}{done1, done2} {
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("request %d did not complete", i)
|
||||
}
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 0 {
|
||||
t.Errorf("a.stopCalls=%d want 0 (coexists with b)", got)
|
||||
}
|
||||
if got := pb.stopCalls.Load(); got != 0 {
|
||||
t.Errorf("b.stopCalls=%d want 0 (coexists with a)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMatrix_IncompatibleQueues verifies that the second request for a model
|
||||
// that cannot coexist with the in-flight first model queues until the first
|
||||
// completes, and then evicts it. This exercises the scheduler folding in-flight
|
||||
// swap targets into the running set it hands the swapper.
|
||||
func TestMatrix_IncompatibleQueues(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
|
||||
expanded := []config.ExpandedSet{
|
||||
{SetName: "s_a", DSL: "a", Models: []string{"a"}},
|
||||
{SetName: "s_b", DSL: "b", Models: []string{"b"}},
|
||||
}
|
||||
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": pb})
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
r.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, r.testProcessed, 1)
|
||||
|
||||
// B arrives before A transitions to StateStarting. The running set the
|
||||
// scheduler builds includes A (an in-flight swap target), so the solver
|
||||
// returns evict=[a] and collidesWith forces B to queue.
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
r.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, r.testProcessed, 1)
|
||||
|
||||
if got := pb.runCalls.Load(); got != 0 {
|
||||
t.Errorf("b started in parallel: runCalls=%d want 0", got)
|
||||
}
|
||||
|
||||
<-a.runStarted
|
||||
a.markReady()
|
||||
waitProcessed(t, r.testProcessed, 1) // swapDone(a) → b promoted, evicts a
|
||||
<-pb.runStarted
|
||||
pb.markReady()
|
||||
|
||||
for i, ch := range []chan struct{}{done1, done2} {
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("request %d did not complete", i)
|
||||
}
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 1 {
|
||||
t.Errorf("a.stopCalls=%d want 1 (b's swap must stop a)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMatrixSolver_TieBreakDefinitionOrder pins the solver's tie-break rule:
|
||||
// when multiple candidate sets have equal eviction cost, the earlier-defined
|
||||
// set wins.
|
||||
func TestMatrixSolver_TieBreakDefinitionOrder(t *testing.T) {
|
||||
expanded := []config.ExpandedSet{
|
||||
{SetName: "first", DSL: "a & b", Models: []string{"a", "b"}},
|
||||
{SetName: "second", DSL: "a & c", Models: []string{"a", "c"}},
|
||||
}
|
||||
s := newMatrixSolver(expanded, nil)
|
||||
|
||||
// No models running, request "a": both sets have cost 0 and contain a.
|
||||
// Definition order: "first" wins.
|
||||
result := s.Solve("a", nil)
|
||||
if result.SetName != "first" {
|
||||
t.Errorf("SetName=%q want %q", result.SetName, "first")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMatrixSolver_EvictCostsPreferred verifies that higher evict costs steer
|
||||
// the solver toward a cheaper set.
|
||||
func TestMatrixSolver_EvictCostsPreferred(t *testing.T) {
|
||||
// b is expensive to evict; c is cheap. Request "a" with both b and c
|
||||
// running. The solver should pick the set that keeps b.
|
||||
expanded := []config.ExpandedSet{
|
||||
{SetName: "a_with_c", DSL: "a & c", Models: []string{"a", "c"}}, // would evict b (cost 10)
|
||||
{SetName: "a_with_b", DSL: "a & b", Models: []string{"a", "b"}}, // would evict c (cost 1)
|
||||
}
|
||||
s := newMatrixSolver(expanded, map[string]int{"b": 10, "c": 1})
|
||||
|
||||
result := s.Solve("a", []string{"b", "c"})
|
||||
if result.SetName != "a_with_b" {
|
||||
t.Errorf("SetName=%q want %q (keep expensive b)", result.SetName, "a_with_b")
|
||||
}
|
||||
if len(result.Evict) != 1 || result.Evict[0] != "c" {
|
||||
t.Errorf("Evict=%v want [c]", result.Evict)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,188 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
type peerMember struct {
|
||||
peerID string
|
||||
reverseProxy *httputil.ReverseProxy
|
||||
apiKey string
|
||||
}
|
||||
|
||||
type Peer struct {
|
||||
cfg config.Config
|
||||
logger *logmon.Monitor
|
||||
peers map[string]*peerMember
|
||||
|
||||
shutdownCtx context.Context
|
||||
shutdownFn context.CancelFunc
|
||||
shuttingDown atomic.Bool
|
||||
inflight sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewPeer(cfg config.Config, logger *logmon.Monitor) (*Peer, error) {
|
||||
peers := cfg.Peers
|
||||
modelMap := make(map[string]*peerMember)
|
||||
|
||||
peerIDs := make([]string, 0, len(peers))
|
||||
for peerID := range peers {
|
||||
peerIDs = append(peerIDs, peerID)
|
||||
}
|
||||
sort.Strings(peerIDs)
|
||||
|
||||
for _, peerID := range peerIDs {
|
||||
peer := peers[peerID]
|
||||
|
||||
peerTransport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: time.Duration(peer.Timeouts.Connect) * time.Second,
|
||||
KeepAlive: time.Duration(peer.Timeouts.KeepAlive) * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: time.Duration(peer.Timeouts.TLSHandshake) * time.Second,
|
||||
ResponseHeaderTimeout: time.Duration(peer.Timeouts.ResponseHeader) * time.Second,
|
||||
ExpectContinueTimeout: time.Duration(peer.Timeouts.ExpectContinue) * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second,
|
||||
}
|
||||
|
||||
reverseProxy := &httputil.ReverseProxy{
|
||||
Transport: peerTransport,
|
||||
Rewrite: func(r *httputil.ProxyRequest) {
|
||||
r.SetURL(peer.ProxyURL)
|
||||
r.Out.Host = r.Out.URL.Host
|
||||
},
|
||||
}
|
||||
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||
resp.Header.Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
logger.Warnf("peer %s: proxy error: %v", peerID, err)
|
||||
errMsg := fmt.Sprintf("peer proxy error: %v", err)
|
||||
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
|
||||
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
|
||||
}
|
||||
http.Error(w, errMsg, http.StatusBadGateway)
|
||||
}
|
||||
|
||||
pp := &peerMember{
|
||||
peerID: peerID,
|
||||
reverseProxy: reverseProxy,
|
||||
apiKey: peer.ApiKey,
|
||||
}
|
||||
|
||||
for _, modelID := range peer.Models {
|
||||
if _, found := modelMap[modelID]; found {
|
||||
logger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
|
||||
continue
|
||||
}
|
||||
modelMap[modelID] = pp
|
||||
}
|
||||
}
|
||||
|
||||
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
|
||||
|
||||
return &Peer{
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
peers: modelMap,
|
||||
shutdownCtx: shutdownCtx,
|
||||
shutdownFn: shutdownFn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *Peer) Handles(model string) bool {
|
||||
_, ok := r.peers[model]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *Peer) Shutdown(timeout time.Duration) error {
|
||||
if !r.shuttingDown.CompareAndSwap(false, true) {
|
||||
return fmt.Errorf("shutdown already in progress")
|
||||
}
|
||||
|
||||
if timeout == 0 {
|
||||
r.shutdownFn()
|
||||
r.inflight.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
r.inflight.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-time.After(timeout):
|
||||
r.shutdownFn()
|
||||
r.inflight.Wait()
|
||||
return fmt.Errorf("peer shutdown timed out after %v", timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Peer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
if r.shuttingDown.Load() {
|
||||
shared.SendError(w, req, fmt.Errorf("peer proxy is shutting down"))
|
||||
return
|
||||
}
|
||||
r.inflight.Add(1)
|
||||
defer r.inflight.Done()
|
||||
|
||||
data, err := shared.FetchContext(req, r.cfg)
|
||||
if err != nil {
|
||||
shared.SendError(w, req, err)
|
||||
return
|
||||
}
|
||||
|
||||
pp, found := r.peers[data.ModelID]
|
||||
if !found {
|
||||
r.logger.Warnf("peer model not found: %s", data.ModelID)
|
||||
shared.SendError(w, req, ErrNoPeerModelFound)
|
||||
return
|
||||
}
|
||||
|
||||
r.logger.Debugf("peer: routing model %s to peer %s", data.ModelID, pp.peerID)
|
||||
|
||||
if pp.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+pp.apiKey)
|
||||
req.Header.Set("x-api-key", pp.apiKey)
|
||||
}
|
||||
|
||||
// Cancel the proxy request when the client disconnects or shutdown times out.
|
||||
// AfterFunc links both parent contexts to our child without a goroutine leak.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
stopReq := context.AfterFunc(req.Context(), cancel)
|
||||
stopShutdown := context.AfterFunc(r.shutdownCtx, cancel)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
pp.reverseProxy.ServeHTTP(w, req)
|
||||
|
||||
stopShutdown()
|
||||
stopReq()
|
||||
cancel()
|
||||
}
|
||||
@@ -0,0 +1,612 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
var testLogger = logmon.NewWriter(os.Stdout)
|
||||
|
||||
func init() {
|
||||
testLogger.SetLogLevel(logmon.LevelWarn)
|
||||
}
|
||||
|
||||
func TestNewPeer_EmptyPeers(t *testing.T) {
|
||||
pr, err := NewPeer(config.Config{}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if pr == nil {
|
||||
t.Fatal("expected non-nil Peer")
|
||||
}
|
||||
if len(pr.peers) != 0 {
|
||||
t.Fatalf("expected empty peers map, got %d entries", len(pr.peers))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPeer_SinglePeer(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "test-key",
|
||||
Models: []string{"model-a", "model-b"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(pr.peers) != 2 {
|
||||
t.Fatalf("expected 2 entries, got %d", len(pr.peers))
|
||||
}
|
||||
if _, ok := pr.peers["model-a"]; !ok {
|
||||
t.Error("expected model-a to be mapped")
|
||||
}
|
||||
if _, ok := pr.peers["model-b"]; !ok {
|
||||
t.Error("expected model-b to be mapped")
|
||||
}
|
||||
if _, ok := pr.peers["model-c"]; ok {
|
||||
t.Error("expected model-c to not be mapped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPeer_MultiplePeers(t *testing.T) {
|
||||
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL1,
|
||||
Models: []string{"model-a", "model-b"},
|
||||
},
|
||||
"peer2": config.PeerConfig{
|
||||
Proxy: "http://peer2.example.com:8080",
|
||||
ProxyURL: proxyURL2,
|
||||
Models: []string{"model-c", "model-d"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(pr.peers) != 4 {
|
||||
t.Fatalf("expected 4 entries, got %d", len(pr.peers))
|
||||
}
|
||||
for _, m := range []string{"model-a", "model-b", "model-c", "model-d"} {
|
||||
if _, ok := pr.peers[m]; !ok {
|
||||
t.Errorf("expected %s to be mapped", m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPeer_DuplicateModel(t *testing.T) {
|
||||
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"alpha-peer": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL1,
|
||||
Models: []string{"duplicate-model"},
|
||||
},
|
||||
"beta-peer": config.PeerConfig{
|
||||
Proxy: "http://peer2.example.com:8080",
|
||||
ProxyURL: proxyURL2,
|
||||
Models: []string{"duplicate-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(pr.peers) != 1 {
|
||||
t.Fatalf("expected 1 entry for duplicate model, got %d", len(pr.peers))
|
||||
}
|
||||
if _, ok := pr.peers["duplicate-model"]; !ok {
|
||||
t.Error("expected duplicate-model to be mapped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_Success(t *testing.T) {
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("response from peer"))
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d", w.Code)
|
||||
}
|
||||
if w.Body.String() != "response from peer" {
|
||||
t.Errorf("expected 'response from peer', got %q", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_ModelNotFoundInContext(t *testing.T) {
|
||||
pr, err := NewPeer(config.Config{}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_PeerModelNotFound(t *testing.T) {
|
||||
pr, err := NewPeer(config.Config{}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "nonexistent-model", ModelID: "nonexistent-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_ApiKeyInjection(t *testing.T) {
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "secret-api-key",
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
|
||||
if receivedAuthHeader != "Bearer secret-api-key" {
|
||||
t.Errorf("expected 'Bearer secret-api-key', got %q", receivedAuthHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_NoApiKey(t *testing.T) {
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "",
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
|
||||
if receivedAuthHeader != "" {
|
||||
t.Errorf("expected no auth header, got %q", receivedAuthHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_HostHeaderSet(t *testing.T) {
|
||||
var receivedHost string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedHost = r.Host
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
|
||||
if !strings.HasPrefix(receivedHost, "127.0.0.1:") {
|
||||
t.Errorf("expected Host to start with '127.0.0.1:', got %q", receivedHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_SSEHeaderModification(t *testing.T) {
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
|
||||
if w.Header().Get("X-Accel-Buffering") != "no" {
|
||||
t.Errorf("expected X-Accel-Buffering=no, got %q", w.Header().Get("X-Accel-Buffering"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_ShutdownRejectsNewRequests(t *testing.T) {
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = pr.Shutdown(0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "shutting down") {
|
||||
t.Errorf("expected 'shutting down' in body, got %q", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_WaitsForInflightDuringShutdown(t *testing.T) {
|
||||
started := make(chan struct{})
|
||||
released := make(chan struct{})
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
close(started)
|
||||
<-released
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
w := httptest.NewRecorder()
|
||||
pr.ServeHTTP(w, req)
|
||||
}()
|
||||
|
||||
<-started
|
||||
|
||||
shutdownDone := make(chan error, 1)
|
||||
go func() {
|
||||
shutdownDone <- pr.Shutdown(500 * time.Millisecond)
|
||||
}()
|
||||
|
||||
// Shutdown should be waiting on inflight. If it finished already something is wrong.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
select {
|
||||
case err := <-shutdownDone:
|
||||
t.Errorf("shutdown completed before inflight finished: %v", err)
|
||||
default:
|
||||
}
|
||||
|
||||
close(released)
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
case err := <-shutdownDone:
|
||||
if err != nil {
|
||||
t.Errorf("shutdown errored after inflight completed: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("shutdown did not complete after inflight finished")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_ShutdownTimeoutCancelsInflight(t *testing.T) {
|
||||
started := make(chan struct{})
|
||||
released := make(chan struct{})
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
close(started)
|
||||
<-released
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
w := httptest.NewRecorder()
|
||||
pr.ServeHTTP(w, req)
|
||||
}()
|
||||
|
||||
<-started
|
||||
|
||||
err = pr.Shutdown(100 * time.Millisecond)
|
||||
if err == nil {
|
||||
t.Error("expected timeout error from shutdown")
|
||||
}
|
||||
|
||||
close(released)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestPeer_ShutdownMultiple(t *testing.T) {
|
||||
pr, err := NewPeer(config.Config{}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = pr.Shutdown(0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = pr.Shutdown(0)
|
||||
if err == nil {
|
||||
t.Error("expected error on second shutdown")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "already in progress") {
|
||||
t.Errorf("expected 'already in progress', got %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_ModelExtractedFromBody(t *testing.T) {
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("ok"))
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"extracted-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
body := strings.NewReader(`{"model":"extracted-model","prompt":"hello"}`)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_ContextOverridesBodyModel(t *testing.T) {
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("ok"))
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"context-model"},
|
||||
},
|
||||
"peer2": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"body-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
body := strings.NewReader(`{"model":"body-model","prompt":"hello"}`)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "context-model", ModelID: "context-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPeer_CustomTimeouts(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://localhost:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"test-peer": config.PeerConfig{
|
||||
Proxy: "http://localhost:8080",
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"model1"},
|
||||
Timeouts: config.TimeoutsConfig{
|
||||
Connect: 45,
|
||||
ResponseHeader: 300,
|
||||
TLSHandshake: 15,
|
||||
ExpectContinue: 2,
|
||||
IdleConn: 120,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
member, ok := pr.peers["model1"]
|
||||
if !ok {
|
||||
t.Fatal("expected model1 to be mapped")
|
||||
}
|
||||
|
||||
transport, ok := member.reverseProxy.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatal("expected Transport to be *http.Transport")
|
||||
}
|
||||
|
||||
if transport.ResponseHeaderTimeout != 300*time.Second {
|
||||
t.Errorf("expected ResponseHeaderTimeout=%v, got %v", 300*time.Second, transport.ResponseHeaderTimeout)
|
||||
}
|
||||
if transport.TLSHandshakeTimeout != 15*time.Second {
|
||||
t.Errorf("expected TLSHandshakeTimeout=%v, got %v", 15*time.Second, transport.TLSHandshakeTimeout)
|
||||
}
|
||||
if transport.ExpectContinueTimeout != 2*time.Second {
|
||||
t.Errorf("expected ExpectContinueTimeout=%v, got %v", 2*time.Second, transport.ExpectContinueTimeout)
|
||||
}
|
||||
if transport.IdleConnTimeout != 120*time.Second {
|
||||
t.Errorf("expected IdleConnTimeout=%v, got %v", 120*time.Second, transport.IdleConnTimeout)
|
||||
}
|
||||
if !transport.ForceAttemptHTTP2 {
|
||||
t.Error("expected ForceAttemptHTTP2 to be true")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoRouterFound = shared.ErrNoRouterFound
|
||||
ErrNoPeerModelFound = shared.ErrNoPeerModelFound
|
||||
ErrNoLocalModelFound = shared.ErrNoLocalModelFound
|
||||
)
|
||||
|
||||
type Router interface {
|
||||
// Shutdown blocks until the router has shutdown returning nil
|
||||
// when the router has shutdown successfully.
|
||||
//
|
||||
// timeout controls how long to wait for inflight requests to finish. After
|
||||
// the timeout all inflight requests will be cancelled.
|
||||
Shutdown(timeout time.Duration) error
|
||||
|
||||
// ServeHTTP implements the http.Handler and requests coming in will
|
||||
// trigger any model swapping and routing logic.
|
||||
ServeHTTP(http.ResponseWriter, *http.Request)
|
||||
|
||||
// Handles reports whether this router can serve requests for the given model.
|
||||
Handles(model string) bool
|
||||
}
|
||||
|
||||
// LocalRouter is a Router backed by local processes whose state can be
|
||||
// inspected and which can be individually stopped. Peer routers, which only
|
||||
// forward to remote hosts, do not implement it.
|
||||
type LocalRouter interface {
|
||||
Router
|
||||
|
||||
// RunningModels returns the current state of every process that is not
|
||||
// stopped or shut down, keyed by model ID.
|
||||
RunningModels() map[string]process.ProcessState
|
||||
|
||||
// Unload stops the named models, or every running model when none are
|
||||
// named. It blocks until each targeted process has stopped.
|
||||
Unload(timeout time.Duration, models ...string)
|
||||
|
||||
// ProcessLogger returns the log monitor for the named model's process.
|
||||
// modelID must be a real (non-alias) config key. Returns false when the
|
||||
// model is not known to this router.
|
||||
ProcessLogger(modelID string) (*logmon.Monitor, bool)
|
||||
}
|
||||
@@ -0,0 +1,489 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// defaultConcurrencyLimit caps simultaneous in-flight requests per model when
|
||||
// the model config leaves concurrencyLimit unset.
|
||||
const defaultConcurrencyLimit = 10
|
||||
|
||||
// activeSwap tracks one in-flight swap and the callers waiting on it.
|
||||
type activeSwap struct {
|
||||
modelID string
|
||||
evict []string
|
||||
waiters []HandlerReq
|
||||
}
|
||||
|
||||
// FIFO is the default scheduler. Requests are handled in a first-in, first-out order.
|
||||
// To reduce swapping requests for a model that is already running will be handled
|
||||
// immediately by the running process.
|
||||
//
|
||||
// Requests into this schedule are handled like this:
|
||||
//
|
||||
// A B C A B C --> A A B B C C
|
||||
//
|
||||
// The strategy is simple and reduces the number of swaps required.
|
||||
type FIFO struct {
|
||||
name string
|
||||
logger *logmon.Monitor
|
||||
planner Swapper
|
||||
cfg config.FifoConfig
|
||||
effects Effects
|
||||
|
||||
limits map[string]int
|
||||
active map[string]*activeSwap
|
||||
inFlight map[string]int
|
||||
queued []HandlerReq
|
||||
}
|
||||
|
||||
// NewFIFO builds a FIFO scheduler. Per-model concurrency limits are derived
|
||||
// from models: each model's ConcurrencyLimit overrides defaultConcurrencyLimit
|
||||
// when set to a value greater than zero.
|
||||
func NewFIFO(name string, logger *logmon.Monitor, planner Swapper, cfg config.FifoConfig, models map[string]config.ModelConfig, eff Effects) *FIFO {
|
||||
limits := make(map[string]int, len(models))
|
||||
for id, mc := range models {
|
||||
limit := defaultConcurrencyLimit
|
||||
if mc.ConcurrencyLimit > 0 {
|
||||
limit = mc.ConcurrencyLimit
|
||||
}
|
||||
limits[id] = limit
|
||||
}
|
||||
|
||||
return &FIFO{
|
||||
name: name,
|
||||
logger: logger,
|
||||
planner: planner,
|
||||
cfg: cfg,
|
||||
effects: eff,
|
||||
limits: limits,
|
||||
active: make(map[string]*activeSwap),
|
||||
inFlight: make(map[string]int),
|
||||
}
|
||||
}
|
||||
|
||||
// OnRequest decides what to do with one incoming ServeHTTP request. It never
|
||||
// blocks indefinitely: any work that has to wait (starting a process, stopping
|
||||
// siblings, waiting for ready) is deferred to a swap goroutine and reported back
|
||||
// via OnSwapDone.
|
||||
//
|
||||
// The decision tree, in order:
|
||||
//
|
||||
// 1. Unknown model — respond with ErrModelNotFound and move on.
|
||||
// 2. A swap to the same model is already in flight — attach this waiter so
|
||||
// one swap serves all callers that asked for the same model.
|
||||
// 3. Fast path — the target process is already ready, the planner sees
|
||||
// nothing to evict, and no in-flight swap is evicting it. Hand back its
|
||||
// ServeHTTP immediately.
|
||||
// 4. Would collide with an in-flight swap (we'd stop their target, or they're
|
||||
// stopping us) — park in the queue for OnSwapDone to drain.
|
||||
// 5. Would evict a process that is still handling requests — park in the
|
||||
// queue. OnServeDone will retry when the busy process drains.
|
||||
// 6. Otherwise — start a new swap. This may run in parallel with other active
|
||||
// swaps when their evict sets don't intersect.
|
||||
func (s *FIFO) OnRequest(req HandlerReq) {
|
||||
// (1) Unknown model.
|
||||
state, ok := s.effects.ModelState(req.Model)
|
||||
if !ok {
|
||||
s.logger.Debugf("%s: model %s not handled by this router", s.name, req.Model)
|
||||
s.effects.GrantError(req, ErrModelNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// (2) Join an in-flight swap for the same model.
|
||||
if sw, ok := s.active[req.Model]; ok {
|
||||
s.logger.Debugf("%s: joining in-flight swap for model %s (%d waiters)", s.name, req.Model, len(sw.waiters)+1)
|
||||
sw.waiters = append(sw.waiters, req)
|
||||
return
|
||||
}
|
||||
|
||||
running := s.runningSet(req.Model)
|
||||
evict := s.planner.EvictionFor(req.Model, running)
|
||||
|
||||
// (3) Fast path: ready, nothing to evict, and nobody is evicting us.
|
||||
if state == process.StateReady && len(evict) == 0 && !collidesWith(req.Model, evict, s.active) {
|
||||
s.logger.Debugf("%s: fast-path serving model %s (already ready)", s.name, req.Model)
|
||||
s.grantHandler(req, req.Model)
|
||||
return
|
||||
}
|
||||
|
||||
// (4) Collision with an in-flight swap — queue.
|
||||
if collidesWith(req.Model, evict, s.active) {
|
||||
s.logger.Debugf("%s: queuing request for model %s (collides with in-flight swap)", s.name, req.Model)
|
||||
s.enqueue(req)
|
||||
return
|
||||
}
|
||||
|
||||
// (5) Would evict a busy process — queue until it drains.
|
||||
if conflictsWithInFlight(evict, s.inFlight) {
|
||||
s.logger.Debugf("%s: queuing request for model %s (would evict in-flight process)", s.name, req.Model)
|
||||
s.enqueue(req)
|
||||
return
|
||||
}
|
||||
|
||||
// (6) Start a new (possibly parallel) swap.
|
||||
s.logger.Debugf("%s: starting swap for model %s, evicting %v", s.name, req.Model, evict)
|
||||
s.startSwap(req, evict, running)
|
||||
}
|
||||
|
||||
// OnCancel removes a request whose client has disconnected from the queue and
|
||||
// from every in-flight swap's waiters. If the request was the sole waiter of an
|
||||
// active swap, the swap goroutine is left to complete on its own — OnSwapDone
|
||||
// will find no waiters and simply clean up. This prevents drainQueue from ever
|
||||
// starting a model load for a caller that is no longer there.
|
||||
func (s *FIFO) OnCancel(req HandlerReq) {
|
||||
removed := false
|
||||
|
||||
// Prune from the queue.
|
||||
if len(s.queued) > 0 {
|
||||
kept := s.queued[:0]
|
||||
for _, q := range s.queued {
|
||||
if q.Respond == req.Respond {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
kept = append(kept, q)
|
||||
}
|
||||
s.queued = kept
|
||||
}
|
||||
|
||||
// Prune from any active swap's waiters.
|
||||
for _, sw := range s.active {
|
||||
filtered := sw.waiters[:0]
|
||||
for _, w := range sw.waiters {
|
||||
if w.Respond == req.Respond {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, w)
|
||||
}
|
||||
sw.waiters = filtered
|
||||
}
|
||||
|
||||
if removed {
|
||||
s.logger.Debugf("%s: cancelled request for model %s pruned from scheduler", s.name, req.Model)
|
||||
broadcastQueuePositions(s.queued)
|
||||
}
|
||||
}
|
||||
|
||||
// OnSwapDone fans the result out to every waiter that joined this swap, removes
|
||||
// the swap from the active map, then walks the queue once, promoting any items
|
||||
// that no longer collide with the remaining active set. FIFO order is preserved:
|
||||
// items still blocked stay in place.
|
||||
func (s *FIFO) OnSwapDone(ev SwapDone) {
|
||||
sw, ok := s.active[ev.ModelID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(s.active, ev.ModelID)
|
||||
|
||||
for _, w := range sw.waiters {
|
||||
if ev.Err != nil {
|
||||
s.effects.GrantError(w, ev.Err)
|
||||
} else {
|
||||
s.grantHandler(w, ev.ModelID)
|
||||
}
|
||||
}
|
||||
|
||||
s.drainQueue()
|
||||
}
|
||||
|
||||
// OnServeDone decrements the per-model in-flight count and, when that drops to
|
||||
// zero, retries the queue: requests whose swap was deferred because they would
|
||||
// have evicted this (now-idle) process can now proceed.
|
||||
func (s *FIFO) OnServeDone(ev ServeDoneEvent) {
|
||||
s.inFlight[ev.ModelID]--
|
||||
if s.inFlight[ev.ModelID] <= 0 {
|
||||
delete(s.inFlight, ev.ModelID)
|
||||
s.drainQueue()
|
||||
}
|
||||
}
|
||||
|
||||
// OnUnload reconciles router-owned state with the impending Stop, performs the
|
||||
// Stop (synchronously, via Effects) so callers of Unload remain blocked until
|
||||
// each targeted process has exited, then drains the queue.
|
||||
func (s *FIFO) OnUnload(targets []string, timeout time.Duration) {
|
||||
unloadErr := fmt.Errorf("%s: model unloaded", s.name)
|
||||
|
||||
targetSet := make(map[string]bool, len(targets))
|
||||
for _, id := range targets {
|
||||
targetSet[id] = true
|
||||
}
|
||||
|
||||
// Release waiters of any in-flight swap whose target is being unloaded.
|
||||
// The swap goroutine itself is left to finish on its own; when its
|
||||
// SwapDone arrives, OnSwapDone will find no entry in active and drop it.
|
||||
for id := range targetSet {
|
||||
sw, ok := s.active[id]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, w := range sw.waiters {
|
||||
s.effects.GrantError(w, unloadErr)
|
||||
}
|
||||
delete(s.active, id)
|
||||
}
|
||||
|
||||
// Drop queued requests addressed to unloaded models. Requests for other
|
||||
// models stay queued and may benefit from drainQueue at the end.
|
||||
if len(s.queued) > 0 {
|
||||
kept := s.queued[:0]
|
||||
for _, w := range s.queued {
|
||||
if targetSet[w.Model] {
|
||||
s.effects.GrantError(w, unloadErr)
|
||||
continue
|
||||
}
|
||||
kept = append(kept, w)
|
||||
}
|
||||
s.queued = kept
|
||||
}
|
||||
|
||||
// Stop the targeted processes. Done synchronously so Unload's caller can
|
||||
// rely on "after Unload returns, the process is stopped". inFlight is
|
||||
// intentionally NOT cleared here: each dying handler will fire its tracked
|
||||
// serve and reach OnServeDone in the normal way.
|
||||
s.effects.StopProcesses(timeout, targets)
|
||||
|
||||
// Removing entries from active above may have unblocked queued requests
|
||||
// that previously collided with the now-cancelled swaps.
|
||||
s.drainQueue()
|
||||
}
|
||||
|
||||
// OnShutdown grants err to every waiter still held by the scheduler.
|
||||
func (s *FIFO) OnShutdown(err error) {
|
||||
for _, sw := range s.active {
|
||||
for _, w := range sw.waiters {
|
||||
s.effects.GrantError(w, err)
|
||||
}
|
||||
}
|
||||
for _, w := range s.queued {
|
||||
s.effects.GrantError(w, err)
|
||||
}
|
||||
}
|
||||
|
||||
// grantHandler hands the caller a tracked handler for modelID and, only if the
|
||||
// caller was still there to receive it, bumps the in-flight count. Incrementing
|
||||
// when the grant failed would strand the counter and block future evictions.
|
||||
// Requests that would exceed the model's concurrency limit are rejected with a
|
||||
// shared.NewConcurrencyLimitError (HTTP 429 with Retry-After).
|
||||
func (s *FIFO) grantHandler(req HandlerReq, modelID string) {
|
||||
if s.inFlight[modelID] >= s.limit(modelID) {
|
||||
s.effects.GrantError(req, shared.ConcurrencyLimitError{})
|
||||
return
|
||||
}
|
||||
|
||||
if err := shared.SetReqData(req.Ctx, "fifo_priority", strconv.Itoa(s.cfg.Priority[req.Model])); err != nil {
|
||||
s.logger.Debugf("failed to set fifo_priority metadata: %v", err)
|
||||
}
|
||||
|
||||
if s.effects.GrantServe(req, modelID) {
|
||||
s.inFlight[modelID]++
|
||||
}
|
||||
}
|
||||
|
||||
// limit returns the per-model concurrency cap, defaulting to
|
||||
// defaultConcurrencyLimit when the model has no explicit entry.
|
||||
func (s *FIFO) limit(modelID string) int {
|
||||
if l, ok := s.limits[modelID]; ok {
|
||||
return l
|
||||
}
|
||||
return defaultConcurrencyLimit
|
||||
}
|
||||
|
||||
// startSwap records the swap as active and launches it via Effects. running is
|
||||
// the set EvictionFor saw, forwarded to OnSwapStart so the planner logs against
|
||||
// the same picture it decided on.
|
||||
func (s *FIFO) startSwap(initial HandlerReq, evict, running []string) {
|
||||
s.active[initial.Model] = &activeSwap{
|
||||
modelID: initial.Model,
|
||||
evict: evict,
|
||||
waiters: []HandlerReq{initial},
|
||||
}
|
||||
s.planner.OnSwapStart(initial.Model, running)
|
||||
s.effects.StartSwap(initial.Model, evict)
|
||||
}
|
||||
|
||||
// enqueue inserts req into the queue in priority order: it goes just before the
|
||||
// first queued item whose priority is strictly lower, so higher-priority models
|
||||
// are serviced first while equal-priority requests keep their arrival (FIFO)
|
||||
// order. Priorities come from the FifoConfig; unlisted models default to 0.
|
||||
func (s *FIFO) enqueue(req HandlerReq) {
|
||||
p := s.cfg.Priority[req.Model]
|
||||
i := len(s.queued)
|
||||
for j, q := range s.queued {
|
||||
if s.cfg.Priority[q.Model] < p {
|
||||
i = j
|
||||
break
|
||||
}
|
||||
}
|
||||
s.queued = append(s.queued, HandlerReq{})
|
||||
copy(s.queued[i+1:], s.queued[i:])
|
||||
s.queued[i] = req
|
||||
broadcastQueuePositions(s.queued)
|
||||
}
|
||||
|
||||
// drainQueue walks the queued requests in order, re-running the OnRequest
|
||||
// decision tree against the (now smaller) active set. Items that can now start
|
||||
// or join become satisfied; items still blocked remain queued in original order
|
||||
// so they get another chance on the next swap completion.
|
||||
func (s *FIFO) drainQueue() {
|
||||
if len(s.queued) == 0 {
|
||||
return
|
||||
}
|
||||
pending := s.queued
|
||||
var remaining []HandlerReq
|
||||
for _, req := range pending {
|
||||
state, ok := s.effects.ModelState(req.Model)
|
||||
if !ok {
|
||||
s.effects.GrantError(req, ErrModelNotFound)
|
||||
continue
|
||||
}
|
||||
if sw, ok := s.active[req.Model]; ok {
|
||||
s.logger.Debugf("%s: queued request for model %s now joining in-flight swap", s.name, req.Model)
|
||||
sw.waiters = append(sw.waiters, req)
|
||||
continue
|
||||
}
|
||||
running := s.runningSet(req.Model)
|
||||
evict := s.planner.EvictionFor(req.Model, running)
|
||||
if state == process.StateReady && len(evict) == 0 && !collidesWith(req.Model, evict, s.active) {
|
||||
s.logger.Debugf("%s: queued request for model %s now served fast-path", s.name, req.Model)
|
||||
s.grantHandler(req, req.Model)
|
||||
continue
|
||||
}
|
||||
if collidesWith(req.Model, evict, s.active) {
|
||||
remaining = append(remaining, req)
|
||||
continue
|
||||
}
|
||||
if conflictsWithInFlight(evict, s.inFlight) {
|
||||
remaining = append(remaining, req)
|
||||
continue
|
||||
}
|
||||
s.logger.Debugf("%s: queued request for model %s now starting swap, evicting %v", s.name, req.Model, evict)
|
||||
s.startSwap(req, evict, running)
|
||||
}
|
||||
s.queued = remaining
|
||||
broadcastQueuePositions(s.queued)
|
||||
}
|
||||
|
||||
// runningSet is the live model set handed to the Swapper: every process the
|
||||
// baseRouter reports as running, unioned with the targets of in-flight swaps
|
||||
// (excluding excludeActive, the model whose own swap is being decided — its
|
||||
// in-flight entry must not count as "already running"). The result is sorted so
|
||||
// eviction decisions derived from it are deterministic.
|
||||
func (s *FIFO) runningSet(excludeActive string) []string {
|
||||
seen := make(map[string]struct{})
|
||||
var out []string
|
||||
add := func(id string) {
|
||||
if _, dup := seen[id]; dup {
|
||||
return
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
for id := range s.effects.RunningModels() {
|
||||
add(id)
|
||||
}
|
||||
for _, id := range activeTargets(s.active, excludeActive) {
|
||||
add(id)
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// activeTargets returns the IDs of every in-flight swap target except exclude.
|
||||
// The planner uses this to account for models committed to but not yet reflected
|
||||
// in process state.
|
||||
func activeTargets(active map[string]*activeSwap, exclude string) []string {
|
||||
if len(active) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(active))
|
||||
for id := range active {
|
||||
if id == exclude {
|
||||
continue
|
||||
}
|
||||
out = append(out, id)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// collidesWith reports whether a new swap with this target and evict set can
|
||||
// safely run alongside the currently active swaps. Same-target callers should
|
||||
// JOIN (handled before this) — they do not collide with themselves.
|
||||
func collidesWith(target string, evict []string, active map[string]*activeSwap) bool {
|
||||
for id, sw := range active {
|
||||
if id == target {
|
||||
continue
|
||||
}
|
||||
if containsString(evict, id) {
|
||||
return true
|
||||
}
|
||||
if containsString(sw.evict, target) {
|
||||
return true
|
||||
}
|
||||
if slicesOverlap(evict, sw.evict) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// slicesOverlap reports whether xs and ys share any common element.
|
||||
func slicesOverlap(xs, ys []string) bool {
|
||||
for _, x := range xs {
|
||||
if containsString(ys, x) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// conflictsWithInFlight reports whether any model in evict is still handling
|
||||
// requests. Stopping a busy process would cancel its callers' connections, so
|
||||
// the scheduler defers the swap until those callers finish.
|
||||
func conflictsWithInFlight(evict []string, inFlight map[string]int) bool {
|
||||
for _, m := range evict {
|
||||
if inFlight[m] > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsString(xs []string, s string) bool {
|
||||
for _, x := range xs {
|
||||
if x == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// broadcastQueuePositions sends each queued request its current 1-indexed
|
||||
// position. Sends are non-blocking: if the channel is full, the old value is
|
||||
// drained first so the consumer always sees the latest position.
|
||||
func broadcastQueuePositions(queued []HandlerReq) {
|
||||
for i, req := range queued {
|
||||
pos := i + 1
|
||||
select {
|
||||
case req.PositionCh <- pos:
|
||||
default:
|
||||
select {
|
||||
case <-req.PositionCh:
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case req.PositionCh <- pos:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,779 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// FIFO methods all run on the router's single run-loop goroutine, so these
|
||||
// tests drive them directly and synchronously. A swap is "completed" by calling
|
||||
// OnSwapDone, a served request "finishes" by calling OnServeDone — exactly the
|
||||
// events the run loop would deliver. fakeEffects records every side-effect and
|
||||
// stubPlanner supplies a fixed eviction set per target.
|
||||
|
||||
// stubPlanner returns a fixed eviction list per target.
|
||||
type stubPlanner struct {
|
||||
evict map[string][]string
|
||||
}
|
||||
|
||||
func (s *stubPlanner) EvictionFor(target string, _ []string) []string {
|
||||
if s.evict == nil {
|
||||
return nil
|
||||
}
|
||||
return s.evict[target]
|
||||
}
|
||||
|
||||
func (s *stubPlanner) OnSwapStart(string, []string) {}
|
||||
|
||||
// grantRec is one GrantError / GrantServe call. err!=nil marks an error grant;
|
||||
// otherwise it is a serve grant and serve reports whether the caller received it.
|
||||
type grantRec struct {
|
||||
model string
|
||||
err error
|
||||
serve bool
|
||||
}
|
||||
|
||||
type startRec struct {
|
||||
model string
|
||||
evict []string
|
||||
}
|
||||
|
||||
type stopRec struct {
|
||||
timeout time.Duration
|
||||
ids []string
|
||||
}
|
||||
|
||||
// fakeEffects is an in-memory scheduler.Effects. Tests program process states
|
||||
// and GrantServe outcomes, then assert on the recorded calls.
|
||||
type fakeEffects struct {
|
||||
states map[string]process.ProcessState // model -> state; missing => not handled
|
||||
serveResult map[string]bool // GrantServe return per model (default true)
|
||||
lastServeReq HandlerReq
|
||||
|
||||
starts []startRec
|
||||
grants []grantRec
|
||||
stops []stopRec
|
||||
}
|
||||
|
||||
func newFakeEffects() *fakeEffects {
|
||||
return &fakeEffects{
|
||||
states: map[string]process.ProcessState{},
|
||||
serveResult: map[string]bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeEffects) ModelState(modelID string) (process.ProcessState, bool) {
|
||||
st, ok := f.states[modelID]
|
||||
return st, ok
|
||||
}
|
||||
|
||||
func (f *fakeEffects) RunningModels() map[string]process.ProcessState {
|
||||
out := make(map[string]process.ProcessState)
|
||||
for id, st := range f.states {
|
||||
if st == process.StateStopped || st == process.StateShutdown {
|
||||
continue
|
||||
}
|
||||
out[id] = st
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (f *fakeEffects) StartSwap(modelID string, evict []string) {
|
||||
f.starts = append(f.starts, startRec{model: modelID, evict: evict})
|
||||
}
|
||||
|
||||
func (f *fakeEffects) GrantError(req HandlerReq, err error) {
|
||||
f.grants = append(f.grants, grantRec{model: req.Model, err: err})
|
||||
}
|
||||
|
||||
func (f *fakeEffects) GrantServe(req HandlerReq, modelID string) bool {
|
||||
ok := true
|
||||
if v, set := f.serveResult[modelID]; set {
|
||||
ok = v
|
||||
}
|
||||
f.lastServeReq = req
|
||||
f.grants = append(f.grants, grantRec{model: modelID, serve: ok})
|
||||
return ok
|
||||
}
|
||||
|
||||
func (f *fakeEffects) StopProcesses(timeout time.Duration, ids []string) {
|
||||
f.stops = append(f.stops, stopRec{timeout: timeout, ids: ids})
|
||||
}
|
||||
|
||||
// served counts grants that handed modelID a handler and were received.
|
||||
func (f *fakeEffects) served(modelID string) int {
|
||||
n := 0
|
||||
for _, g := range f.grants {
|
||||
if g.err == nil && g.serve && g.model == modelID {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// errored counts error grants, optionally filtered by model ("" = any).
|
||||
func (f *fakeEffects) errored(model string) int {
|
||||
n := 0
|
||||
for _, g := range f.grants {
|
||||
if g.err != nil && (model == "" || g.model == model) {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// startsFor counts StartSwap calls for modelID.
|
||||
func (f *fakeEffects) startsFor(modelID string) int {
|
||||
n := 0
|
||||
for _, s := range f.starts {
|
||||
if s.model == modelID {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func newFIFO(planner Swapper, eff Effects) *FIFO {
|
||||
return NewFIFO("test", logmon.NewWriter(io.Discard), planner, config.FifoConfig{}, nil, eff)
|
||||
}
|
||||
|
||||
func req(model string) HandlerReq { return HandlerReq{Model: model} }
|
||||
|
||||
// reqCh creates a HandlerReq with a unique Respond channel so OnCancel can
|
||||
// identify it among queued requests and swap waiters.
|
||||
func reqCh(model string) HandlerReq {
|
||||
return HandlerReq{
|
||||
Model: model,
|
||||
Respond: make(chan HandlerResp, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_FastPath(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateReady
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
|
||||
if got := eff.startsFor("a"); got != 0 {
|
||||
t.Errorf("StartSwap calls=%d want 0 (fast path should not swap)", got)
|
||||
}
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Errorf("served(a)=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_GrantSetsPriorityMetadata(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateReady
|
||||
cfg := config.FifoConfig{Priority: map[string]int{"a": 7}}
|
||||
s := NewFIFO("test", logmon.NewWriter(io.Discard), &stubPlanner{}, cfg, nil, eff)
|
||||
|
||||
ctx := shared.SetContext(context.Background(), shared.ReqContextData{ModelID: "a", Metadata: make(map[string]string)})
|
||||
s.OnRequest(HandlerReq{Model: "a", Ctx: ctx})
|
||||
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Fatalf("served(a)=%d want 1", got)
|
||||
}
|
||||
data, ok := shared.ReadContext(eff.lastServeReq.Ctx)
|
||||
if !ok {
|
||||
t.Fatal("context data missing from granted request")
|
||||
}
|
||||
if data.Metadata["fifo_priority"] != "7" {
|
||||
t.Errorf("fifo_priority = %q, want 7", data.Metadata["fifo_priority"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_ModelNotFound(t *testing.T) {
|
||||
eff := newFakeEffects() // no states => model unknown
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
s.OnRequest(req("ghost"))
|
||||
|
||||
if got := len(eff.starts); got != 0 {
|
||||
t.Errorf("StartSwap calls=%d want 0", got)
|
||||
}
|
||||
if eff.errored("ghost") != 1 {
|
||||
t.Fatalf("want 1 error grant for ghost, grants=%+v", eff.grants)
|
||||
}
|
||||
if !errors.Is(eff.grants[0].err, ErrModelNotFound) {
|
||||
t.Errorf("err=%v want ErrModelNotFound", eff.grants[0].err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_OnDemandStartThenServe(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
if got := eff.startsFor("a"); got != 1 {
|
||||
t.Fatalf("StartSwap(a)=%d want 1", got)
|
||||
}
|
||||
if got := eff.served("a"); got != 0 {
|
||||
t.Errorf("served(a)=%d want 0 before swap completes", got)
|
||||
}
|
||||
|
||||
// Swap finishes, model is now ready.
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Errorf("served(a)=%d want 1 after swap done", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_JoinInFlightSwap(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // starts swap
|
||||
s.OnRequest(req("a")) // joins
|
||||
s.OnRequest(req("a")) // joins
|
||||
|
||||
if got := eff.startsFor("a"); got != 1 {
|
||||
t.Fatalf("StartSwap(a)=%d want 1 (all three share one swap)", got)
|
||||
}
|
||||
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
|
||||
if got := eff.served("a"); got != 3 {
|
||||
t.Errorf("served(a)=%d want 3 (one swap serves all waiters)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_SwapDoneError_FailsAllWaiters(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
s.OnRequest(req("a"))
|
||||
|
||||
s.OnSwapDone(SwapDone{ModelID: "a", Err: errors.New("boom")})
|
||||
|
||||
if eff.served("a") != 0 {
|
||||
t.Errorf("served(a)=%d want 0 on swap error", eff.served("a"))
|
||||
}
|
||||
if eff.errored("a") != 2 {
|
||||
t.Errorf("errored(a)=%d want 2 (both waiters fail)", eff.errored("a"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_QueueOnEvictionCollision covers a request whose target evicts the
|
||||
// model currently being swapped: it must queue until that swap finishes AND its
|
||||
// served request drains, because starting it would stop a busy process.
|
||||
func TestFIFO_QueueOnEvictionCollision(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
// Loading b evicts a.
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"b": {"a"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // StartSwap(a)
|
||||
s.OnRequest(req("b")) // collides with a's in-flight swap -> queue
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Fatalf("b started early: StartSwap(b)=%d want 0", got)
|
||||
}
|
||||
|
||||
// a becomes ready and is granted (now serving, inFlight[a]=1).
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Fatalf("b started while a is serving: StartSwap(b)=%d want 0", got)
|
||||
}
|
||||
|
||||
// a's request finishes -> a no longer in-flight -> b may now swap.
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
|
||||
if got := eff.startsFor("b"); got != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 after a drained", got)
|
||||
}
|
||||
if got := eff.starts[len(eff.starts)-1].evict; len(got) != 1 || got[0] != "a" {
|
||||
t.Errorf("b swap evict=%v want [a]", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_DisjointSwapsRunInParallel verifies two requests with
|
||||
// non-conflicting evict sets both start without waiting for each other.
|
||||
func TestFIFO_DisjointSwapsRunInParallel(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
s := newFIFO(&stubPlanner{}, eff) // empty evicts
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
s.OnRequest(req("b"))
|
||||
|
||||
if eff.startsFor("a") != 1 || eff.startsFor("b") != 1 {
|
||||
t.Fatalf("StartSwap a=%d b=%d want 1 each (parallel)", eff.startsFor("a"), eff.startsFor("b"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_OverlappingEvictSetsDoNotRunInParallel verifies two swaps with
|
||||
// different targets that evict the *same* model do not run concurrently: the
|
||||
// second must queue rather than double-evict the shared model. Neither target is
|
||||
// in the other's evict set, so this is only caught by the evict-set overlap
|
||||
// check in collidesWith.
|
||||
func TestFIFO_OverlappingEvictSetsDoNotRunInParallel(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
eff.states["x"] = process.StateReady // shared eviction target, running
|
||||
// Loading a or b both require evicting x.
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"a": {"x"}, "b": {"x"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // StartSwap(a, [x])
|
||||
s.OnRequest(req("b")) // overlaps a's evict set ([x]) -> queue
|
||||
if eff.startsFor("a") != 1 {
|
||||
t.Fatalf("StartSwap(a)=%d want 1", eff.startsFor("a"))
|
||||
}
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Fatalf("b started in parallel while a evicts x: StartSwap(b)=%d want 0", got)
|
||||
}
|
||||
|
||||
// a's swap completes and x is gone; b can now evict nothing and start.
|
||||
eff.states["a"] = process.StateReady
|
||||
eff.states["x"] = process.StateStopped
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
if got := eff.startsFor("b"); got != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 after a's swap drained", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_QueueDrainPromotesMultiple verifies completing one swap unblocks
|
||||
// every queued request that no longer collides — they all start together.
|
||||
func TestFIFO_QueueDrainPromotesMultiple(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
eff.states["c"] = process.StateStopped
|
||||
// a's swap evicts both b and c; b and c evict nothing.
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"a": {"b", "c"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // StartSwap(a, [b,c])
|
||||
s.OnRequest(req("b")) // collides (in a's evict set) -> queue
|
||||
s.OnRequest(req("c")) // collides -> queue
|
||||
if eff.startsFor("b") != 0 || eff.startsFor("c") != 0 {
|
||||
t.Fatalf("b/c started early")
|
||||
}
|
||||
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
|
||||
// b and c have empty evict sets and don't evict a, so both start now.
|
||||
if eff.startsFor("b") != 1 || eff.startsFor("c") != 1 {
|
||||
t.Fatalf("StartSwap b=%d c=%d want 1 each after a done", eff.startsFor("b"), eff.startsFor("c"))
|
||||
}
|
||||
if eff.served("a") != 1 {
|
||||
t.Errorf("served(a)=%d want 1", eff.served("a"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_QueueCollation verifies duplicate requests collapse into one swap
|
||||
// per model: the second request for each model joins the active swap (at arrival
|
||||
// or at drain time) rather than triggering its own swap.
|
||||
func TestFIFO_QueueCollation(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
for _, id := range []string{"a", "b", "c"} {
|
||||
eff.states[id] = process.StateStopped
|
||||
}
|
||||
// Each model evicts the other two: all swaps are mutually exclusive.
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{
|
||||
"a": {"b", "c"},
|
||||
"b": {"a", "c"},
|
||||
"c": {"a", "b"},
|
||||
}}, eff)
|
||||
|
||||
for _, id := range []string{"a", "b", "c", "a", "b", "c"} {
|
||||
s.OnRequest(req(id))
|
||||
}
|
||||
|
||||
// Drain a, then its served requests, which promotes b; repeat for b -> c.
|
||||
drain := func(model string, waiters int) {
|
||||
eff.states[model] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: model})
|
||||
for i := 0; i < waiters; i++ {
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: model})
|
||||
}
|
||||
}
|
||||
drain("a", 2)
|
||||
drain("b", 2)
|
||||
drain("c", 2)
|
||||
|
||||
for _, id := range []string{"a", "b", "c"} {
|
||||
if got := eff.startsFor(id); got != 1 {
|
||||
t.Errorf("StartSwap(%s)=%d want 1 (collation)", id, got)
|
||||
}
|
||||
if got := eff.served(id); got != 2 {
|
||||
t.Errorf("served(%s)=%d want 2", id, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_NoSwapWhileServing verifies a model still handling requests is not
|
||||
// evicted: the evicting request waits until every in-flight request drains.
|
||||
func TestFIFO_NoSwapWhileServing(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateReady
|
||||
eff.states["b"] = process.StateStopped
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"b": {"a"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // fast path, inFlight[a]=1
|
||||
s.OnRequest(req("a")) // fast path, inFlight[a]=2
|
||||
s.OnRequest(req("b")) // would evict busy a -> queue
|
||||
if eff.startsFor("b") != 0 {
|
||||
t.Fatalf("b started while a serving")
|
||||
}
|
||||
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"}) // inFlight[a]=1
|
||||
if eff.startsFor("b") != 0 {
|
||||
t.Fatalf("b started while a still serving one request")
|
||||
}
|
||||
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"}) // inFlight[a]=0
|
||||
if eff.startsFor("b") != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 after a fully drained", eff.startsFor("b"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_GrantServeFalseDoesNotLeakInFlight verifies that when a caller has
|
||||
// walked away (GrantServe returns false) the in-flight count is not bumped, so a
|
||||
// later evicting request is not blocked forever.
|
||||
func TestFIFO_GrantServeFalseDoesNotLeakInFlight(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
eff.serveResult["a"] = false // a's waiter is gone by grant time
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"b": {"a"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"}) // grant fails, inFlight[a] stays 0
|
||||
|
||||
// b evicts a; since a is not in-flight, b should start immediately.
|
||||
s.OnRequest(req("b"))
|
||||
if eff.startsFor("b") != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 (no leaked in-flight on a)", eff.startsFor("b"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_OnShutdown_FailsAllWaiters verifies shutdown errors every waiter the
|
||||
// scheduler holds: active-swap waiters and queued requests alike.
|
||||
func TestFIFO_OnShutdown_FailsAllWaiters(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
for _, id := range []string{"a", "b", "c"} {
|
||||
eff.states[id] = process.StateStopped
|
||||
}
|
||||
// a and b load in parallel; c collides with both and queues.
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"c": {"a", "b"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // StartSwap(a)
|
||||
s.OnRequest(req("a")) // join a
|
||||
s.OnRequest(req("b")) // StartSwap(b)
|
||||
s.OnRequest(req("b")) // join b
|
||||
s.OnRequest(req("c")) // queued
|
||||
|
||||
s.OnShutdown(errors.New("shutting down"))
|
||||
|
||||
if got := eff.errored(""); got != 5 {
|
||||
t.Errorf("error grants=%d want 5 (2 a + 2 b + 1 c)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_OnUnload_ReleasesActiveWaiters(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // active swap a with one waiter
|
||||
s.OnRequest(req("a")) // join
|
||||
|
||||
s.OnUnload([]string{"a"}, time.Second)
|
||||
|
||||
if got := eff.errored("a"); got != 2 {
|
||||
t.Errorf("errored(a)=%d want 2 (active swap waiters released)", got)
|
||||
}
|
||||
if len(eff.stops) != 1 || len(eff.stops[0].ids) != 1 || eff.stops[0].ids[0] != "a" {
|
||||
t.Errorf("StopProcesses=%+v want one call stopping [a]", eff.stops)
|
||||
}
|
||||
if eff.stops[0].timeout != time.Second {
|
||||
t.Errorf("StopProcesses timeout=%v want 1s", eff.stops[0].timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_OnUnload_DropsQueuedRequests(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
// b evicts a, so a request for b queues while a is loading.
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"b": {"a"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // StartSwap(a)
|
||||
s.OnRequest(req("b")) // queued
|
||||
|
||||
s.OnUnload([]string{"b"}, time.Second)
|
||||
|
||||
if got := eff.errored("b"); got != 1 {
|
||||
t.Errorf("errored(b)=%d want 1 (queued request dropped)", got)
|
||||
}
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Errorf("StartSwap(b)=%d want 0 (b should never start)", got)
|
||||
}
|
||||
// a's swap is untouched: its waiter is neither served nor errored yet.
|
||||
if eff.served("a") != 0 || eff.errored("a") != 0 {
|
||||
t.Errorf("a swap should be untouched: served=%d errored=%d", eff.served("a"), eff.errored("a"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_PriorityQueueOrder verifies queued requests are ordered by descending
|
||||
// priority, with arrival (FIFO) order preserved among equal-priority models.
|
||||
func TestFIFO_PriorityQueueOrder(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
for _, m := range []string{"z", "A", "B", "C", "D"} {
|
||||
eff.states[m] = process.StateStopped
|
||||
}
|
||||
// z's swap evicts every other model, so any request that arrives while z is
|
||||
// loading collides with z's in-flight swap and parks in the queue.
|
||||
planner := &stubPlanner{evict: map[string][]string{"z": {"A", "B", "C", "D"}}}
|
||||
cfg := config.FifoConfig{Priority: map[string]int{"A": 10, "B": 5, "C": 5, "D": 1}}
|
||||
s := NewFIFO("test", logmon.NewWriter(io.Discard), planner, cfg, nil, eff)
|
||||
|
||||
s.OnRequest(req("z")) // StartSwap(z, [A,B,C,D])
|
||||
|
||||
// Arrive out of priority order; B before C exercises FIFO tie-breaking.
|
||||
for _, m := range []string{"B", "D", "C", "A"} {
|
||||
s.OnRequest(req(m))
|
||||
}
|
||||
|
||||
got := make([]string, len(s.queued))
|
||||
for i, q := range s.queued {
|
||||
got[i] = q.Model
|
||||
}
|
||||
want := []string{"A", "B", "C", "D"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("queue=%v want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("queue=%v want %v", got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_OnCancel_QueuedRequest verifies that cancelling a queued request
|
||||
// prevents drainQueue from ever starting a model load for it. Without OnCancel
|
||||
// the dead request would sit in the queue until a drain triggers a wasted swap.
|
||||
func TestFIFO_OnCancel_QueuedRequest(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
// b evicts a, so a request for b queues while a is loading.
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"b": {"a"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // StartSwap(a)
|
||||
|
||||
cancelledReq := reqCh("b")
|
||||
s.OnRequest(cancelledReq) // queued (collides with a's in-flight swap)
|
||||
if len(s.queued) != 1 {
|
||||
t.Fatalf("queue len=%d want 1 before cancel", len(s.queued))
|
||||
}
|
||||
|
||||
// Client disconnects.
|
||||
s.OnCancel(cancelledReq)
|
||||
|
||||
if len(s.queued) != 0 {
|
||||
t.Fatalf("queue len=%d want 0 after cancel", len(s.queued))
|
||||
}
|
||||
|
||||
// a's swap finishes; drainQueue runs but b is gone — no swap for b.
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Errorf("StartSwap(b)=%d want 0 (cancelled request should not trigger a load)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_OnCancel_SwapWaiter verifies that cancelling a request that joined an
|
||||
// in-flight swap removes it from the waiter list. When the swap completes, the
|
||||
// cancelled waiter receives no grant and does not bump the in-flight count.
|
||||
func TestFIFO_OnCancel_SwapWaiter(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
liveReq := reqCh("a")
|
||||
cancelledReq := reqCh("a")
|
||||
s.OnRequest(liveReq) // starts swap
|
||||
s.OnRequest(cancelledReq) // joins
|
||||
|
||||
if sw := s.active["a"]; len(sw.waiters) != 2 {
|
||||
t.Fatalf("waiters=%d want 2", len(sw.waiters))
|
||||
}
|
||||
|
||||
s.OnCancel(cancelledReq)
|
||||
|
||||
if sw := s.active["a"]; len(sw.waiters) != 1 {
|
||||
t.Fatalf("waiters=%d want 1 after cancel", len(sw.waiters))
|
||||
}
|
||||
|
||||
// Swap finishes: only the live waiter is granted.
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Errorf("served(a)=%d want 1 (only the non-cancelled waiter)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_OnCancel_NotPresent is a no-op: cancelling a request that was already
|
||||
// granted (and is no longer queued or waiting) must not affect anything.
|
||||
func TestFIFO_OnCancel_NotPresent(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateReady
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
r := reqCh("a")
|
||||
s.OnRequest(r) // fast-path served immediately
|
||||
|
||||
// Cancel after grant — should be a harmless no-op.
|
||||
s.OnCancel(r)
|
||||
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Errorf("served(a)=%d want 1 (cancel of granted request is a no-op)", got)
|
||||
}
|
||||
if len(s.queued) != 0 {
|
||||
t.Errorf("queue should be empty, len=%d", len(s.queued))
|
||||
}
|
||||
}
|
||||
|
||||
// newFIFOWithLimit builds a FIFO whose single model has the given concurrency
|
||||
// limit, already in StateReady so every request exercises the fast path.
|
||||
func newFIFOWithLimit(t *testing.T, model string, limit int) (*FIFO, *fakeEffects) {
|
||||
t.Helper()
|
||||
eff := newFakeEffects()
|
||||
eff.states[model] = process.StateReady
|
||||
models := map[string]config.ModelConfig{
|
||||
model: {ConcurrencyLimit: limit},
|
||||
}
|
||||
s := NewFIFO("test", logmon.NewWriter(io.Discard), &stubPlanner{}, config.FifoConfig{}, models, eff)
|
||||
return s, eff
|
||||
}
|
||||
|
||||
// TestFIFO_ConcurrencyLimit_RejectsOverLimit verifies that a request arriving
|
||||
// while the model is at capacity gets an error grant instead of being served,
|
||||
// and that a new request succeeds once an in-flight one completes.
|
||||
func TestFIFO_ConcurrencyLimit_RejectsOverLimit(t *testing.T) {
|
||||
s, eff := newFIFOWithLimit(t, "a", 1)
|
||||
|
||||
// First request: served (inFlight 0 → 1).
|
||||
s.OnRequest(req("a"))
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Fatalf("served(a)=%d want 1", got)
|
||||
}
|
||||
|
||||
// Second request while slot is occupied: rejected with HTTPError 429.
|
||||
s.OnRequest(req("a"))
|
||||
if got := eff.errored("a"); got != 1 {
|
||||
t.Fatalf("errored(a)=%d want 1 (over-limit)", got)
|
||||
}
|
||||
var httpErr shared.HTTPError
|
||||
if !errors.As(eff.grants[len(eff.grants)-1].err, &httpErr) {
|
||||
t.Fatalf("err=%v want HTTPError", eff.grants[len(eff.grants)-1].err)
|
||||
}
|
||||
if httpErr.StatusCode() != http.StatusTooManyRequests {
|
||||
t.Fatalf("StatusCode()=%d want 429", httpErr.StatusCode())
|
||||
}
|
||||
if httpErr.Header().Get("Retry-After") == "" {
|
||||
t.Fatal("missing Retry-After header")
|
||||
}
|
||||
|
||||
// After the in-flight request finishes, a new request succeeds.
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
|
||||
s.OnRequest(req("a"))
|
||||
if got := eff.served("a"); got != 2 {
|
||||
t.Fatalf("served(a)=%d want 2 after drain", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_ConcurrencyLimit_DefaultIsTen verifies that a model without an
|
||||
// explicit ConcurrencyLimit gets the default cap of 10.
|
||||
func TestFIFO_ConcurrencyLimit_DefaultIsTen(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateReady
|
||||
// nil models → every model gets defaultConcurrencyLimit (10).
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
s.OnRequest(req("a"))
|
||||
}
|
||||
if got := eff.served("a"); got != 10 {
|
||||
t.Fatalf("served(a)=%d want 10 (default limit)", got)
|
||||
}
|
||||
|
||||
// 11th request is rejected.
|
||||
s.OnRequest(req("a"))
|
||||
if got := eff.errored("a"); got != 1 {
|
||||
t.Fatalf("errored(a)=%d want 1 (over default limit)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_ConcurrencyLimit_CustomLimit verifies a ConcurrencyLimit greater
|
||||
// than zero overrides the default.
|
||||
func TestFIFO_ConcurrencyLimit_CustomLimit(t *testing.T) {
|
||||
s, eff := newFIFOWithLimit(t, "a", 2)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
s.OnRequest(req("a"))
|
||||
s.OnRequest(req("a"))
|
||||
|
||||
if got := eff.served("a"); got != 2 {
|
||||
t.Fatalf("served(a)=%d want 2 (custom limit)", got)
|
||||
}
|
||||
if got := eff.errored("a"); got != 1 {
|
||||
t.Fatalf("errored(a)=%d want 1 (over custom limit)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_ConcurrencyLimit_SwapWaiters verifies that when more swap waiters
|
||||
// exist than the concurrency limit, excess waiters are rejected on swap
|
||||
// completion rather than exceeding the limit.
|
||||
func TestFIFO_ConcurrencyLimit_SwapWaiters(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
models := map[string]config.ModelConfig{
|
||||
"a": {ConcurrencyLimit: 2},
|
||||
}
|
||||
s := NewFIFO("test", logmon.NewWriter(io.Discard), &stubPlanner{}, config.FifoConfig{}, models, eff)
|
||||
|
||||
// Three requests arrive while model is loading: one starts swap, two join.
|
||||
s.OnRequest(req("a"))
|
||||
s.OnRequest(req("a"))
|
||||
s.OnRequest(req("a"))
|
||||
|
||||
if got := eff.startsFor("a"); got != 1 {
|
||||
t.Fatalf("StartSwap(a)=%d want 1", got)
|
||||
}
|
||||
|
||||
// Swap completes: two served (limit), one rejected.
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
|
||||
if got := eff.served("a"); got != 2 {
|
||||
t.Fatalf("served(a)=%d want 2 (limit on swap completion)", got)
|
||||
}
|
||||
if got := eff.errored("a"); got != 1 {
|
||||
t.Fatalf("errored(a)=%d want 1 (excess waiter rejected)", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,143 @@
|
||||
// Package scheduler contains the request-scheduling strategies used by the
|
||||
// router's baseRouter. A Scheduler owns the queue, in-flight tracking, and the
|
||||
// decision tree for when to start a swap versus queue a request. The baseRouter
|
||||
// owns the channels, run loop, and process machinery, and exposes the
|
||||
// side-effects a scheduler needs through the Effects interface.
|
||||
//
|
||||
// Splitting these apart lets the scheduling strategy be swapped out
|
||||
// independently of both the process machinery (baseRouter) and the eviction
|
||||
// policy (Swapper). FIFO is the first and currently only implementation.
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// ErrModelNotFound is granted to callers whose model is not handled by this
|
||||
// router. It is an alias for shared.ErrNoLocalModelFound.
|
||||
var ErrModelNotFound = shared.ErrNoLocalModelFound
|
||||
|
||||
// Swapper is the eviction policy: it decides which running models must be
|
||||
// stopped before a target can serve. It is orthogonal to the scheduling
|
||||
// strategy — any Scheduler works with any Swapper.
|
||||
type Swapper interface {
|
||||
// EvictionFor returns running model IDs that must be stopped before
|
||||
// target can serve. running is the complete set the scheduler considers
|
||||
// live: every process that is not stopped, unioned with the targets of
|
||||
// in-flight swaps the scheduler has already committed to (which are not yet
|
||||
// visible in process state). The planner does not inspect process state
|
||||
// itself. Pure decision; must not log.
|
||||
EvictionFor(target string, running []string) []string
|
||||
|
||||
// OnSwapStart runs once at the start of every swap, with the same running
|
||||
// set EvictionFor was given for this decision. Planners may log their
|
||||
// decision here at whatever verbosity they choose.
|
||||
OnSwapStart(target string, running []string)
|
||||
}
|
||||
|
||||
// Scheduler decides what happens to each event the router's run loop receives.
|
||||
// All methods run on that single run-loop goroutine, so implementations need no
|
||||
// internal locking for their own state.
|
||||
type Scheduler interface {
|
||||
// OnRequest handles one incoming ServeHTTP request.
|
||||
OnRequest(req HandlerReq)
|
||||
// OnCancel handles a request whose client has disconnected before it was
|
||||
// granted. The scheduler must remove the request from its queue and from
|
||||
// any in-flight swap's waiters so it never triggers a model load or grant
|
||||
// for a caller that is no longer there.
|
||||
OnCancel(req HandlerReq)
|
||||
// OnSwapDone handles a swap goroutine reporting completion.
|
||||
OnSwapDone(ev SwapDone)
|
||||
// OnServeDone handles a tracked ServeHTTP finishing (in-flight decrement).
|
||||
OnServeDone(ev ServeDoneEvent)
|
||||
// OnUnload reconciles scheduler state for an unload, stops the targeted
|
||||
// processes via Effects, and drains the queue. It must block until the
|
||||
// targeted processes have stopped.
|
||||
OnUnload(targets []string, timeout time.Duration)
|
||||
// OnShutdown grants err to every waiter the scheduler still holds (active
|
||||
// swap waiters and queued requests). Process teardown is the baseRouter's
|
||||
// responsibility.
|
||||
OnShutdown(err error)
|
||||
}
|
||||
|
||||
// Effects is implemented by the baseRouter. The scheduler calls back through it
|
||||
// for every side-effect: inspecting process state, launching swaps, responding
|
||||
// to callers, and stopping processes.
|
||||
type Effects interface {
|
||||
// ModelState returns the current state of a model's process. ok is false
|
||||
// when the model is not handled by this router.
|
||||
ModelState(modelID string) (process.ProcessState, bool)
|
||||
// RunningModels returns the state of every process that is not stopped or
|
||||
// shut down, keyed by model ID. The scheduler uses it to build the running
|
||||
// set it hands the Swapper.
|
||||
RunningModels() map[string]process.ProcessState
|
||||
// StartSwap launches the swap goroutine for modelID, stopping evict first.
|
||||
StartSwap(modelID string, evict []string)
|
||||
// GrantError responds to a caller with an error.
|
||||
GrantError(req HandlerReq, err error)
|
||||
// GrantServe hands a caller the wrapped handler for modelID and reports
|
||||
// whether the caller was still there to receive it. The scheduler bumps
|
||||
// its in-flight count only when this returns true.
|
||||
GrantServe(req HandlerReq, modelID string) bool
|
||||
// StopProcesses stops the named processes in parallel and blocks until all
|
||||
// have stopped. Unknown IDs are skipped.
|
||||
StopProcesses(timeout time.Duration, ids []string)
|
||||
}
|
||||
|
||||
// New returns a Scheduler selected by conf.Routing.Scheduler.Use, configured from
|
||||
// conf and bound to the given planner and effects. Supported values are "fifo"
|
||||
// (throughput-oriented, batches same-model requests) and "serial" (strict
|
||||
// one-model-at-a-time, exact arrival order).
|
||||
//
|
||||
// The deployment default is applied by config loading (LoadConfig sets Use to
|
||||
// "serial" when unset). The "" fallback here is the library default and remains
|
||||
// "fifo" so callers that build a Config directly keep the original behavior.
|
||||
func New(conf config.Config, name string, logger *logmon.Monitor, planner Swapper, eff Effects) (Scheduler, error) {
|
||||
use := conf.Routing.Scheduler.Use
|
||||
if use == "" {
|
||||
use = "fifo"
|
||||
}
|
||||
switch use {
|
||||
case "fifo":
|
||||
return NewFIFO(name, logger, planner, conf.Routing.Scheduler.Settings.Fifo, conf.Models, eff), nil
|
||||
case "serial":
|
||||
// Serial ignores the group planner: it always evicts every other model.
|
||||
return NewSerial(name, logger, eff), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported scheduler type: %q", use)
|
||||
}
|
||||
}
|
||||
|
||||
// HandlerReq is one in-flight ServeHTTP request waiting for a routing decision.
|
||||
type HandlerReq struct {
|
||||
Model string
|
||||
Ctx context.Context
|
||||
Respond chan HandlerResp
|
||||
PositionCh chan int
|
||||
}
|
||||
|
||||
// HandlerResp is the routing decision returned to a HandlerReq's caller: either
|
||||
// a handler to serve with, or an error.
|
||||
type HandlerResp struct {
|
||||
HandleFunc http.HandlerFunc
|
||||
Err error
|
||||
}
|
||||
|
||||
// SwapDone is reported by a swap goroutine when its target is ready (or failed).
|
||||
type SwapDone struct {
|
||||
ModelID string
|
||||
Err error
|
||||
}
|
||||
|
||||
// ServeDoneEvent is reported when a tracked ServeHTTP handler returns.
|
||||
type ServeDoneEvent struct {
|
||||
ModelID string
|
||||
}
|
||||
@@ -0,0 +1,253 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
)
|
||||
|
||||
// Serial is a strict one-model-at-a-time scheduler. Unlike FIFO it never reorders
|
||||
// or batches: requests run in exact arrival order and at most one request runs at
|
||||
// any instant. When the next request targets a model other than the one loaded,
|
||||
// every other running model is evicted and the target is loaded before it runs,
|
||||
// so a single model occupies memory at a time — at the cost of throughput.
|
||||
//
|
||||
// Example: A B C A is served as A B C A. The final A reloads its model even
|
||||
// though it ran first, because B and C displaced it in between. (FIFO, by
|
||||
// contrast, would batch the two A requests: A A B C.)
|
||||
//
|
||||
// Serial ignores group/eviction policy entirely: it always evicts every other
|
||||
// running model, regardless of how groups are configured. That is what makes the
|
||||
// single-model guarantee a property of the scheduler rather than of the config.
|
||||
//
|
||||
// Like FIFO, every method runs on the router's single run-loop goroutine, so no
|
||||
// internal locking is needed.
|
||||
type Serial struct {
|
||||
name string
|
||||
logger *logmon.Monitor
|
||||
effects Effects
|
||||
|
||||
// queued holds requests in strict arrival order. It is never reordered.
|
||||
queued []HandlerReq
|
||||
|
||||
// active is the one request currently being processed (loading or serving),
|
||||
// or nil when idle. phase is meaningful only while active != nil.
|
||||
active *HandlerReq
|
||||
phase serialPhase
|
||||
}
|
||||
|
||||
// serialPhase is the lifecycle stage of the active request.
|
||||
type serialPhase int
|
||||
|
||||
const (
|
||||
phaseIdle serialPhase = iota
|
||||
phaseSwapping // waiting for OnSwapDone for active.Model
|
||||
phaseServing // waiting for OnServeDone for active.Model
|
||||
)
|
||||
|
||||
// NewSerial builds a Serial scheduler. It takes no Swapper: eviction is always
|
||||
// "stop every other running model", so the group planner is not consulted.
|
||||
func NewSerial(name string, logger *logmon.Monitor, eff Effects) *Serial {
|
||||
return &Serial{
|
||||
name: name,
|
||||
logger: logger,
|
||||
effects: eff,
|
||||
}
|
||||
}
|
||||
|
||||
// OnRequest validates the model and appends the request to the tail of the queue,
|
||||
// then tries to start the next job. Unknown models fail immediately.
|
||||
func (s *Serial) OnRequest(req HandlerReq) {
|
||||
if _, ok := s.effects.ModelState(req.Model); !ok {
|
||||
s.logger.Debugf("%s: model %s not handled by this router", s.name, req.Model)
|
||||
s.effects.GrantError(req, ErrModelNotFound)
|
||||
return
|
||||
}
|
||||
s.queued = append(s.queued, req)
|
||||
broadcastQueuePositions(s.queued)
|
||||
s.startNext()
|
||||
}
|
||||
|
||||
// startNext begins processing the head of the queue when nothing is active. It
|
||||
// fast-paths a request whose model is already the sole loaded-and-ready process;
|
||||
// otherwise it launches a swap that evicts every other running model first. The
|
||||
// loop skips over requests for models that vanished (e.g. a config reload) and
|
||||
// requests whose caller disconnected before they could be served.
|
||||
func (s *Serial) startNext() {
|
||||
if s.active != nil {
|
||||
return // a job is already loading or serving
|
||||
}
|
||||
for len(s.queued) > 0 {
|
||||
req := s.queued[0]
|
||||
s.queued = s.queued[1:]
|
||||
broadcastQueuePositions(s.queued)
|
||||
|
||||
state, ok := s.effects.ModelState(req.Model)
|
||||
if !ok {
|
||||
s.effects.GrantError(req, ErrModelNotFound)
|
||||
continue
|
||||
}
|
||||
|
||||
r := req
|
||||
s.active = &r
|
||||
|
||||
evict := s.otherRunning(req.Model)
|
||||
if state == process.StateReady && len(evict) == 0 {
|
||||
// Already loaded and the only model running — serve immediately.
|
||||
s.logger.Debugf("%s: serving model %s (already loaded)", s.name, req.Model)
|
||||
if s.serve() {
|
||||
return
|
||||
}
|
||||
continue // caller gone; pick the next request
|
||||
}
|
||||
|
||||
s.logger.Debugf("%s: swapping to model %s, evicting %v", s.name, req.Model, evict)
|
||||
s.phase = phaseSwapping
|
||||
s.effects.StartSwap(req.Model, evict)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// serve hands the active request its tracked handler. It returns true when the
|
||||
// request is now serving (await OnServeDone); false when the caller had already
|
||||
// disconnected, in which case active is cleared so the next job can start.
|
||||
func (s *Serial) serve() bool {
|
||||
if s.effects.GrantServe(*s.active, s.active.Model) {
|
||||
s.phase = phaseServing
|
||||
return true
|
||||
}
|
||||
s.logger.Debugf("%s: caller for model %s gone before serve", s.name, s.active.Model)
|
||||
s.active = nil
|
||||
s.phase = phaseIdle
|
||||
return false
|
||||
}
|
||||
|
||||
// OnSwapDone fires when the load for the active request completes. On success the
|
||||
// request is served; on failure its caller receives the error and the queue
|
||||
// advances. A SwapDone that does not match the active load (e.g. its request was
|
||||
// unloaded or cancelled mid-load) is ignored.
|
||||
func (s *Serial) OnSwapDone(ev SwapDone) {
|
||||
if s.active == nil || s.phase != phaseSwapping || s.active.Model != ev.ModelID {
|
||||
return
|
||||
}
|
||||
if ev.Err != nil {
|
||||
s.logger.Debugf("%s: swap for model %s failed: %v", s.name, ev.ModelID, ev.Err)
|
||||
s.effects.GrantError(*s.active, ev.Err)
|
||||
s.active = nil
|
||||
s.phase = phaseIdle
|
||||
s.startNext()
|
||||
return
|
||||
}
|
||||
if !s.serve() {
|
||||
s.startNext() // caller vanished while the model loaded; move on
|
||||
}
|
||||
}
|
||||
|
||||
// OnServeDone fires when the active request's handler returns. The slot is freed
|
||||
// and the next queued request begins.
|
||||
func (s *Serial) OnServeDone(ev ServeDoneEvent) {
|
||||
if s.active == nil || s.phase != phaseServing {
|
||||
return
|
||||
}
|
||||
s.active = nil
|
||||
s.phase = phaseIdle
|
||||
s.startNext()
|
||||
}
|
||||
|
||||
// OnCancel removes a disconnected client's request from the queue. A request that
|
||||
// is already active is left to finish: if it was loading, OnSwapDone's serve()
|
||||
// will find the caller gone (GrantServe false) and advance; if it was serving,
|
||||
// its handler returns normally and reaches OnServeDone.
|
||||
func (s *Serial) OnCancel(req HandlerReq) {
|
||||
if len(s.queued) == 0 {
|
||||
return
|
||||
}
|
||||
kept := s.queued[:0]
|
||||
removed := false
|
||||
for _, q := range s.queued {
|
||||
if q.Respond == req.Respond {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
kept = append(kept, q)
|
||||
}
|
||||
s.queued = kept
|
||||
if removed {
|
||||
s.logger.Debugf("%s: cancelled request for model %s pruned from queue", s.name, req.Model)
|
||||
broadcastQueuePositions(s.queued)
|
||||
}
|
||||
}
|
||||
|
||||
// OnUnload reconciles state for an unload, stops the targeted processes, and
|
||||
// advances the queue. It mirrors the FIFO contract: queued requests for unloaded
|
||||
// models are failed; an active *loading* request for an unloaded model is failed
|
||||
// (its swap goroutine is left to finish and its SwapDone is then ignored); an
|
||||
// active *serving* request is left for its handler to end when StopProcesses
|
||||
// kills the upstream. The Stop is synchronous so callers of Unload can rely on
|
||||
// the processes being stopped on return.
|
||||
func (s *Serial) OnUnload(targets []string, timeout time.Duration) {
|
||||
unloadErr := fmt.Errorf("%s: model unloaded", s.name)
|
||||
|
||||
targetSet := make(map[string]bool, len(targets))
|
||||
for _, id := range targets {
|
||||
targetSet[id] = true
|
||||
}
|
||||
|
||||
if s.active != nil && s.phase == phaseSwapping && targetSet[s.active.Model] {
|
||||
s.effects.GrantError(*s.active, unloadErr)
|
||||
s.active = nil
|
||||
s.phase = phaseIdle
|
||||
}
|
||||
|
||||
if len(s.queued) > 0 {
|
||||
kept := s.queued[:0]
|
||||
for _, q := range s.queued {
|
||||
if targetSet[q.Model] {
|
||||
s.effects.GrantError(q, unloadErr)
|
||||
continue
|
||||
}
|
||||
kept = append(kept, q)
|
||||
}
|
||||
s.queued = kept
|
||||
broadcastQueuePositions(s.queued)
|
||||
}
|
||||
|
||||
s.effects.StopProcesses(timeout, targets)
|
||||
|
||||
// A still-serving active request advances via OnServeDone when its killed
|
||||
// handler returns; only start the next job when nothing is active now.
|
||||
if s.active == nil {
|
||||
s.startNext()
|
||||
}
|
||||
}
|
||||
|
||||
// OnShutdown grants err to every request the scheduler still holds: an active
|
||||
// loading request and all queued requests. A serving request is torn down with
|
||||
// its process by the baseRouter.
|
||||
func (s *Serial) OnShutdown(err error) {
|
||||
if s.active != nil && s.phase == phaseSwapping {
|
||||
s.effects.GrantError(*s.active, err)
|
||||
s.active = nil
|
||||
s.phase = phaseIdle
|
||||
}
|
||||
for _, q := range s.queued {
|
||||
s.effects.GrantError(q, err)
|
||||
}
|
||||
s.queued = nil
|
||||
}
|
||||
|
||||
// otherRunning returns every running model except target, sorted for
|
||||
// deterministic eviction.
|
||||
func (s *Serial) otherRunning(target string) []string {
|
||||
var out []string
|
||||
for id := range s.effects.RunningModels() {
|
||||
if id != target {
|
||||
out = append(out, id)
|
||||
}
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,391 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
)
|
||||
|
||||
// Serial methods all run on the router's single run-loop goroutine, so these
|
||||
// tests drive them directly and synchronously, reusing fakeEffects and the
|
||||
// req/reqCh helpers from fifo_test.go. A load completes via OnSwapDone and a
|
||||
// served request finishes via OnServeDone — the events the run loop delivers.
|
||||
|
||||
func newSerial(eff Effects) *Serial {
|
||||
return NewSerial("test", logmon.NewWriter(io.Discard), eff)
|
||||
}
|
||||
|
||||
// lastStart returns the most recent StartSwap record.
|
||||
func lastStart(t *testing.T, eff *fakeEffects) startRec {
|
||||
t.Helper()
|
||||
if len(eff.starts) == 0 {
|
||||
t.Fatal("no StartSwap recorded")
|
||||
}
|
||||
return eff.starts[len(eff.starts)-1]
|
||||
}
|
||||
|
||||
func sameSet(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
m := map[string]int{}
|
||||
for _, x := range a {
|
||||
m[x]++
|
||||
}
|
||||
for _, x := range b {
|
||||
m[x]--
|
||||
}
|
||||
for _, v := range m {
|
||||
if v != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// servedOrder returns the model IDs of every successful serve grant in order.
|
||||
func servedOrder(eff *fakeEffects) []string {
|
||||
var out []string
|
||||
for _, g := range eff.grants {
|
||||
if g.err == nil && g.serve {
|
||||
out = append(out, g.model)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestSerial_FastPath_AlreadyLoaded(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateReady
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
|
||||
if got := len(eff.starts); got != 0 {
|
||||
t.Errorf("StartSwap calls=%d want 0 (already loaded, no swap)", got)
|
||||
}
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Errorf("served(a)=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerial_ColdStart_LoadsThenServes(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
if got := eff.startsFor("a"); got != 1 {
|
||||
t.Fatalf("StartSwap(a)=%d want 1", got)
|
||||
}
|
||||
if got := eff.served("a"); got != 0 {
|
||||
t.Errorf("served(a)=%d want 0 before load completes", got)
|
||||
}
|
||||
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Errorf("served(a)=%d want 1 after load", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerial_UnknownModel(t *testing.T) {
|
||||
eff := newFakeEffects() // no states => unknown
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("ghost"))
|
||||
|
||||
if len(eff.starts) != 0 {
|
||||
t.Errorf("StartSwap calls=%d want 0", len(eff.starts))
|
||||
}
|
||||
if eff.errored("ghost") != 1 {
|
||||
t.Fatalf("errored(ghost)=%d want 1", eff.errored("ghost"))
|
||||
}
|
||||
if !errors.Is(eff.grants[0].err, ErrModelNotFound) {
|
||||
t.Errorf("err=%v want ErrModelNotFound", eff.grants[0].err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerial_EvictsEveryOtherModel(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["x"] = process.StateReady // already running
|
||||
eff.states["y"] = process.StateReady // also running (e.g. left over)
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
|
||||
st := lastStart(t, eff)
|
||||
if st.model != "a" {
|
||||
t.Fatalf("loading %s want a", st.model)
|
||||
}
|
||||
if !sameSet(st.evict, []string{"x", "y"}) {
|
||||
t.Errorf("evict=%v want [x y] (serial evicts ALL other models)", st.evict)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerial_OneJobAtATime verifies a second request waits while the first is
|
||||
// serving, and only starts after the first finishes.
|
||||
func TestSerial_OneJobAtATime(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateReady
|
||||
eff.states["b"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a")) // served immediately
|
||||
s.OnRequest(req("b")) // must wait — a is serving
|
||||
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Fatalf("StartSwap(b)=%d want 0 while a is serving", got)
|
||||
}
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Fatalf("served(a)=%d want 1", got)
|
||||
}
|
||||
|
||||
// a finishes -> b may now load (evicting a).
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
|
||||
if got := eff.startsFor("b"); got != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 after a finished", got)
|
||||
}
|
||||
if st := lastStart(t, eff); !sameSet(st.evict, []string{"a"}) {
|
||||
t.Errorf("b evict=%v want [a]", st.evict)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerial_SameModelConsecutive_NoReload verifies back-to-back requests for the
|
||||
// already-loaded model run without a reload, one after another.
|
||||
func TestSerial_SameModelConsecutive_NoReload(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a")) // cold load
|
||||
s.OnRequest(req("a")) // queued behind the first
|
||||
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"}) // first serves
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Fatalf("served(a)=%d want 1 (one at a time)", got)
|
||||
}
|
||||
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"}) // first done -> second serves
|
||||
if got := eff.served("a"); got != 2 {
|
||||
t.Fatalf("served(a)=%d want 2", got)
|
||||
}
|
||||
if got := eff.startsFor("a"); got != 1 {
|
||||
t.Errorf("StartSwap(a)=%d want 1 (second request must not reload)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerial_StrictArrivalOrder is the core guarantee: qwen36, qwen35, sdxl,
|
||||
// qwen36 execute in EXACTLY that order with evictions between each model switch,
|
||||
// including reloading qwen36 at the end even though it ran first.
|
||||
func TestSerial_StrictArrivalOrder(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
for _, m := range []string{"qwen36", "qwen35", "sdxl"} {
|
||||
eff.states[m] = process.StateStopped
|
||||
}
|
||||
s := newSerial(eff)
|
||||
|
||||
for _, m := range []string{"qwen36", "qwen35", "sdxl", "qwen36"} {
|
||||
s.OnRequest(req(m))
|
||||
}
|
||||
|
||||
// Only the first job starts loading; the rest wait their turn.
|
||||
if len(eff.starts) != 1 || eff.starts[0].model != "qwen36" {
|
||||
t.Fatalf("starts=%+v want only [qwen36] loading first", eff.starts)
|
||||
}
|
||||
|
||||
// step completes the current model's load+serve and returns control to the
|
||||
// scheduler, which must start the next queued model.
|
||||
step := func(model string, wantEvict []string) {
|
||||
t.Helper()
|
||||
st := lastStart(t, eff)
|
||||
if st.model != model {
|
||||
t.Fatalf("loading %q want %q", st.model, model)
|
||||
}
|
||||
if !sameSet(st.evict, wantEvict) {
|
||||
t.Fatalf("loading %q evict=%v want %v", model, st.evict, wantEvict)
|
||||
}
|
||||
// Simulate the eviction + load actually happening.
|
||||
for _, e := range st.evict {
|
||||
eff.states[e] = process.StateStopped
|
||||
}
|
||||
eff.states[model] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: model})
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: model})
|
||||
}
|
||||
|
||||
step("qwen36", nil) // cold load, nothing else running
|
||||
step("qwen35", []string{"qwen36"}) // evict qwen36
|
||||
step("sdxl", []string{"qwen35"}) // evict qwen35
|
||||
step("qwen36", []string{"sdxl"}) // RELOAD qwen36, evict sdxl
|
||||
|
||||
want := []string{"qwen36", "qwen35", "sdxl", "qwen36"}
|
||||
if got := servedOrder(eff); !sameOrder(got, want) {
|
||||
t.Fatalf("serve order=%v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func sameOrder(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestSerial_SwapError_FailsCallerAndAdvances(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
s.OnRequest(req("b")) // queued behind a
|
||||
|
||||
// a's load fails: its caller is errored and b proceeds.
|
||||
s.OnSwapDone(SwapDone{ModelID: "a", Err: errors.New("boom")})
|
||||
if eff.errored("a") != 1 {
|
||||
t.Fatalf("errored(a)=%d want 1", eff.errored("a"))
|
||||
}
|
||||
if got := eff.startsFor("b"); got != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 after a's load failed", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerial_GrantServeFalse_Advances verifies that when the active request's
|
||||
// caller has disconnected by serve time, the queue advances to the next request.
|
||||
func TestSerial_GrantServeFalse_Advances(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
eff.serveResult["a"] = false // a's caller is gone by grant time
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
s.OnRequest(req("b")) // queued
|
||||
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"}) // grant fails -> advance to b
|
||||
|
||||
if got := eff.served("a"); got != 0 {
|
||||
t.Errorf("served(a)=%d want 0 (caller gone)", got)
|
||||
}
|
||||
if got := eff.startsFor("b"); got != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 (advanced after gone caller)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerial_OnCancel_QueuedRequest(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(reqCh("a")) // starts loading a
|
||||
cancelled := reqCh("b")
|
||||
s.OnRequest(cancelled) // queued behind a
|
||||
if len(s.queued) != 1 {
|
||||
t.Fatalf("queued=%d want 1", len(s.queued))
|
||||
}
|
||||
|
||||
s.OnCancel(cancelled)
|
||||
if len(s.queued) != 0 {
|
||||
t.Fatalf("queued=%d want 0 after cancel", len(s.queued))
|
||||
}
|
||||
|
||||
// a completes; b is gone, so nothing starts for it.
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Errorf("StartSwap(b)=%d want 0 (cancelled before its turn)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerial_OnShutdown_FailsQueuedAndActiveLoad(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
eff.states["c"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a")) // active (loading)
|
||||
s.OnRequest(req("b")) // queued
|
||||
s.OnRequest(req("c")) // queued
|
||||
|
||||
s.OnShutdown(errors.New("shutting down"))
|
||||
|
||||
if got := eff.errored(""); got != 3 {
|
||||
t.Errorf("error grants=%d want 3 (active load + 2 queued)", got)
|
||||
}
|
||||
if len(s.queued) != 0 {
|
||||
t.Errorf("queued=%d want 0 after shutdown", len(s.queued))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerial_OnUnload_WhileServing verifies that unloading the model that is
|
||||
// actively serving does not strand the queue: OnUnload stops the process but
|
||||
// leaves the active request to end via OnServeDone, which then advances.
|
||||
func TestSerial_OnUnload_WhileServing(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateReady
|
||||
eff.states["b"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a")) // served immediately (a ready)
|
||||
s.OnRequest(req("b")) // queued behind a
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Fatalf("served(a)=%d want 1", got)
|
||||
}
|
||||
|
||||
// Unload a while it is serving: the process is stopped, but the queue must
|
||||
// not advance yet — the active serve is still outstanding.
|
||||
s.OnUnload([]string{"a"}, time.Second)
|
||||
if len(eff.stops) != 1 || !sameSet(eff.stops[0].ids, []string{"a"}) {
|
||||
t.Errorf("StopProcesses=%+v want one call stopping [a]", eff.stops)
|
||||
}
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Fatalf("StartSwap(b)=%d want 0 before the serving request ends", got)
|
||||
}
|
||||
|
||||
// The killed handler returns -> OnServeDone advances to b.
|
||||
eff.states["a"] = process.StateStopped
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
|
||||
if got := eff.startsFor("b"); got != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 after the serving request ended", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerial_OnUnload_DropsQueuedAndStops(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a")) // active (loading a)
|
||||
s.OnRequest(req("b")) // queued
|
||||
|
||||
// Unload a: its active load is failed and a is stopped.
|
||||
s.OnUnload([]string{"a"}, time.Second)
|
||||
|
||||
if eff.errored("a") != 1 {
|
||||
t.Errorf("errored(a)=%d want 1 (active load failed)", eff.errored("a"))
|
||||
}
|
||||
if len(eff.stops) != 1 || !sameSet(eff.stops[0].ids, []string{"a"}) {
|
||||
t.Errorf("StopProcesses=%+v want one call stopping [a]", eff.stops)
|
||||
}
|
||||
// b was queued and not unloaded; with a's load cancelled it now starts.
|
||||
if got := eff.startsFor("b"); got != 1 {
|
||||
t.Errorf("StartSwap(b)=%d want 1 after unload advanced the queue", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,375 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// apiUnloadTimeout is used by the API endpoints to stop processes
|
||||
const apiUnloadTimeout = 10 * time.Second
|
||||
|
||||
// modelRecord is one entry in the OpenAI-compatible /v1/models listing.
|
||||
type modelRecord struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Architecture map[string]any `json:"architecture,omitempty"`
|
||||
Capabilities map[string]any `json:"capabilities,omitempty"`
|
||||
SupportedParameters []string `json:"supported_parameters,omitempty"`
|
||||
ContextLength int `json:"context_length,omitempty"`
|
||||
Meta map[string]any `json:"meta,omitempty"`
|
||||
}
|
||||
|
||||
// cappedMetadataKeys are top-level /v1/models fields produced by the
|
||||
// capabilities renderer. If a model's metadata block defines any of these
|
||||
// keys, the renderer's values win and the metadata keys are dropped.
|
||||
var cappedMetadataKeys = map[string]struct{}{
|
||||
"architecture": {},
|
||||
"capabilities": {},
|
||||
"supported_parameters": {},
|
||||
"context_length": {},
|
||||
}
|
||||
|
||||
// renderCapabilities converts a model's capabilities config into additional
|
||||
// /v1/models fields. Returns zero values when caps.Empty() is true.
|
||||
func renderCapabilities(caps config.ModelCapConfig) (arch map[string]any, capsMap map[string]any, params []string, ctxLen int) {
|
||||
if caps.Empty() {
|
||||
return
|
||||
}
|
||||
|
||||
hasIn := len(caps.In) > 0
|
||||
hasOut := len(caps.Out) > 0
|
||||
|
||||
if hasIn || hasOut {
|
||||
arch = make(map[string]any)
|
||||
}
|
||||
if hasIn {
|
||||
arch["input_modalities"] = caps.In
|
||||
}
|
||||
if hasOut {
|
||||
arch["output_modalities"] = caps.Out
|
||||
}
|
||||
if hasIn && hasOut {
|
||||
arch["modality"] = strings.Join(caps.In, "+") + "->" + strings.Join(caps.Out, "+")
|
||||
}
|
||||
|
||||
// Build capabilities map only if there's something to put in it.
|
||||
if hasIn || hasOut || caps.Tools || caps.Reranker {
|
||||
capsMap = make(map[string]any)
|
||||
}
|
||||
|
||||
if hasIn {
|
||||
if contains(caps.In, "image") {
|
||||
capsMap["vision"] = true
|
||||
}
|
||||
}
|
||||
if hasIn && hasOut {
|
||||
if contains(caps.In, "audio") && contains(caps.Out, "text") {
|
||||
capsMap["audio_transcriptions"] = true
|
||||
}
|
||||
if contains(caps.In, "text") && contains(caps.Out, "audio") {
|
||||
capsMap["audio_speech"] = true
|
||||
}
|
||||
if contains(caps.In, "text") && contains(caps.Out, "image") {
|
||||
capsMap["image_generation"] = true
|
||||
}
|
||||
if contains(caps.In, "image") && contains(caps.Out, "image") {
|
||||
capsMap["image_to_image"] = true
|
||||
}
|
||||
}
|
||||
|
||||
if caps.Tools {
|
||||
capsMap["function_calling"] = true
|
||||
params = []string{"tools", "tool_choice"}
|
||||
}
|
||||
|
||||
if caps.Reranker {
|
||||
capsMap["reranker"] = true
|
||||
}
|
||||
|
||||
if caps.Context > 0 {
|
||||
ctxLen = caps.Context
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// contains reports whether s is present in ss.
|
||||
func contains(ss []string, s string) bool {
|
||||
for _, v := range ss {
|
||||
if v == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// filterCappedMetadata returns metadata with renderer-owned keys removed.
|
||||
func filterCappedMetadata(md map[string]any) map[string]any {
|
||||
if len(md) == 0 {
|
||||
return nil
|
||||
}
|
||||
filtered := make(map[string]any, len(md))
|
||||
for k, v := range md {
|
||||
if _, capped := cappedMetadataKeys[k]; !capped {
|
||||
filtered[k] = v
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return nil
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// handleListModels serves the OpenAI-compatible model listing: local models
|
||||
// (with optional aliases) plus peer models.
|
||||
func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
|
||||
created := time.Now().Unix()
|
||||
data := make([]modelRecord, 0, len(s.cfg.Models))
|
||||
|
||||
newRecord := func(id, name, description string, metadata map[string]any, caps config.ModelCapConfig) modelRecord {
|
||||
rec := modelRecord{
|
||||
ID: id,
|
||||
Object: "model",
|
||||
Created: created,
|
||||
OwnedBy: "llama-swap",
|
||||
Name: strings.TrimSpace(name),
|
||||
Description: strings.TrimSpace(description),
|
||||
}
|
||||
rec.Architecture, rec.Capabilities, rec.SupportedParameters, rec.ContextLength = renderCapabilities(caps)
|
||||
if !caps.Empty() {
|
||||
metadata = filterCappedMetadata(metadata)
|
||||
}
|
||||
if len(metadata) > 0 {
|
||||
rec.Meta = map[string]any{"llamaswap": metadata}
|
||||
}
|
||||
return rec
|
||||
}
|
||||
|
||||
for id, mc := range s.cfg.Models {
|
||||
if mc.Unlisted {
|
||||
continue
|
||||
}
|
||||
data = append(data, newRecord(id, mc.Name, mc.Description, mc.Metadata, mc.Capabilities))
|
||||
|
||||
if s.cfg.IncludeAliasesInList {
|
||||
for _, alias := range mc.Aliases {
|
||||
if alias := strings.TrimSpace(alias); alias != "" {
|
||||
data = append(data, newRecord(alias, mc.Name, mc.Description, mc.Metadata, mc.Capabilities))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for peerID, peer := range s.cfg.Peers {
|
||||
for _, modelID := range peer.Models {
|
||||
data = append(data, newRecord(modelID, peerID+": "+modelID, "", map[string]any{"peerID": peerID}, config.ModelCapConfig{}))
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(data, func(i, j int) bool { return data[i].ID < data[j].ID })
|
||||
|
||||
// Echo the Origin so browser clients can read the listing.
|
||||
if origin := r.Header.Get("Origin"); origin != "" {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"object": "list",
|
||||
"data": data,
|
||||
})
|
||||
}
|
||||
|
||||
// runningModel is one entry in the /running listing.
|
||||
type runningModel struct {
|
||||
Model string `json:"model"`
|
||||
State string `json:"state"`
|
||||
Cmd string `json:"cmd"`
|
||||
Proxy string `json:"proxy"`
|
||||
TTL int `json:"ttl"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// handleUnload stops every running local process. Peer models are remote and
|
||||
// unaffected.
|
||||
func (s *Server) handleUnload(w http.ResponseWriter, r *http.Request) {
|
||||
s.local.Unload(apiUnloadTimeout)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}
|
||||
|
||||
// handleRunning lists local processes that are not stopped, joining each model
|
||||
// ID against its config for the cmd/proxy/ttl/name/description metadata.
|
||||
func (s *Server) handleRunning(w http.ResponseWriter, r *http.Request) {
|
||||
states := s.local.RunningModels()
|
||||
list := make([]runningModel, 0, len(states))
|
||||
for id, state := range states {
|
||||
mc := s.cfg.Models[id]
|
||||
list = append(list, runningModel{
|
||||
Model: id,
|
||||
State: string(state),
|
||||
Cmd: mc.Cmd,
|
||||
Proxy: mc.Proxy,
|
||||
TTL: mc.UnloadAfter,
|
||||
Name: mc.Name,
|
||||
Description: mc.Description,
|
||||
})
|
||||
}
|
||||
sort.Slice(list, func(i, j int) bool { return list[i].Model < list[j].Model })
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{"running": list})
|
||||
}
|
||||
|
||||
// discardResponseWriter satisfies http.ResponseWriter for preload requests,
|
||||
// dropping the body while capturing the status code.
|
||||
type discardResponseWriter struct {
|
||||
header http.Header
|
||||
status int
|
||||
}
|
||||
|
||||
func (d *discardResponseWriter) Header() http.Header {
|
||||
if d.header == nil {
|
||||
d.header = make(http.Header)
|
||||
}
|
||||
return d.header
|
||||
}
|
||||
|
||||
func (d *discardResponseWriter) Write(p []byte) (int, error) { return len(p), nil }
|
||||
|
||||
func (d *discardResponseWriter) WriteHeader(status int) { d.status = status }
|
||||
|
||||
// startPreload fires a background GET / at every model named in
|
||||
// Hooks.OnStartup.Preload so they are warm before the first real request.
|
||||
// Preload names are already resolved to real model IDs by config loading.
|
||||
func (s *Server) startPreload() {
|
||||
models := s.cfg.Hooks.OnStartup.Preload
|
||||
if len(models) == 0 {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
for _, modelID := range models {
|
||||
if !s.local.Handles(modelID) {
|
||||
s.proxylog.Warnf("preload: model %s is not a local model, skipping", modelID)
|
||||
continue
|
||||
}
|
||||
s.proxylog.Infof("preloading model: %s", modelID)
|
||||
|
||||
req, err := http.NewRequestWithContext(s.shutdownCtx, http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
req = req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: modelID, ModelID: modelID, Metadata: make(map[string]string)}))
|
||||
|
||||
dw := &discardResponseWriter{status: http.StatusOK}
|
||||
s.local.ServeHTTP(dw, req)
|
||||
|
||||
success := dw.status < http.StatusBadRequest
|
||||
if !success {
|
||||
s.proxylog.Errorf("failed to preload model %s: status %d", modelID, dw.status)
|
||||
}
|
||||
event.Emit(shared.ModelPreloadedEvent{ModelName: modelID, Success: success})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// handleMetrics serves Prometheus-format performance metrics. Returns 503 when
|
||||
// performance monitoring is disabled.
|
||||
func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
if s.perf == nil {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
w.Write([]byte("# performance monitor not available\n"))
|
||||
return
|
||||
}
|
||||
s.perf.MetricsHandler().ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}
|
||||
|
||||
func handleRootRedirect(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/ui", http.StatusFound)
|
||||
}
|
||||
|
||||
func handleUpstreamRedirect(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/ui/models", http.StatusFound)
|
||||
}
|
||||
|
||||
// handleUpstream proxies ANY request under /upstream/<model>/<path> directly to
|
||||
// the model's process, bypassing model dispatch by body/query inspection.
|
||||
func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
||||
upstreamPath := r.PathValue("upstreamPath")
|
||||
|
||||
searchName, modelID, remainingPath, found := shared.FindModelInPath(s.cfg, "/"+upstreamPath)
|
||||
if !found {
|
||||
shared.SendResponse(w, r, http.StatusNotFound, "model not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Redirect /upstream/model to /upstream/model/ so relative URLs in upstream
|
||||
// responses resolve. 301 for GET/HEAD, 308 otherwise to preserve the method.
|
||||
if remainingPath == "/" && !strings.HasSuffix(r.URL.Path, "/") {
|
||||
newPath := "/upstream/" + searchName + "/"
|
||||
if r.URL.RawQuery != "" {
|
||||
newPath += "?" + r.URL.RawQuery
|
||||
}
|
||||
if r.Method == http.MethodGet || r.Method == http.MethodHead {
|
||||
http.Redirect(w, r, newPath, http.StatusMovedPermanently)
|
||||
} else {
|
||||
http.Redirect(w, r, newPath, http.StatusPermanentRedirect)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Strip the /upstream/<model> prefix before forwarding.
|
||||
r.URL.Path = remainingPath
|
||||
// Pin the resolved model so the router skips body/query extraction.
|
||||
*r = *r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: searchName, ModelID: modelID, Metadata: make(map[string]string)}))
|
||||
|
||||
// If the path matches an upstream.ignorePaths entry and the model is
|
||||
// not already loaded, refuse the request without triggering a swap. The
|
||||
// server was not able to process the response because the model was not
|
||||
// already loaded.
|
||||
for _, re := range s.cfg.Upstream.IgnorePaths {
|
||||
if !re.MatchString(remainingPath) {
|
||||
continue
|
||||
}
|
||||
if s.local.Handles(modelID) {
|
||||
state, ok := s.local.RunningModels()[modelID]
|
||||
if !ok || state != process.StateReady {
|
||||
shared.SendResponse(w, r, http.StatusConflict,
|
||||
fmt.Sprintf("model %s is not loaded; path matches upstream.ignorePaths", modelID))
|
||||
return
|
||||
}
|
||||
}
|
||||
// Either the model is already loaded (no swap would be triggered)
|
||||
// or this is a peer model (peer proxying never swaps). Fall through
|
||||
// to normal dispatch.
|
||||
break
|
||||
}
|
||||
|
||||
switch {
|
||||
case s.local.Handles(modelID):
|
||||
s.local.ServeHTTP(w, r)
|
||||
case s.peer.Handles(modelID):
|
||||
s.peer.ServeHTTP(w, r)
|
||||
default:
|
||||
shared.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID)
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user