Compare commits
49 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d567fa78cb | |||
| 187f1ae27a | |||
| 0ae56b1eb9 | |||
| e46cbeb2bf | |||
| a0578f0007 | |||
| d207a059a4 | |||
| 040ee1e284 | |||
| 82cad1b84e | |||
| 55c3678906 | |||
| 8b5a62d92a | |||
| d1e4c8ee77 | |||
| 11f8afead8 | |||
| 749819ef47 | |||
| 0ab9e74333 | |||
| b20be6dcd1 | |||
| fc24722258 | |||
| 2b087dffb1 | |||
| 746c083a87 | |||
| 8dd91e99e8 | |||
| 136dcdc25f | |||
| 767b8015fa | |||
| f0144a2361 | |||
| 32bc781326 | |||
| 316ad63f76 | |||
| e37077a963 | |||
| eff9b60434 | |||
| 9bcddad91b | |||
| a15e47922c | |||
| 0ab214d1c8 | |||
| d07b063ab6 | |||
| 826210dac9 | |||
| 6cf1317341 | |||
| 8e84b2ec4f | |||
| ed77385d08 | |||
| 92b90447e8 | |||
| 62aea0e83d | |||
| 8c660dcb90 | |||
| f6877b8175 | |||
| 9b3a33d7b9 | |||
| 0cfe5a6639 | |||
| 44e1501e81 | |||
| 46cea36bc2 | |||
| ccfba0df28 | |||
| ddfae90b19 | |||
| 29d3d9ba20 | |||
| 9be9a87fa0 | |||
| 6ea551362e | |||
| 03d58e53fa | |||
| c790d0ee03 |
@@ -15,6 +15,8 @@ reviews:
|
||||
auto_review:
|
||||
enabled: false
|
||||
drafts: false
|
||||
unit_tests:
|
||||
enabled: false
|
||||
chat:
|
||||
auto_reply: true
|
||||
issue_enrichment:
|
||||
|
||||
@@ -44,13 +44,10 @@ jobs:
|
||||
|
||||
echo "✓ config-schema.json is valid"
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 #v6.2.0
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c #6.4.0
|
||||
with:
|
||||
python-version: "3.x"
|
||||
|
||||
- name: Install check-jsonschema
|
||||
run: pip install check-jsonschema
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Validate config.example.yaml against schema
|
||||
run: check-jsonschema --schemafile config-schema.json config.example.yaml
|
||||
run: go test ./internal/config/ -run TestConfig_ExampleMatchesSchema -v
|
||||
|
||||
@@ -2,10 +2,10 @@ name: Build Containers
|
||||
|
||||
on:
|
||||
# time has no specific meaning, trying to time it after
|
||||
# the llama.cpp daily packages are published
|
||||
# the llama.cpp daily packages have time to build and publish (~8hr after llama.cpp project's cron)
|
||||
# https://github.com/ggml-org/llama.cpp/blob/master/.github/workflows/docker.yml
|
||||
schedule:
|
||||
- cron: "37 5 * * *"
|
||||
- cron: "00 12,18 * * *"
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
@@ -19,21 +19,17 @@ all: mac linux simple-responder
|
||||
clean:
|
||||
rm -rf $(BUILD_DIR)
|
||||
|
||||
proxy/ui_dist/placeholder.txt:
|
||||
mkdir -p proxy/ui_dist
|
||||
touch $@
|
||||
|
||||
# use cached test results while developing
|
||||
test-dev: proxy/ui_dist/placeholder.txt
|
||||
go test -short ./proxy/... ./internal/...
|
||||
staticcheck ./proxy/... ./internal/... || true
|
||||
test-dev:
|
||||
go test -short ./...
|
||||
staticcheck ./... || true
|
||||
|
||||
test: proxy/ui_dist/placeholder.txt
|
||||
go test -short -count=1 ./proxy/... ./internal/...
|
||||
test:
|
||||
go test -short -count=1 ./internal/...
|
||||
|
||||
# for CI - full test (takes longer)
|
||||
test-all: proxy/ui_dist/placeholder.txt
|
||||
go test -race -count=1 ./proxy/... ./internal/...
|
||||
test-all:
|
||||
go test -race -count=1 ./internal/...
|
||||
|
||||
ui/node_modules:
|
||||
cd ui-svelte && npm install
|
||||
@@ -64,7 +60,7 @@ windows: ui
|
||||
@echo "Building Windows binary..."
|
||||
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
|
||||
|
||||
# for testing proxy.Process
|
||||
# for testing with real external processes
|
||||
simple-responder:
|
||||
@echo "Building simple responder"
|
||||
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 cmd/simple-responder/simple-responder.go
|
||||
|
||||
@@ -88,10 +88,11 @@ Real time log streaming:
|
||||
llama-swap can be installed in multiple ways
|
||||
|
||||
1. Docker
|
||||
2. Homebrew (OSX and Linux)
|
||||
3. WinGet
|
||||
4. From release binaries
|
||||
5. From source
|
||||
2. Homebrew (macOS and Linux)
|
||||
3. MacPorts (macOS)
|
||||
4. WinGet
|
||||
5. From release binaries
|
||||
6. From source
|
||||
|
||||
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||
|
||||
@@ -155,6 +156,16 @@ brew install llama-swap
|
||||
llama-swap --config path/to/config.yaml --listen localhost:8080
|
||||
```
|
||||
|
||||
### MacPorts (macOS)
|
||||
|
||||
> [!NOTE]
|
||||
> Maintained by MacPorts community - [llama-swap port](https://ports.macports.org/port/llama-swap). It is not an official part of llama-swap.
|
||||
|
||||
```shell
|
||||
sudo port install llama-swap
|
||||
llama-swap --config path/to/config.yaml --listen localhost:8080
|
||||
```
|
||||
|
||||
### WinGet Install (Windows)
|
||||
|
||||
> [!NOTE]
|
||||
|
||||
@@ -1,249 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||
"github.com/mostlygeek/llama-swap/internal/watcher"
|
||||
"github.com/mostlygeek/llama-swap/proxy"
|
||||
)
|
||||
|
||||
var (
|
||||
version string = "0"
|
||||
commit string = "abcd1234"
|
||||
date string = "unknown"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Define a command-line flag for the port
|
||||
configPath := flag.String("config", "config.yaml", "config file name")
|
||||
listenStr := flag.String("listen", "", "listen ip/port")
|
||||
certFile := flag.String("tls-cert-file", "", "TLS certificate file")
|
||||
keyFile := flag.String("tls-key-file", "", "TLS key file")
|
||||
showVersion := flag.Bool("version", false, "show version of build")
|
||||
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
|
||||
mainLogger := logmon.New()
|
||||
|
||||
flag.Parse() // Parse the command-line flags
|
||||
|
||||
if *showVersion {
|
||||
fmt.Printf("version: %s (%s), built at %s", version, commit, date)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
conf, err := config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
mainLogger.Errorf("Error loading config: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(conf.Profiles) > 0 {
|
||||
mainLogger.Warn("Profile functionality has been removed in favor of Groups. See the README for more information.")
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(conf.LogLevel)) {
|
||||
case "debug":
|
||||
mainLogger.SetLogLevel(logmon.LevelDebug)
|
||||
case "info":
|
||||
mainLogger.SetLogLevel(logmon.LevelInfo)
|
||||
case "warn":
|
||||
mainLogger.SetLogLevel(logmon.LevelWarn)
|
||||
case "error":
|
||||
mainLogger.SetLogLevel(logmon.LevelError)
|
||||
default:
|
||||
mainLogger.SetLogLevel(logmon.LevelInfo)
|
||||
}
|
||||
|
||||
mainLogger.Debugf("PID: %d", os.Getpid())
|
||||
|
||||
if mode := os.Getenv("GIN_MODE"); mode != "" {
|
||||
gin.SetMode(mode)
|
||||
} else {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
|
||||
// Validate TLS flags.
|
||||
var useTLS = (*certFile != "" && *keyFile != "")
|
||||
if (*certFile != "" && *keyFile == "") ||
|
||||
(*certFile == "" && *keyFile != "") {
|
||||
fmt.Println("Error: Both --tls-cert-file and --tls-key-file must be provided for TLS.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Set default ports.
|
||||
if *listenStr == "" {
|
||||
defaultPort := ":8080"
|
||||
if useTLS {
|
||||
defaultPort = ":8443"
|
||||
}
|
||||
listenStr = &defaultPort
|
||||
}
|
||||
|
||||
var mon *perf.Monitor
|
||||
if !conf.Performance.Disabled {
|
||||
mon, err = perf.New(conf.Performance, mainLogger)
|
||||
if err != nil {
|
||||
mainLogger.Errorf("failed to create monitor: %s", err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
mon.Start()
|
||||
} else {
|
||||
mainLogger.Info("performance monitoring is disabled")
|
||||
}
|
||||
|
||||
// Setup channels for server management
|
||||
exitChan := make(chan struct{})
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
|
||||
|
||||
// Context that bounds the lifetime of background watcher goroutines.
|
||||
watcherCtx, watcherCancel := context.WithCancel(context.Background())
|
||||
|
||||
// Create server with initial handlergit
|
||||
srv := &http.Server{
|
||||
Addr: *listenStr,
|
||||
}
|
||||
|
||||
// Support for watching config and reloading when it changes
|
||||
reloading := false
|
||||
var reloadMutex sync.Mutex
|
||||
reloadProxyManager := func() {
|
||||
reloadMutex.Lock()
|
||||
if reloading {
|
||||
reloadMutex.Unlock()
|
||||
return
|
||||
}
|
||||
reloading = true
|
||||
reloadMutex.Unlock()
|
||||
defer func() {
|
||||
reloadMutex.Lock()
|
||||
reloading = false
|
||||
reloadMutex.Unlock()
|
||||
}()
|
||||
|
||||
if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||
mainLogger.Info("Reloading Configuration")
|
||||
conf, err = config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
mainLogger.Warnf("Unable to reload configuration: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
mainLogger.Debug("Configuration Changed")
|
||||
currentPM.Shutdown()
|
||||
if mon != nil {
|
||||
mon.UpdateConfig(conf.Performance)
|
||||
}
|
||||
newPM := proxy.New(conf)
|
||||
newPM.SetVersion(date, commit, version)
|
||||
newPM.SetPerfMonitor(mon)
|
||||
srv.Handler = newPM
|
||||
mainLogger.Debug("Configuration Reloaded")
|
||||
|
||||
// wait a few seconds and tell any UI to reload
|
||||
time.AfterFunc(3*time.Second, func() {
|
||||
event.Emit(proxy.ConfigFileChangedEvent{
|
||||
ReloadingState: proxy.ReloadingStateEnd,
|
||||
})
|
||||
})
|
||||
} else {
|
||||
conf, err = config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
mainLogger.Errorf("Unable to load configuration: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
newPM := proxy.New(conf)
|
||||
newPM.SetVersion(date, commit, version)
|
||||
newPM.SetPerfMonitor(mon)
|
||||
srv.Handler = newPM
|
||||
}
|
||||
}
|
||||
|
||||
// load the initial proxy manager
|
||||
reloadProxyManager()
|
||||
|
||||
if *watchConfig {
|
||||
go func() {
|
||||
absConfigPath, err := filepath.Abs(*configPath)
|
||||
if err != nil {
|
||||
mainLogger.Errorf("watch-config unable to determine absolute path for watching config file: %v", err)
|
||||
return
|
||||
}
|
||||
mainLogger.Info("Watching configuration for changes (poll-based, 2s interval)")
|
||||
(&configwatcher.Watcher{
|
||||
Path: absConfigPath,
|
||||
Interval: configwatcher.DefaultInterval,
|
||||
OnChange: func() {
|
||||
reloadProxyManager()
|
||||
},
|
||||
}).Run(watcherCtx)
|
||||
}()
|
||||
}
|
||||
|
||||
// Signal handling
|
||||
go func() {
|
||||
for {
|
||||
sig := <-sigChan
|
||||
switch sig {
|
||||
case syscall.SIGHUP:
|
||||
mainLogger.Debug("Received SIGHUP")
|
||||
reloadProxyManager()
|
||||
case syscall.SIGINT, syscall.SIGTERM:
|
||||
mainLogger.Debugf("Received signal %v, shutting down...", sig)
|
||||
if mon != nil {
|
||||
mon.Stop()
|
||||
}
|
||||
watcherCancel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
if pm, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||
pm.Shutdown()
|
||||
} else {
|
||||
mainLogger.Error("srv.Handler is not of type *proxy.ProxyManager")
|
||||
}
|
||||
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
mainLogger.Errorf("Server shutdown: %v", err)
|
||||
}
|
||||
close(exitChan)
|
||||
return
|
||||
default:
|
||||
// do nothing on other signals
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Start server
|
||||
go func() {
|
||||
var err error
|
||||
if useTLS {
|
||||
mainLogger.Infof("llama-swap listening with TLS on https://%s", *listenStr)
|
||||
err = srv.ListenAndServeTLS(*certFile, *keyFile)
|
||||
} else {
|
||||
mainLogger.Infof("llama-swap listening on http://%s", *listenStr)
|
||||
err = srv.ListenAndServe()
|
||||
}
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
mainLogger.Errorf("Fatal server error: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for exit signal
|
||||
<-exitChan
|
||||
}
|
||||
+227
-73
@@ -82,6 +82,78 @@
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "Timeout settings for proxy connections."
|
||||
},
|
||||
"groupsConfig": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"members"
|
||||
],
|
||||
"properties": {
|
||||
"swap": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Controls model swapping behaviour within the group. True: only one model runs at a time. False: all models can run together."
|
||||
},
|
||||
"exclusive": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Controls how the group affects other groups. True: causes all other groups to unload when this group runs a model. False: does not affect other groups."
|
||||
},
|
||||
"persistent": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Prevents other groups from unloading the models in this group. Does not affect individual model behaviour."
|
||||
},
|
||||
"members": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Array of model IDs that are members of this group. Model IDs must be defined in models."
|
||||
}
|
||||
}
|
||||
},
|
||||
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
|
||||
},
|
||||
"matrixConfig": {
|
||||
"type": "object",
|
||||
"description": "Solver-based alternative to groups. Declares valid combinations of concurrent models. The solver minimizes eviction cost when swapping. A config must use either groups or matrix, not both.",
|
||||
"required": [
|
||||
"vars",
|
||||
"sets"
|
||||
],
|
||||
"properties": {
|
||||
"vars": {
|
||||
"type": "object",
|
||||
"description": "Short names for models. Keys must be alphanumeric, 1-8 characters. All sets and evict_costs must use these IDs.",
|
||||
"minProperties": 1,
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
},
|
||||
"propertyNames": {
|
||||
"pattern": "^[a-zA-Z0-9]{1,8}$"
|
||||
}
|
||||
},
|
||||
"evict_costs": {
|
||||
"type": "object",
|
||||
"description": "Relative cost of evicting a running model. Models not listed default to 1. Values must be positive integers.",
|
||||
"additionalProperties": {
|
||||
"type": "integer",
|
||||
"minimum": 1
|
||||
}
|
||||
},
|
||||
"sets": {
|
||||
"type": "object",
|
||||
"description": "Named sets of concurrent model combinations. Values are DSL strings using & (AND), | (OR), () (grouping), and +ref (inline another set). Definition order is used for tie-breaking.",
|
||||
"minProperties": 1,
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"properties": {
|
||||
@@ -306,81 +378,68 @@
|
||||
},
|
||||
"timeouts": {
|
||||
"$ref": "#/definitions/timeouts"
|
||||
},
|
||||
"capabilities": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"in": {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"uniqueItems": true,
|
||||
"default": [],
|
||||
"items": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"text",
|
||||
"audio",
|
||||
"image"
|
||||
]
|
||||
},
|
||||
"description": "List of input modalities understood by the model."
|
||||
},
|
||||
"out": {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"uniqueItems": true,
|
||||
"default": [],
|
||||
"items": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"text",
|
||||
"audio",
|
||||
"image"
|
||||
]
|
||||
},
|
||||
"description": "List of output modalities generated by the model."
|
||||
},
|
||||
"tools": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Whether the model supports function calling."
|
||||
},
|
||||
"reranker": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Whether the model supports the /v1/rerank endpoint."
|
||||
},
|
||||
"context": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 0,
|
||||
"description": "Maximum token context length supported by the model."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "Defines what the model accepts for input, output and other metadata. Used in v1/models to inform clients what the model can do. An empty capabilities block (all zero values) is treated as not configured."
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"groups": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"members"
|
||||
],
|
||||
"properties": {
|
||||
"swap": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Controls model swapping behaviour within the group. True: only one model runs at a time. False: all models can run together."
|
||||
},
|
||||
"exclusive": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Controls how the group affects other groups. True: causes all other groups to unload when this group runs a model. False: does not affect other groups."
|
||||
},
|
||||
"persistent": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Prevents other groups from unloading the models in this group. Does not affect individual model behaviour."
|
||||
},
|
||||
"members": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Array of model IDs that are members of this group. Model IDs must be defined in models."
|
||||
}
|
||||
}
|
||||
},
|
||||
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
|
||||
"$ref": "#/definitions/groupsConfig"
|
||||
},
|
||||
"matrix": {
|
||||
"type": "object",
|
||||
"description": "Solver-based alternative to groups. Declares valid combinations of concurrent models. The solver minimizes eviction cost when swapping. A config must use either groups or matrix, not both.",
|
||||
"required": [
|
||||
"vars",
|
||||
"sets"
|
||||
],
|
||||
"properties": {
|
||||
"vars": {
|
||||
"type": "object",
|
||||
"description": "Short names for models. Keys must be alphanumeric, 1-8 characters. All sets and evict_costs must use these IDs.",
|
||||
"minProperties": 1,
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
},
|
||||
"propertyNames": {
|
||||
"pattern": "^[a-zA-Z0-9]{1,8}$"
|
||||
}
|
||||
},
|
||||
"evict_costs": {
|
||||
"type": "object",
|
||||
"description": "Relative cost of evicting a running model. Models not listed default to 1. Values must be positive integers.",
|
||||
"additionalProperties": {
|
||||
"type": "integer",
|
||||
"minimum": 1
|
||||
}
|
||||
},
|
||||
"sets": {
|
||||
"type": "object",
|
||||
"description": "Named sets of concurrent model combinations. Values are DSL strings using & (AND), | (OR), () (grouping), and +ref (inline another set). Definition order is used for tie-breaking.",
|
||||
"minProperties": 1,
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
"$ref": "#/definitions/matrixConfig"
|
||||
},
|
||||
"hooks": {
|
||||
"type": "object",
|
||||
@@ -512,28 +571,123 @@
|
||||
},
|
||||
"default": {},
|
||||
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
|
||||
},
|
||||
"upstream": {
|
||||
"type": "object",
|
||||
"description": "Controls behaviour of the /upstream passthrough endpoint. Recommended to only use in special use cases; leaving it as the default will typically be the best experience.",
|
||||
"properties": {
|
||||
"ignorePaths": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"default": [
|
||||
".*\\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$"
|
||||
],
|
||||
"description": "List of RE2 compatible regular expressions. Any request to a path matching any of the regular expressions will be ignored and not trigger a swap. When not specified, defaults to a pattern matching common static-asset suffixes (.js, .json, .css, .png, .gif, .jpg, .jpeg, .ico, .txt)."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"default": {}
|
||||
},
|
||||
"routing": {
|
||||
"type": "object",
|
||||
"description": "Canonical routing/scheduling configuration. Alternative to the legacy top-level 'groups'/'matrix' keys; a config must not use both styles.",
|
||||
"properties": {
|
||||
"scheduler": {
|
||||
"type": "object",
|
||||
"description": "Scheduler configuration. Decides the order in which queued requests are serviced.",
|
||||
"properties": {
|
||||
"use": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"fifo"
|
||||
],
|
||||
"default": "fifo",
|
||||
"description": "Scheduler to use. Only 'fifo' is currently supported."
|
||||
},
|
||||
"settings": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"fifo": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"priority": {
|
||||
"type": "object",
|
||||
"description": "Per-model priority. Keys are model IDs, values are integers (default 0). Higher values are serviced first.",
|
||||
"additionalProperties": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
},
|
||||
"router": {
|
||||
"type": "object",
|
||||
"description": "Router configuration. Selects between the group and matrix swapping strategies.",
|
||||
"properties": {
|
||||
"use": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"group",
|
||||
"matrix"
|
||||
],
|
||||
"default": "group",
|
||||
"description": "Router to use. 'group' uses static groups, 'matrix' uses the solver-based swap matrix."
|
||||
},
|
||||
"settings": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"groups": {
|
||||
"$ref": "#/definitions/groupsConfig"
|
||||
},
|
||||
"matrix": {
|
||||
"$ref": "#/definitions/matrixConfig"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"allOf": [
|
||||
{
|
||||
"if": {
|
||||
"required": ["groups"]
|
||||
"required": [
|
||||
"groups"
|
||||
]
|
||||
},
|
||||
"then": {
|
||||
"not": {
|
||||
"required": ["matrix"]
|
||||
"required": [
|
||||
"matrix"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"if": {
|
||||
"required": ["matrix"]
|
||||
"required": [
|
||||
"matrix"
|
||||
]
|
||||
},
|
||||
"then": {
|
||||
"not": {
|
||||
"required": ["groups"]
|
||||
"required": [
|
||||
"groups"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
+213
-87
@@ -134,6 +134,18 @@ apiKeys:
|
||||
- "${env.API_KEY_1}"
|
||||
- "${env.API_KEY_2}"
|
||||
|
||||
# upstream: controls behaviour of the /upstream passthrough endpoint
|
||||
# - optional, default: empty dictionary
|
||||
# - recommended to only use in special use cases. Leaving it as the
|
||||
# default will typically be the best experience
|
||||
upstream:
|
||||
# ignorePaths: list of RE2 compatible regular expressions
|
||||
# - default: (see below)
|
||||
# - any request to a path matching any of the regular expressions
|
||||
# will be ignored and not trigger a swap
|
||||
ignorePaths:
|
||||
- '.*\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$'
|
||||
|
||||
# models: a dictionary of model configurations
|
||||
# - required
|
||||
# - each key is the model's ID, used in API requests
|
||||
@@ -312,6 +324,37 @@ models:
|
||||
tlsHandshake: 10
|
||||
idleConn: 90
|
||||
|
||||
# capabilities: defines what the model accepts for input, output and other metadata
|
||||
# - optional; omitted or all-zero means no capabilities
|
||||
# - used in v1/models to inform clients what the model can do
|
||||
capabilities:
|
||||
# in: list of modalities understood by the model
|
||||
# - default: []
|
||||
# - valid: text, audio, image
|
||||
in:
|
||||
- text
|
||||
- audio
|
||||
- image
|
||||
# out: list of modalities generated by the model
|
||||
# - default: []
|
||||
# - valid: text, audio, image
|
||||
out:
|
||||
- text
|
||||
- audio
|
||||
- image
|
||||
# tools: the model supports function calling
|
||||
# - default: false
|
||||
tools: true
|
||||
|
||||
# reranker: the model supports the /v1/rerank endpoint
|
||||
# - default: false
|
||||
reranker: false
|
||||
|
||||
# context: the maximum token context length supported
|
||||
# - default: 0
|
||||
# - must be an integer > 0
|
||||
context: 32000
|
||||
|
||||
# Unlisted model example:
|
||||
"qwen-unlisted":
|
||||
# unlisted: boolean, true or false
|
||||
@@ -343,93 +386,6 @@ models:
|
||||
# - processes have 5 seconds to shutdown until forceful termination is attempted
|
||||
cmdStop: docker stop ${MODEL_ID}
|
||||
|
||||
# =============================================================================
|
||||
# matrix: run concurrent models with a solver-based swap DSL
|
||||
# =============================================================================
|
||||
#
|
||||
# Matrix or Groups?
|
||||
#
|
||||
# Groups are available and fully supported. The syntax may be easier to use
|
||||
# for simple use cases.
|
||||
#
|
||||
# Documentation can be found here:
|
||||
# https://github.com/mostlygeek/llama-swap/blob/40e39f7/config.example.yaml#L334-L396
|
||||
#
|
||||
# A config can only use a matrix (recommended) or groups. A configuration error
|
||||
# will occur if both are defined. Groups is legacy but is fully supported with
|
||||
# no plans to deprecate it.
|
||||
#
|
||||
# ~~~~~
|
||||
#
|
||||
# The matrix declares valid combinations of models that can run concurrently.
|
||||
# When a model is requested, the solver finds the cheapest way to make it
|
||||
# available by evicting as few (and least costly) running models as possible.
|
||||
#
|
||||
# Solver behavior:
|
||||
# 1. Request arrives for model X
|
||||
# 2. If X is already running, forward immediately. Done.
|
||||
# 3. Find all sets containing X
|
||||
# 4. For each candidate set, compute cost: sum of evict_costs for
|
||||
# every running model NOT in that set
|
||||
# 5. Pick lowest cost candidate. Ties broken by definition order.
|
||||
# 6. Evict what needs to stop. Start X. Forward request.
|
||||
#
|
||||
# Subset semantics: a set [a, b, c] means any subset is valid.
|
||||
# Only the requested model is started — others are not preloaded.
|
||||
#
|
||||
# A model not appearing in any set can only run alone.
|
||||
#
|
||||
matrix:
|
||||
# vars: short names for models (alphanumeric, 1-8 chars)
|
||||
# - required for sets and evict_costs settings
|
||||
# - each entry is a short name to a real model ID. Do not use an alias
|
||||
# - used to keep set DSL logic short and easier to read
|
||||
# - sets and evict_costs only use identifiers defined in vars
|
||||
vars:
|
||||
g: gemma-model
|
||||
q: qwen-model
|
||||
m: mistral-model
|
||||
v: voxtral-model
|
||||
e: reranker-model
|
||||
L: llama-70B
|
||||
sd: stable-diffusion
|
||||
|
||||
# evict_costs: relative cost of losing a running model (default: 1)
|
||||
evict_costs:
|
||||
v: 50 # vllm backend, slow cold start
|
||||
L: 30 # 70B weights, slow to load
|
||||
|
||||
# sets: named sets of concurrent model combinations
|
||||
# Values are DSL strings with operators:
|
||||
# & AND (models run together)
|
||||
# | OR (alternatives)
|
||||
# () grouping
|
||||
# +ref inline another set's expression
|
||||
#
|
||||
# Expansion examples:
|
||||
# "L" → [L]
|
||||
# "a & b" → [a, b]
|
||||
# "a | b" → [a], [b]
|
||||
# "(a | b) & c" → [a, c], [b, c]
|
||||
# "(a | b) & (c | d)" → [a,c], [a,d], [b,c], [b,d]
|
||||
# "+llms & v" → expands llms inline, then applies & v
|
||||
sets:
|
||||
# LLM + TTS: switching between g/q/m won't evict v
|
||||
# expands to: [g,v], [q,v], [m,v]
|
||||
standard: "(g | q | m) & v"
|
||||
|
||||
# LLM + TTS + reranker
|
||||
# expands to: [g,v,e], [q,v,e]
|
||||
with_rerank: "(g | q) & v & e"
|
||||
|
||||
# LLM + image generation, no TTS
|
||||
# expands to: [g,sd], [q,sd]
|
||||
creative: "(g | q) & sd"
|
||||
|
||||
# 70B model uses all GPUs, can only run alone
|
||||
# expands to: [L]
|
||||
full: "L"
|
||||
|
||||
# hooks: a dictionary of event triggers and actions
|
||||
# - optional, default: empty dictionary
|
||||
# - the only supported hook is on_startup
|
||||
@@ -446,6 +402,176 @@ hooks:
|
||||
preload:
|
||||
- "llama"
|
||||
|
||||
# routing:
|
||||
# Controls how llama-swap decides which models can run at the same time and
|
||||
# which get swapped out. Choose one of two swap engines:
|
||||
#
|
||||
# - group: the default engine. Simpler to configure. You define groups of
|
||||
# models that run together, and loading one group typically unloads
|
||||
# the others.
|
||||
#
|
||||
# - matrix: the newer engine. More involved to configure, but far more
|
||||
# flexible. It uses a small expression language to describe which
|
||||
# model combinations are allowed to run concurrently, enabling
|
||||
# setups that groups cannot express.
|
||||
#
|
||||
# The routing section is optional.
|
||||
routing:
|
||||
router:
|
||||
# use: a string defining which engine to use
|
||||
# - optional, default: "group"
|
||||
# - valid values: group, matrix
|
||||
use: group
|
||||
|
||||
# settings: a dictionary of settings for the specific engines
|
||||
settings:
|
||||
# groups: a dictionary of named groups
|
||||
# - optional, default: empty dictionary
|
||||
# - lets you keep some models loaded while others swap out
|
||||
# - every member must be a model ID defined in the models section
|
||||
# - a model can belong to only one group
|
||||
# - behaviour is set per group with the `swap`, `exclusive` and
|
||||
# `persistent` fields
|
||||
# - see issue #109 for details
|
||||
#
|
||||
# NOTE: the model names below are illustrative and are not defined above.
|
||||
groups:
|
||||
# group1 reproduces llama-swap's default behaviour: only one model
|
||||
# runs at a time across the entire instance.
|
||||
"group1":
|
||||
# swap: how members of this group swap among themselves
|
||||
# - optional, default: true
|
||||
# - true: only one member runs at a time
|
||||
# - false: all members can run together, no swapping
|
||||
swap: true
|
||||
|
||||
# exclusive: how this group affects other groups
|
||||
# - optional, default: true
|
||||
# - true: running a member unloads every other group
|
||||
# - false: running a member leaves other groups untouched
|
||||
exclusive: true
|
||||
|
||||
# members: the model IDs in this group
|
||||
# required
|
||||
members:
|
||||
- "llama"
|
||||
- "qwen-unlisted"
|
||||
|
||||
# group2: members all run together, but loading any other group
|
||||
# unloads them.
|
||||
"group2":
|
||||
# swap: false lets all members stay loaded at once
|
||||
swap: false
|
||||
|
||||
# exclusive: false means requesting a member loads it without
|
||||
# unloading any other group
|
||||
exclusive: false
|
||||
members:
|
||||
- "docker-llama"
|
||||
- "modelA"
|
||||
- "modelB"
|
||||
|
||||
# forever: a persistent group that other groups can never unload.
|
||||
"forever":
|
||||
# persistent: other groups cannot unload this group's members
|
||||
# - optional, default: false
|
||||
# - has no effect on swapping within the group
|
||||
persistent: true
|
||||
|
||||
# swap/exclusive: false keeps all members loaded and avoids
|
||||
# unloading other groups
|
||||
swap: false
|
||||
exclusive: false
|
||||
members:
|
||||
- "forever-modelA"
|
||||
- "forever-modelB"
|
||||
- "forever-modelc"
|
||||
|
||||
# The matrix lists the model combinations that are allowed to run
|
||||
# concurrently. When a model is requested, the solver makes room for it
|
||||
# by evicting as few running models as possible, preferring to keep the
|
||||
# costliest ones loaded.
|
||||
#
|
||||
# Solver behaviour:
|
||||
# 1. A request arrives for model X.
|
||||
# 2. If X is already running, forward the request. Done.
|
||||
# 3. Collect every set that contains X.
|
||||
# 4. For each set, add up the evict_costs of the running models that
|
||||
# are NOT in that set — that is the set's cost.
|
||||
# 5. Choose the lowest-cost set. Break ties by definition order.
|
||||
# 6. Evict the models outside that set, start X, forward the request.
|
||||
#
|
||||
# Subset semantics: a set [a, b, c] also permits any subset of itself.
|
||||
# Only the requested model is started; the others are not preloaded.
|
||||
#
|
||||
# A model that appears in no set can only run on its own.
|
||||
#
|
||||
matrix:
|
||||
# vars: short aliases for model IDs (alphanumeric, 1-8 chars)
|
||||
# - required: sets and evict_costs reference these names, not model IDs
|
||||
# - map each short name to a real model ID (not a model alias)
|
||||
# - keeps the set expressions short and readable
|
||||
vars:
|
||||
g: gemma-model
|
||||
q: qwen-model
|
||||
m: mistral-model
|
||||
v: voxtral-model
|
||||
e: reranker-model
|
||||
L: llama-70B
|
||||
sd: stable-diffusion
|
||||
|
||||
# evict_costs: relative cost of losing a running model (default: 1)
|
||||
evict_costs:
|
||||
v: 50 # vllm backend, slow cold start
|
||||
L: 30 # 70B weights, slow to load
|
||||
|
||||
# sets: named combinations of models that may run together.
|
||||
# Each value is an expression built from these operators:
|
||||
# & AND (models run together)
|
||||
# | OR (alternatives)
|
||||
# () grouping
|
||||
# +ref inline the expression of another set
|
||||
#
|
||||
# Each expression expands into one or more concrete sets:
|
||||
# "L" → [L]
|
||||
# "a & b" → [a, b]
|
||||
# "a | b" → [a], [b]
|
||||
# "(a | b) & c" → [a, c], [b, c]
|
||||
# "(a | b) & (c | d)" → [a,c], [a,d], [b,c], [b,d]
|
||||
# "+llms & v" → inline the llms set, then AND with v
|
||||
sets:
|
||||
# An LLM plus TTS. Switching between g/q/m keeps v loaded.
|
||||
# expands to: [g,v], [q,v], [m,v]
|
||||
standard: "(g | q | m) & v"
|
||||
|
||||
# An LLM plus TTS plus reranker.
|
||||
# expands to: [g,v,e], [q,v,e]
|
||||
with_rerank: "(g | q) & v & e"
|
||||
|
||||
# An LLM plus image generation, no TTS.
|
||||
# expands to: [g,sd], [q,sd]
|
||||
creative: "(g | q) & sd"
|
||||
|
||||
# The 70B model uses every GPU, so it can only run alone.
|
||||
# expands to: [L]
|
||||
full: "L"
|
||||
|
||||
# scheduler: how queued requests are ordered.
|
||||
# The default and only valid scheduler is "fifo"
|
||||
scheduler:
|
||||
use: fifo
|
||||
settings:
|
||||
fifo:
|
||||
# priority: a dictionary of model ID -> priority
|
||||
# - optional, default: empty dictionary
|
||||
# - models default to priority 0
|
||||
# - higher priority requests are serviced first in the queue
|
||||
priority:
|
||||
A: 10
|
||||
B: 5
|
||||
C: 5
|
||||
D: 1
|
||||
|
||||
# peers: a dictionary of remote peers and models they provide
|
||||
# - optional, default empty dictionary
|
||||
# - peers can be another llama-swap
|
||||
|
||||
@@ -2,10 +2,6 @@ ARG BASE_IMAGE=ghcr.io/ggml-org/llama.cpp
|
||||
ARG BASE_TAG=server-cuda
|
||||
FROM ${BASE_IMAGE}:${BASE_TAG}
|
||||
|
||||
# has to be after the FROM
|
||||
# TARGETARCH is auto-set by `docker buildx build --platform …` (amd64/arm64);
|
||||
# falls back to amd64 when an older `docker build` runs without buildx.
|
||||
ARG TARGETARCH=amd64
|
||||
ARG LS_VER=170
|
||||
ARG LS_REPO=mostlygeek/llama-swap
|
||||
|
||||
@@ -37,9 +33,15 @@ WORKDIR /app
|
||||
ENV PATH="/app:${PATH}"
|
||||
|
||||
RUN \
|
||||
curl -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_${TARGETARCH}.tar.gz" && \
|
||||
tar -zxf "llama-swap_${LS_VER}_linux_${TARGETARCH}.tar.gz" && \
|
||||
rm "llama-swap_${LS_VER}_linux_${TARGETARCH}.tar.gz"
|
||||
set -eux; \
|
||||
case "$(uname -m)" in \
|
||||
x86_64) ARCH=amd64 ;; \
|
||||
aarch64) ARCH=arm64 ;; \
|
||||
*) echo "unsupported arch: $(uname -m)" >&2; exit 1 ;; \
|
||||
esac; \
|
||||
curl --fail -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_${ARCH}.tar.gz" && \
|
||||
tar -zxf "llama-swap_${LS_VER}_linux_${ARCH}.tar.gz" && \
|
||||
rm "llama-swap_${LS_VER}_linux_${ARCH}.tar.gz"
|
||||
|
||||
COPY --chown=$UID:$GID config.example.yaml /app/config.yaml
|
||||
|
||||
|
||||
@@ -9,12 +9,14 @@ require (
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/fxamacker/cbor/v2 v2.9.1
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/google/jsonschema-go v0.4.3
|
||||
github.com/klauspost/compress v1.18.5
|
||||
github.com/shirou/gopsutil/v4 v4.26.4
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
golang.org/x/sync v0.20.0
|
||||
golang.org/x/sys v0.41.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
@@ -70,7 +72,6 @@ require (
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
)
|
||||
|
||||
@@ -61,6 +61,8 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0=
|
||||
github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE=
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/billziss-gh/golib/shlex"
|
||||
)
|
||||
|
||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
// Handle trailing backslashes by replacing with space
|
||||
if strings.HasSuffix(trimmed, "\\") {
|
||||
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||
} else {
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
// put it back together
|
||||
cmdStr = strings.Join(cleanedLines, "\n")
|
||||
|
||||
// Split the command into arguments
|
||||
var args []string
|
||||
if runtime.GOOS == "windows" {
|
||||
args = shlex.Windows.Split(cmdStr)
|
||||
} else {
|
||||
args = shlex.Posix.Split(cmdStr)
|
||||
}
|
||||
|
||||
// Ensure the command is not empty
|
||||
if len(args) == 0 {
|
||||
return nil, fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func StripComments(cmdStr string) string {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
return strings.Join(cleanedLines, "\n")
|
||||
}
|
||||
+40
-611
@@ -2,16 +2,9 @@ package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/billziss-gh/golib/shlex"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -85,12 +78,6 @@ type GroupConfig struct {
|
||||
Members []string `yaml:"members"`
|
||||
}
|
||||
|
||||
var (
|
||||
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
|
||||
)
|
||||
|
||||
// set default values for GroupConfig
|
||||
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawGroupConfig GroupConfig
|
||||
@@ -129,13 +116,16 @@ type Config struct {
|
||||
GlobalTTL int `yaml:"globalTTL"`
|
||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||
|
||||
// swap matrix: solver-based alternative to groups
|
||||
Matrix *MatrixConfig `yaml:"matrix"`
|
||||
// routing is the canonical source for swap/scheduling configuration.
|
||||
// New code must read Routing, never the backwards-compat fields below.
|
||||
Routing RoutingConfig `yaml:"routing"`
|
||||
|
||||
// populated during validation when matrix is configured
|
||||
ExpandedSets []ExpandedSet `yaml:"-"`
|
||||
// Groups and Matrix are permanent backwards-compat input fields for the
|
||||
// legacy top-level `groups:`/`matrix:` keys. They are normalized into
|
||||
// Routing by LoadConfigFromReader. New code must not read them directly.
|
||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||
Matrix *MatrixConfig `yaml:"matrix"`
|
||||
|
||||
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
||||
Macros MacroList `yaml:"macros"`
|
||||
@@ -160,6 +150,38 @@ type Config struct {
|
||||
|
||||
// support remote peers, see issue #433, #296
|
||||
Peers PeerDictionaryConfig `yaml:"peers"`
|
||||
|
||||
// upstream controls behaviour of the /upstream passthrough endpoint
|
||||
Upstream UpstreamConfig `yaml:"upstream"`
|
||||
}
|
||||
|
||||
// RoutingConfig is the canonical, normalized routing/scheduling configuration.
|
||||
type RoutingConfig struct {
|
||||
Scheduler SchedulerConfig `yaml:"scheduler"`
|
||||
Router RouterConfig `yaml:"router"`
|
||||
}
|
||||
|
||||
type SchedulerConfig struct {
|
||||
Use string `yaml:"use"` // default "fifo"
|
||||
Settings SchedulerSettings `yaml:"settings"`
|
||||
}
|
||||
|
||||
type SchedulerSettings struct {
|
||||
Fifo FifoConfig `yaml:"fifo"`
|
||||
}
|
||||
|
||||
type FifoConfig struct {
|
||||
Priority map[string]int `yaml:"priority"` // model ID -> priority, default 0
|
||||
}
|
||||
|
||||
type RouterConfig struct {
|
||||
Use string `yaml:"use"` // "group" (default) | "matrix"
|
||||
Settings RouterSettings `yaml:"settings"`
|
||||
}
|
||||
|
||||
type RouterSettings struct {
|
||||
Groups map[string]GroupConfig `yaml:"groups"`
|
||||
Matrix *MatrixConfig `yaml:"matrix"`
|
||||
}
|
||||
|
||||
func (c *Config) RealModelName(search string) (string, bool) {
|
||||
@@ -189,369 +211,6 @@ func LoadConfig(path string) (Config, error) {
|
||||
return LoadConfigFromReader(file)
|
||||
}
|
||||
|
||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
yamlStr := string(data)
|
||||
|
||||
// Phase 1: Substitute all ${env.VAR} macros at string level
|
||||
// This is safe because env values are simple strings without YAML formatting
|
||||
yamlStr, err = substituteEnvMacros(yamlStr)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
// Unmarshal into full Config with defaults
|
||||
config := Config{
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
GlobalTTL: 0,
|
||||
}
|
||||
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
// Apply defaults for performance config when section is missing
|
||||
if config.Performance.Every == 0 {
|
||||
config.Performance.Every = 5 * time.Second
|
||||
}
|
||||
if err = config.Performance.Validate(); err != nil {
|
||||
return Config{}, fmt.Errorf("performance: %w", err)
|
||||
}
|
||||
|
||||
if config.StartPort < 1 {
|
||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||
}
|
||||
|
||||
if config.GlobalTTL < 0 {
|
||||
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
||||
}
|
||||
|
||||
switch config.LogToStdout {
|
||||
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
||||
default:
|
||||
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
if _, found := config.aliases[alias]; found {
|
||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||
}
|
||||
config.aliases[alias] = modelName
|
||||
}
|
||||
}
|
||||
|
||||
// Validate global macros
|
||||
for _, macro := range config.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get and sort all model IDs for consistent port assignment
|
||||
modelIds := make([]string, 0, len(config.Models))
|
||||
for modelId := range config.Models {
|
||||
modelIds = append(modelIds, modelId)
|
||||
}
|
||||
sort.Strings(modelIds)
|
||||
|
||||
nextPort := config.StartPort
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
modelConfig.HealthCheckTimeout = config.HealthCheckTimeout
|
||||
|
||||
// Strip comments from command fields
|
||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||
|
||||
// set model TTL to globalTTL it is the default value
|
||||
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
|
||||
modelConfig.UnloadAfter = config.GlobalTTL
|
||||
}
|
||||
|
||||
if modelConfig.UnloadAfter < 0 {
|
||||
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
|
||||
}
|
||||
|
||||
// Validate model macros
|
||||
for _, macro := range modelConfig.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||
mergedMacros = append(mergedMacros, config.Macros...)
|
||||
|
||||
// Add model macros (override globals with same name)
|
||||
for _, entry := range modelConfig.Macros {
|
||||
found := false
|
||||
for i, existing := range mergedMacros {
|
||||
if existing.Name == entry.Name {
|
||||
mergedMacros[i] = entry
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
mergedMacros = append(mergedMacros, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Substitute remaining macros in model fields (LIFO order)
|
||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||
entry := mergedMacros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||
|
||||
// Substitute macros in SetParamsByID keys and values
|
||||
if len(modelConfig.Filters.SetParamsByID) > 0 {
|
||||
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
|
||||
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
|
||||
}
|
||||
newParamMap, ok := newValAny.(map[string]any)
|
||||
if !ok {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
|
||||
}
|
||||
newSetParamsByID[newKey] = newParamMap
|
||||
}
|
||||
modelConfig.Filters.SetParamsByID = newSetParamsByID
|
||||
}
|
||||
|
||||
// Substitute in metadata (type-preserving)
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle PORT macro - only allocate if cmd uses it
|
||||
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||
if cmdHasPort || proxyHasPort {
|
||||
if !cmdHasPort && proxyHasPort {
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
macroSlug := "${PORT}"
|
||||
macroStr := fmt.Sprintf("%v", nextPort)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
|
||||
nextPort++
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
fieldMap := map[string]string{
|
||||
"cmd": modelConfig.Cmd,
|
||||
"cmdStop": modelConfig.CmdStop,
|
||||
"proxy": modelConfig.Proxy,
|
||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||
"filters.stripParams": modelConfig.Filters.StripParams,
|
||||
"name": modelConfig.Name,
|
||||
"description": modelConfig.Description,
|
||||
}
|
||||
|
||||
for fieldName, fieldValue := range fieldMap {
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
if macroName == "PID" && fieldName == "cmdStop" {
|
||||
continue // replaced at runtime
|
||||
}
|
||||
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate SetParamsByID keys and values
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
|
||||
}
|
||||
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
|
||||
for key := range modelConfig.Filters.SetParamsByID {
|
||||
if key == modelId {
|
||||
continue
|
||||
}
|
||||
if _, exists := config.Models[key]; exists {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
|
||||
}
|
||||
if existingModel, exists := config.aliases[key]; exists {
|
||||
if existingModel != modelId {
|
||||
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
|
||||
}
|
||||
continue // already registered as explicit alias for this model
|
||||
}
|
||||
config.aliases[key] = modelId
|
||||
modelConfig.Aliases = append(modelConfig.Aliases, key)
|
||||
}
|
||||
|
||||
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
|
||||
}
|
||||
|
||||
if modelConfig.SendLoadingState == nil {
|
||||
v := config.SendLoadingState
|
||||
modelConfig.SendLoadingState = &v
|
||||
}
|
||||
|
||||
config.Models[modelId] = modelConfig
|
||||
}
|
||||
|
||||
// groups XOR matrix
|
||||
if config.Matrix != nil && len(config.Groups) > 0 {
|
||||
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
|
||||
}
|
||||
|
||||
if config.Matrix != nil {
|
||||
expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("matrix: %w", err)
|
||||
}
|
||||
config.ExpandedSets = expandedSets
|
||||
} else {
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
|
||||
// Validate group members
|
||||
memberUsage := make(map[string]string)
|
||||
for groupID, groupConfig := range config.Groups {
|
||||
prevSet := make(map[string]bool)
|
||||
for _, member := range groupConfig.Members {
|
||||
if _, found := prevSet[member]; found {
|
||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||
}
|
||||
prevSet[member] = true
|
||||
|
||||
if existingGroup, exists := memberUsage[member]; exists {
|
||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||
}
|
||||
memberUsage[member] = groupID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up hooks preload
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
var toPreload []string
|
||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
continue
|
||||
}
|
||||
if real, found := config.RealModelName(modelID); found {
|
||||
toPreload = append(toPreload, real)
|
||||
}
|
||||
}
|
||||
config.Hooks.OnStartup.Preload = toPreload
|
||||
}
|
||||
|
||||
// Validate API keys (env macros already substituted at string level)
|
||||
for i, apikey := range config.RequiredAPIKeys {
|
||||
if apikey == "" {
|
||||
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
||||
}
|
||||
if strings.Contains(apikey, " ") {
|
||||
return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey)
|
||||
}
|
||||
config.RequiredAPIKeys[i] = apikey
|
||||
}
|
||||
|
||||
// Process peers with global macro substitution
|
||||
for peerName, peerConfig := range config.Peers {
|
||||
// Substitute global macros (LIFO order)
|
||||
for i := len(config.Macros) - 1; i >= 0; i-- {
|
||||
entry := config.Macros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
|
||||
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute in setParams (type-preserving)
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
|
||||
}
|
||||
peerConfig.Filters.SetParams = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
config.Peers[peerName] = peerConfig
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// rewrites the yaml to include a default group with any orphaned models
|
||||
func AddDefaultGroupToConfig(config Config) Config {
|
||||
|
||||
@@ -596,233 +255,3 @@ func AddDefaultGroupToConfig(config Config) Config {
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
// Handle trailing backslashes by replacing with space
|
||||
if strings.HasSuffix(trimmed, "\\") {
|
||||
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||
} else {
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
// put it back together
|
||||
cmdStr = strings.Join(cleanedLines, "\n")
|
||||
|
||||
// Split the command into arguments
|
||||
var args []string
|
||||
if runtime.GOOS == "windows" {
|
||||
args = shlex.Windows.Split(cmdStr)
|
||||
} else {
|
||||
args = shlex.Posix.Split(cmdStr)
|
||||
}
|
||||
|
||||
// Ensure the command is not empty
|
||||
if len(args) == 0 {
|
||||
return nil, fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func StripComments(cmdStr string) string {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
return strings.Join(cleanedLines, "\n")
|
||||
}
|
||||
|
||||
// validateMacro validates macro name and value constraints
|
||||
func validateMacro(name string, value any) error {
|
||||
if len(name) >= 64 {
|
||||
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
|
||||
}
|
||||
if !macroNameRegex.MatchString(name) {
|
||||
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
|
||||
}
|
||||
|
||||
// Validate that value is a scalar type
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check for self-reference
|
||||
macroSlug := fmt.Sprintf("${%s}", name)
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
||||
// These types are allowed
|
||||
default:
|
||||
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
|
||||
}
|
||||
|
||||
switch name {
|
||||
case "PORT", "MODEL_ID":
|
||||
return fmt.Errorf("macro name '%s' is reserved", name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
|
||||
func validateNestedForUnknownMacros(value any, context string) error {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
|
||||
}
|
||||
// Check for unsubstituted env macros
|
||||
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range envMatches {
|
||||
varName := match[1]
|
||||
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
|
||||
}
|
||||
return nil
|
||||
|
||||
case map[string]any:
|
||||
for _, val := range v {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
case []any:
|
||||
for _, val := range v {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
default:
|
||||
// Scalar types don't contain macros
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||
// This is called once per macro, allowing LIFO substitution order
|
||||
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
macroStr := fmt.Sprintf("%v", macroValue)
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check if this is a direct macro substitution
|
||||
if v == macroSlug {
|
||||
return macroValue, nil
|
||||
}
|
||||
// Handle string interpolation
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case map[string]any:
|
||||
// Recursively process map values
|
||||
newMap := make(map[string]any)
|
||||
for key, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newMap[key] = newVal
|
||||
}
|
||||
return newMap, nil
|
||||
|
||||
case []any:
|
||||
// Recursively process slice elements
|
||||
newSlice := make([]any, len(v))
|
||||
for i, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSlice[i] = newVal
|
||||
}
|
||||
return newSlice, nil
|
||||
|
||||
default:
|
||||
// Return scalar types as-is
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
|
||||
// Returns error if any referenced env var is not set or contains invalid characters.
|
||||
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
|
||||
// (which strips comments) and only checking the comment-free version for macros.
|
||||
func substituteEnvMacros(s string) (string, error) {
|
||||
// Unmarshal and remarshal to strip YAML comments
|
||||
var raw any
|
||||
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
|
||||
// If YAML is invalid, fall back to scanning the original string
|
||||
// so the user gets the env var error rather than a confusing YAML parse error
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
clean, err := yaml.Marshal(raw)
|
||||
if err != nil {
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
|
||||
return substituteEnvMacrosInString(s, string(clean))
|
||||
}
|
||||
|
||||
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
|
||||
// them in target. This separation allows scanning comment-free YAML while
|
||||
// substituting in the original string.
|
||||
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
|
||||
result := target
|
||||
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
|
||||
for _, match := range matches {
|
||||
fullMatch := match[0] // ${env.VAR_NAME}
|
||||
varName := match[1] // VAR_NAME
|
||||
|
||||
value, exists := os.LookupEnv(varName)
|
||||
if !exists {
|
||||
return "", fmt.Errorf("environment variable '%s' is not set", varName)
|
||||
}
|
||||
|
||||
// Sanitize the value for safe YAML substitution
|
||||
value, err := sanitizeEnvValueForYAML(value, varName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result = strings.ReplaceAll(result, fullMatch, value)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
|
||||
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
|
||||
// for compatibility with double-quoted YAML strings.
|
||||
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
|
||||
// Reject values that would break YAML structure regardless of quoting context
|
||||
if strings.ContainsAny(value, "\n\r\x00") {
|
||||
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
|
||||
}
|
||||
|
||||
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
|
||||
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
|
||||
// In double-quoted contexts, they are interpreted correctly.
|
||||
value = strings.ReplaceAll(value, `\`, `\\`)
|
||||
value = strings.ReplaceAll(value, `"`, `\"`)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
@@ -173,6 +173,25 @@ groups:
|
||||
IdleConn: 90,
|
||||
}
|
||||
|
||||
expectedGroups := map[string]GroupConfig{
|
||||
DEFAULT_GROUP_ID: {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model3"},
|
||||
},
|
||||
"group1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model4"},
|
||||
},
|
||||
}
|
||||
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
@@ -246,22 +265,19 @@ groups:
|
||||
"m2": "model2",
|
||||
"mthree": "model3",
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
DEFAULT_GROUP_ID: {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model3"},
|
||||
Groups: expectedGroups,
|
||||
Upstream: UpstreamConfig{
|
||||
IgnorePaths: DefaultUpstreamIgnorePaths(),
|
||||
},
|
||||
Routing: RoutingConfig{
|
||||
Router: RouterConfig{
|
||||
Use: "group",
|
||||
Settings: RouterSettings{
|
||||
Groups: expectedGroups,
|
||||
},
|
||||
},
|
||||
"group1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model4"},
|
||||
Scheduler: SchedulerConfig{
|
||||
Use: "fifo",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/google/jsonschema-go/jsonschema"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// TestConfig_ExampleMatchesSchema validates that config.example.yaml conforms to
|
||||
// config-schema.json. Both files live at the repository root.
|
||||
func TestConfig_ExampleMatchesSchema(t *testing.T) {
|
||||
const (
|
||||
schemaPath = "../../config-schema.json"
|
||||
examplePath = "../../config.example.yaml"
|
||||
)
|
||||
|
||||
schemaBytes, err := os.ReadFile(schemaPath)
|
||||
if err != nil {
|
||||
t.Fatalf("reading %s: %v", schemaPath, err)
|
||||
}
|
||||
|
||||
var schema jsonschema.Schema
|
||||
if err := json.Unmarshal(schemaBytes, &schema); err != nil {
|
||||
t.Fatalf("unmarshalling schema: %v", err)
|
||||
}
|
||||
|
||||
resolved, err := schema.Resolve(&jsonschema.ResolveOptions{
|
||||
BaseURI: "https://github.com/mostlygeek/llama-swap/",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("resolving schema: %v", err)
|
||||
}
|
||||
|
||||
exampleBytes, err := os.ReadFile(examplePath)
|
||||
if err != nil {
|
||||
t.Fatalf("reading %s: %v", examplePath, err)
|
||||
}
|
||||
|
||||
// Convert YAML to a JSON-like value so numbers and keys match what the
|
||||
// validator expects.
|
||||
var yamlValue any
|
||||
if err := yaml.Unmarshal(exampleBytes, &yamlValue); err != nil {
|
||||
t.Fatalf("unmarshalling example yaml: %v", err)
|
||||
}
|
||||
jsonBytes, err := json.Marshal(yamlValue)
|
||||
if err != nil {
|
||||
t.Fatalf("converting example to json: %v", err)
|
||||
}
|
||||
var instance any
|
||||
if err := json.Unmarshal(jsonBytes, &instance); err != nil {
|
||||
t.Fatalf("unmarshalling example json: %v", err)
|
||||
}
|
||||
|
||||
if err := resolved.Validate(instance); err != nil {
|
||||
t.Fatalf("config.example.yaml does not match config-schema.json:\n%v", err)
|
||||
}
|
||||
}
|
||||
@@ -777,22 +777,27 @@ func TestConfig_APIKeys_Invalid(t *testing.T) {
|
||||
{
|
||||
name: "blank spaces only",
|
||||
content: `apiKeys: [" "]`,
|
||||
expectedErr: "api key cannot contain spaces: ` `",
|
||||
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||
},
|
||||
{
|
||||
name: "contains leading space",
|
||||
content: `apiKeys: [" key123"]`,
|
||||
expectedErr: "api key cannot contain spaces: ` key123`",
|
||||
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||
},
|
||||
{
|
||||
name: "contains trailing space",
|
||||
content: `apiKeys: ["key123 "]`,
|
||||
expectedErr: "api key cannot contain spaces: `key123 `",
|
||||
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||
},
|
||||
{
|
||||
name: "contains middle space",
|
||||
content: `apiKeys: ["key 123"]`,
|
||||
expectedErr: "api key cannot contain spaces: `key 123`",
|
||||
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||
},
|
||||
{
|
||||
name: "space in second key reports correct index",
|
||||
content: `apiKeys: ["valid-key", "bad key"]`,
|
||||
expectedErr: "apiKeys[1]: api key cannot contain spaces",
|
||||
},
|
||||
{
|
||||
name: "empty in list with valid keys",
|
||||
@@ -1544,3 +1549,174 @@ peers:
|
||||
assert.Equal(t, 1, peerConfig.Timeouts.ExpectContinue)
|
||||
assert.Equal(t, 90, peerConfig.Timeouts.IdleConn)
|
||||
}
|
||||
|
||||
// twoModels is a minimal models block reused by the routing tests below.
|
||||
const twoModels = `
|
||||
models:
|
||||
gemma:
|
||||
cmd: echo gemma
|
||||
proxy: http://localhost:8080
|
||||
qwen:
|
||||
cmd: echo qwen
|
||||
proxy: http://localhost:8081
|
||||
`
|
||||
|
||||
func TestConfig_Routing_LegacyTopLevelGroups(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
groups:
|
||||
g1:
|
||||
members: [gemma, qwen]
|
||||
`
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
||||
// default group injected for orphaned models (none here) still leaves g1
|
||||
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
|
||||
assert.Equal(t, "fifo", cfg.Routing.Scheduler.Use)
|
||||
}
|
||||
|
||||
func TestConfig_Routing_LegacyTopLevelMatrix(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
matrix:
|
||||
vars:
|
||||
g: gemma
|
||||
q: qwen
|
||||
sets:
|
||||
combo: "g | q"
|
||||
`
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "matrix", cfg.Routing.Router.Use)
|
||||
require.NotNil(t, cfg.Routing.Router.Settings.Matrix)
|
||||
assert.Len(t, cfg.Routing.Router.Settings.Matrix.ExpandedSets, 2)
|
||||
}
|
||||
|
||||
func TestConfig_Routing_RouterUseMatrix(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
routing:
|
||||
router:
|
||||
use: matrix
|
||||
settings:
|
||||
matrix:
|
||||
vars:
|
||||
g: gemma
|
||||
q: qwen
|
||||
sets:
|
||||
combo: "g | q"
|
||||
`
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "matrix", cfg.Routing.Router.Use)
|
||||
require.NotNil(t, cfg.Routing.Router.Settings.Matrix)
|
||||
assert.Len(t, cfg.Routing.Router.Settings.Matrix.ExpandedSets, 2)
|
||||
}
|
||||
|
||||
func TestConfig_Routing_RouterUseGroup(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
routing:
|
||||
router:
|
||||
use: group
|
||||
settings:
|
||||
groups:
|
||||
g1:
|
||||
members: [gemma, qwen]
|
||||
`
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
||||
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
|
||||
}
|
||||
|
||||
func TestConfig_Routing_DefaultsToGroup(t *testing.T) {
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(twoModels))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
||||
assert.Equal(t, "fifo", cfg.Routing.Scheduler.Use)
|
||||
}
|
||||
|
||||
func TestConfig_Routing_LegacyAndRoutingConflict(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
groups:
|
||||
g1:
|
||||
members: [gemma, qwen]
|
||||
routing:
|
||||
router:
|
||||
use: group
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "migrate")
|
||||
}
|
||||
|
||||
func TestConfig_Routing_RouterUseMatrixWithoutSettings(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
routing:
|
||||
router:
|
||||
use: matrix
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "routing.router.settings.matrix is not set")
|
||||
}
|
||||
|
||||
// Both groups and matrix may be defined under routing.router.settings;
|
||||
// routing.router.use selects which one is active.
|
||||
func TestConfig_Routing_RouterSettingsBothGroupsAndMatrix(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
routing:
|
||||
router:
|
||||
use: group
|
||||
settings:
|
||||
groups:
|
||||
g1:
|
||||
members: [gemma, qwen]
|
||||
matrix:
|
||||
sets:
|
||||
s: "gemma"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
// use: group means groups are active and matrix is ignored
|
||||
assert.Equal(t, "group", config.Routing.Router.Use)
|
||||
assert.Nil(t, config.Matrix)
|
||||
assert.Contains(t, config.Groups, "g1")
|
||||
}
|
||||
|
||||
func TestConfig_Routing_UnknownRouter(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
routing:
|
||||
router:
|
||||
use: bogus
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown router")
|
||||
}
|
||||
|
||||
func TestConfig_Routing_FifoPriorityUnknownModel(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
routing:
|
||||
scheduler:
|
||||
settings:
|
||||
fifo:
|
||||
priority:
|
||||
nope: 5
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown model")
|
||||
}
|
||||
|
||||
func TestConfig_Routing_FifoPriorityKnownModel(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
routing:
|
||||
scheduler:
|
||||
settings:
|
||||
fifo:
|
||||
priority:
|
||||
gemma: 5
|
||||
`
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 5, cfg.Routing.Scheduler.Settings.Fifo.Priority["gemma"])
|
||||
}
|
||||
|
||||
@@ -165,6 +165,25 @@ groups:
|
||||
IdleConn: 90,
|
||||
}
|
||||
|
||||
expectedGroups := map[string]GroupConfig{
|
||||
DEFAULT_GROUP_ID: {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model3"},
|
||||
},
|
||||
"group1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model4"},
|
||||
},
|
||||
}
|
||||
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
@@ -235,22 +254,19 @@ groups:
|
||||
"m2": "model2",
|
||||
"mthree": "model3",
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
DEFAULT_GROUP_ID: {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model3"},
|
||||
Groups: expectedGroups,
|
||||
Upstream: UpstreamConfig{
|
||||
IgnorePaths: DefaultUpstreamIgnorePaths(),
|
||||
},
|
||||
Routing: RoutingConfig{
|
||||
Router: RouterConfig{
|
||||
Use: "group",
|
||||
Settings: RouterSettings{
|
||||
Groups: expectedGroups,
|
||||
},
|
||||
},
|
||||
"group1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model4"},
|
||||
Scheduler: SchedulerConfig{
|
||||
Use: "fifo",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -0,0 +1,436 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
yamlStr := string(data)
|
||||
|
||||
// Phase 1: Substitute all ${env.VAR} macros at string level
|
||||
// This is safe because env values are simple strings without YAML formatting
|
||||
yamlStr, err = substituteEnvMacros(yamlStr)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
// Unmarshal into full Config with defaults
|
||||
config := Config{
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
GlobalTTL: 0,
|
||||
}
|
||||
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
// Apply defaults for performance config when section is missing
|
||||
if config.Performance.Every == 0 {
|
||||
config.Performance.Every = 5 * time.Second
|
||||
}
|
||||
if err = config.Performance.Validate(); err != nil {
|
||||
return Config{}, fmt.Errorf("performance: %w", err)
|
||||
}
|
||||
|
||||
if config.StartPort < 1 {
|
||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||
}
|
||||
|
||||
if config.GlobalTTL < 0 {
|
||||
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
||||
}
|
||||
|
||||
// Apply default for upstream.ignorePaths when not specified. The default
|
||||
// matches common static-asset suffixes so they do not trigger a swap.
|
||||
if len(config.Upstream.IgnorePaths) == 0 {
|
||||
config.Upstream.IgnorePaths = DefaultUpstreamIgnorePaths()
|
||||
}
|
||||
|
||||
switch config.LogToStdout {
|
||||
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
||||
default:
|
||||
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
if _, found := config.aliases[alias]; found {
|
||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||
}
|
||||
config.aliases[alias] = modelName
|
||||
}
|
||||
}
|
||||
|
||||
// Validate global macros
|
||||
for _, macro := range config.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get and sort all model IDs for consistent port assignment
|
||||
modelIds := make([]string, 0, len(config.Models))
|
||||
for modelId := range config.Models {
|
||||
modelIds = append(modelIds, modelId)
|
||||
}
|
||||
sort.Strings(modelIds)
|
||||
|
||||
nextPort := config.StartPort
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
modelConfig.HealthCheckTimeout = config.HealthCheckTimeout
|
||||
|
||||
// Strip comments from command fields
|
||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||
|
||||
// set model TTL to globalTTL it is the default value
|
||||
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
|
||||
modelConfig.UnloadAfter = config.GlobalTTL
|
||||
}
|
||||
|
||||
if modelConfig.UnloadAfter < 0 {
|
||||
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
|
||||
}
|
||||
|
||||
// Validate model macros
|
||||
for _, macro := range modelConfig.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||
mergedMacros = append(mergedMacros, config.Macros...)
|
||||
|
||||
// Add model macros (override globals with same name)
|
||||
for _, entry := range modelConfig.Macros {
|
||||
found := false
|
||||
for i, existing := range mergedMacros {
|
||||
if existing.Name == entry.Name {
|
||||
mergedMacros[i] = entry
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
mergedMacros = append(mergedMacros, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Substitute remaining macros in model fields (LIFO order)
|
||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||
entry := mergedMacros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||
|
||||
// Substitute macros in SetParamsByID keys and values
|
||||
if len(modelConfig.Filters.SetParamsByID) > 0 {
|
||||
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
|
||||
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
|
||||
}
|
||||
newParamMap, ok := newValAny.(map[string]any)
|
||||
if !ok {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
|
||||
}
|
||||
newSetParamsByID[newKey] = newParamMap
|
||||
}
|
||||
modelConfig.Filters.SetParamsByID = newSetParamsByID
|
||||
}
|
||||
|
||||
// Substitute in metadata (type-preserving)
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle PORT macro - only allocate if cmd uses it
|
||||
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||
if cmdHasPort || proxyHasPort {
|
||||
if !cmdHasPort && proxyHasPort {
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
macroSlug := "${PORT}"
|
||||
macroStr := fmt.Sprintf("%v", nextPort)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
|
||||
nextPort++
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
fieldMap := map[string]string{
|
||||
"cmd": modelConfig.Cmd,
|
||||
"cmdStop": modelConfig.CmdStop,
|
||||
"proxy": modelConfig.Proxy,
|
||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||
"filters.stripParams": modelConfig.Filters.StripParams,
|
||||
"name": modelConfig.Name,
|
||||
"description": modelConfig.Description,
|
||||
}
|
||||
|
||||
for fieldName, fieldValue := range fieldMap {
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
if macroName == "PID" && fieldName == "cmdStop" {
|
||||
continue // replaced at runtime
|
||||
}
|
||||
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if err = modelConfig.Capabilities.Validate(); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %w", modelId, err)
|
||||
}
|
||||
|
||||
// Validate SetParamsByID keys and values
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
|
||||
}
|
||||
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
|
||||
for key := range modelConfig.Filters.SetParamsByID {
|
||||
if key == modelId {
|
||||
continue
|
||||
}
|
||||
if _, exists := config.Models[key]; exists {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
|
||||
}
|
||||
if existingModel, exists := config.aliases[key]; exists {
|
||||
if existingModel != modelId {
|
||||
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
|
||||
}
|
||||
continue // already registered as explicit alias for this model
|
||||
}
|
||||
config.aliases[key] = modelId
|
||||
modelConfig.Aliases = append(modelConfig.Aliases, key)
|
||||
}
|
||||
|
||||
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
|
||||
}
|
||||
|
||||
if modelConfig.SendLoadingState == nil {
|
||||
v := config.SendLoadingState
|
||||
modelConfig.SendLoadingState = &v
|
||||
}
|
||||
|
||||
config.Models[modelId] = modelConfig
|
||||
}
|
||||
|
||||
// Normalize routing config. The legacy top-level `matrix`/`groups` keys and
|
||||
// the new `routing.router` block are mutually exclusive: a config may use
|
||||
// either style, never both.
|
||||
hasTopLevel := config.Matrix != nil || len(config.Groups) > 0
|
||||
rtr := config.Routing.Router
|
||||
hasRouting := rtr.Use != "" || rtr.Settings.Matrix != nil || len(rtr.Settings.Groups) > 0
|
||||
|
||||
if hasTopLevel && hasRouting {
|
||||
return Config{}, fmt.Errorf("config uses both the legacy top-level 'matrix'/'groups' keys and the new 'routing.router' block; please migrate the top-level keys into 'routing.router' and remove them")
|
||||
}
|
||||
|
||||
if !hasTopLevel {
|
||||
// Both groups and matrix may be defined under routing.router.settings;
|
||||
// routing.router.use selects which one is active, so there is no conflict.
|
||||
rs := config.Routing.Router.Settings
|
||||
switch config.Routing.Router.Use {
|
||||
case "matrix":
|
||||
if rs.Matrix == nil {
|
||||
return Config{}, fmt.Errorf("routing.router.use is 'matrix' but routing.router.settings.matrix is not set")
|
||||
}
|
||||
config.Matrix = rs.Matrix
|
||||
case "group", "":
|
||||
config.Groups = rs.Groups
|
||||
default:
|
||||
return Config{}, fmt.Errorf("routing.router.use: unknown router %q (valid: group, matrix)", config.Routing.Router.Use)
|
||||
}
|
||||
}
|
||||
|
||||
// groups XOR matrix
|
||||
if config.Matrix != nil && len(config.Groups) > 0 {
|
||||
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
|
||||
}
|
||||
|
||||
if config.Matrix != nil {
|
||||
expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("matrix: %w", err)
|
||||
}
|
||||
config.Matrix.ExpandedSets = expandedSets
|
||||
} else {
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
|
||||
// Validate group members
|
||||
memberUsage := make(map[string]string)
|
||||
for groupID, groupConfig := range config.Groups {
|
||||
prevSet := make(map[string]bool)
|
||||
for _, member := range groupConfig.Members {
|
||||
if _, found := prevSet[member]; found {
|
||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||
}
|
||||
prevSet[member] = true
|
||||
|
||||
if existingGroup, exists := memberUsage[member]; exists {
|
||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||
}
|
||||
memberUsage[member] = groupID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build the canonical Config.Routing from the effective result. Both legacy
|
||||
// and new-style configs converge here. The Matrix pointer is shared so
|
||||
// ExpandedSets stays in one place.
|
||||
if config.Matrix != nil {
|
||||
config.Routing.Router.Use = "matrix"
|
||||
} else {
|
||||
config.Routing.Router.Use = "group"
|
||||
}
|
||||
config.Routing.Router.Settings.Matrix = config.Matrix
|
||||
config.Routing.Router.Settings.Groups = config.Groups
|
||||
|
||||
if config.Routing.Scheduler.Use == "" {
|
||||
config.Routing.Scheduler.Use = "fifo"
|
||||
}
|
||||
if config.Routing.Scheduler.Use != "fifo" {
|
||||
return Config{}, fmt.Errorf("routing.scheduler.use: unknown scheduler %q (valid: fifo)", config.Routing.Scheduler.Use)
|
||||
}
|
||||
for modelID := range config.Routing.Scheduler.Settings.Fifo.Priority {
|
||||
if _, found := config.RealModelName(modelID); !found {
|
||||
return Config{}, fmt.Errorf("routing.scheduler.settings.fifo.priority references unknown model %q", modelID)
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up hooks preload
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
var toPreload []string
|
||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
continue
|
||||
}
|
||||
if real, found := config.RealModelName(modelID); found {
|
||||
toPreload = append(toPreload, real)
|
||||
}
|
||||
}
|
||||
config.Hooks.OnStartup.Preload = toPreload
|
||||
}
|
||||
|
||||
// Validate API keys (env macros already substituted at string level)
|
||||
for i, apikey := range config.RequiredAPIKeys {
|
||||
if apikey == "" {
|
||||
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
||||
}
|
||||
if strings.Contains(apikey, " ") {
|
||||
return Config{}, fmt.Errorf("apiKeys[%d]: api key cannot contain spaces", i)
|
||||
}
|
||||
config.RequiredAPIKeys[i] = apikey
|
||||
}
|
||||
|
||||
// Process peers with global macro substitution
|
||||
for peerName, peerConfig := range config.Peers {
|
||||
// Substitute global macros (LIFO order)
|
||||
for i := len(config.Macros) - 1; i >= 0; i-- {
|
||||
entry := config.Macros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
|
||||
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute in setParams (type-preserving)
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
|
||||
}
|
||||
peerConfig.Filters.SetParams = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
config.Peers[peerName] = peerConfig
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
|
||||
)
|
||||
|
||||
// validateMacro validates macro name and value constraints
|
||||
func validateMacro(name string, value any) error {
|
||||
if len(name) >= 64 {
|
||||
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
|
||||
}
|
||||
if !macroNameRegex.MatchString(name) {
|
||||
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
|
||||
}
|
||||
|
||||
// Validate that value is a scalar type
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check for self-reference
|
||||
macroSlug := fmt.Sprintf("${%s}", name)
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
||||
// These types are allowed
|
||||
default:
|
||||
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
|
||||
}
|
||||
|
||||
switch name {
|
||||
case "PORT", "MODEL_ID":
|
||||
return fmt.Errorf("macro name '%s' is reserved", name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
|
||||
func validateNestedForUnknownMacros(value any, context string) error {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
|
||||
}
|
||||
// Check for unsubstituted env macros
|
||||
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range envMatches {
|
||||
varName := match[1]
|
||||
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
|
||||
}
|
||||
return nil
|
||||
|
||||
case map[string]any:
|
||||
for _, val := range v {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
case []any:
|
||||
for _, val := range v {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
default:
|
||||
// Scalar types don't contain macros
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||
// This is called once per macro, allowing LIFO substitution order
|
||||
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
macroStr := fmt.Sprintf("%v", macroValue)
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check if this is a direct macro substitution
|
||||
if v == macroSlug {
|
||||
return macroValue, nil
|
||||
}
|
||||
// Handle string interpolation
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case map[string]any:
|
||||
// Recursively process map values
|
||||
newMap := make(map[string]any)
|
||||
for key, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newMap[key] = newVal
|
||||
}
|
||||
return newMap, nil
|
||||
|
||||
case []any:
|
||||
// Recursively process slice elements
|
||||
newSlice := make([]any, len(v))
|
||||
for i, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSlice[i] = newVal
|
||||
}
|
||||
return newSlice, nil
|
||||
|
||||
default:
|
||||
// Return scalar types as-is
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
|
||||
// Returns error if any referenced env var is not set or contains invalid characters.
|
||||
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
|
||||
// (which strips comments) and only checking the comment-free version for macros.
|
||||
func substituteEnvMacros(s string) (string, error) {
|
||||
// Unmarshal and remarshal to strip YAML comments
|
||||
var raw any
|
||||
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
|
||||
// If YAML is invalid, fall back to scanning the original string
|
||||
// so the user gets the env var error rather than a confusing YAML parse error
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
clean, err := yaml.Marshal(raw)
|
||||
if err != nil {
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
|
||||
return substituteEnvMacrosInString(s, string(clean))
|
||||
}
|
||||
|
||||
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
|
||||
// them in target. This separation allows scanning comment-free YAML while
|
||||
// substituting in the original string.
|
||||
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
|
||||
result := target
|
||||
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
|
||||
for _, match := range matches {
|
||||
fullMatch := match[0] // ${env.VAR_NAME}
|
||||
varName := match[1] // VAR_NAME
|
||||
|
||||
value, exists := os.LookupEnv(varName)
|
||||
if !exists {
|
||||
return "", fmt.Errorf("environment variable '%s' is not set", varName)
|
||||
}
|
||||
|
||||
// Sanitize the value for safe YAML substitution
|
||||
value, err := sanitizeEnvValueForYAML(value, varName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result = strings.ReplaceAll(result, fullMatch, value)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
|
||||
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
|
||||
// for compatibility with double-quoted YAML strings.
|
||||
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
|
||||
// Reject values that would break YAML structure regardless of quoting context
|
||||
if strings.ContainsAny(value, "\n\r\x00") {
|
||||
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
|
||||
}
|
||||
|
||||
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
|
||||
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
|
||||
// In double-quoted contexts, they are interpreted correctly.
|
||||
value = strings.ReplaceAll(value, `\`, `\\`)
|
||||
value = strings.ReplaceAll(value, `"`, `\"`)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
@@ -15,6 +15,9 @@ type MatrixConfig struct {
|
||||
Var map[string]string `yaml:"vars"`
|
||||
EvictCosts map[string]int `yaml:"evict_costs"`
|
||||
Sets OrderedSets `yaml:"sets"`
|
||||
|
||||
// populated by ValidateMatrix; not settable from yaml
|
||||
ExpandedSets []ExpandedSet `yaml:"-"`
|
||||
}
|
||||
|
||||
// SetEntry is a single named set with its DSL expression.
|
||||
|
||||
@@ -289,7 +289,9 @@ matrix:
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, cfg.Matrix)
|
||||
assert.Len(t, cfg.ExpandedSets, 2)
|
||||
assert.Len(t, cfg.Matrix.ExpandedSets, 2)
|
||||
assert.Equal(t, "matrix", cfg.Routing.Router.Use)
|
||||
assert.Len(t, cfg.Routing.Router.Settings.Matrix.ExpandedSets, 2)
|
||||
// Groups should be empty when matrix is used
|
||||
assert.Empty(t, cfg.Groups)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,300 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// identityMapPaths is the set of dotted paths whose direct children are
|
||||
// identity-keyed maps. A child key present in two sources is a hard error;
|
||||
// such keys name discrete entities (a model, a group, a peer, etc.) and a
|
||||
// duplicate means the user has split one entity across files by mistake.
|
||||
var identityMapPaths = map[string]bool{
|
||||
"models": true,
|
||||
"groups": true,
|
||||
"profiles": true,
|
||||
"peers": true,
|
||||
"matrix": true,
|
||||
"routing.router.settings.groups": true,
|
||||
"routing.router.settings.matrix": true,
|
||||
}
|
||||
|
||||
// LoadConfigSources loads and merges configuration from -config (optional)
|
||||
// and -config-dir (optional). At least one must be provided. The -config file
|
||||
// is loaded first; *.yml/*.yaml files directly under -config-dir are then
|
||||
// merged in sorted filename order. The merged document is passed through the
|
||||
// existing LoadConfigFromReader pipeline unchanged.
|
||||
func LoadConfigSources(configPath, configDir string) (Config, error) {
|
||||
if configPath == "" && configDir == "" {
|
||||
return Config{}, fmt.Errorf("at least one of -config or -config-dir must be provided")
|
||||
}
|
||||
|
||||
var sourcePaths []string
|
||||
|
||||
if configPath != "" {
|
||||
sourcePaths = append(sourcePaths, configPath)
|
||||
}
|
||||
|
||||
if configDir != "" {
|
||||
dirFiles, err := listYAMLFiles(configDir)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("-config-dir %s: %w", configDir, err)
|
||||
}
|
||||
|
||||
if configPath != "" {
|
||||
absConfig, err := filepath.Abs(configPath)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("failed to resolve -config path: %w", err)
|
||||
}
|
||||
for _, f := range dirFiles {
|
||||
absF, err := filepath.Abs(f)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("failed to resolve config dir file %s: %w", f, err)
|
||||
}
|
||||
if absConfig == absF {
|
||||
return Config{}, fmt.Errorf("-config path %s is also present in -config-dir %s; remove it from one", configPath, configDir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sourcePaths = append(sourcePaths, dirFiles...)
|
||||
}
|
||||
|
||||
if len(sourcePaths) == 0 {
|
||||
return Config{}, fmt.Errorf("no configuration sources found")
|
||||
}
|
||||
|
||||
var merged *yaml.Node
|
||||
for _, p := range sourcePaths {
|
||||
node, err := parseSource(p)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
if node == nil {
|
||||
continue // empty file
|
||||
}
|
||||
if merged == nil {
|
||||
merged = node
|
||||
continue
|
||||
}
|
||||
if err := mergeNodes(merged, node, "", p); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if merged == nil {
|
||||
// All sources were empty; run the pipeline on empty input so defaults
|
||||
// and validation still apply (e.g. startPort, performance defaults).
|
||||
return LoadConfigFromReader(strings.NewReader(""))
|
||||
}
|
||||
|
||||
out, err := yaml.Marshal(merged)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("failed to marshal merged config: %w", err)
|
||||
}
|
||||
return LoadConfigFromReader(strings.NewReader(string(out)))
|
||||
}
|
||||
|
||||
// listYAMLFiles returns the top-level *.yml and *.yaml files in dir, sorted by
|
||||
// filename for deterministic merge order. Subdirectories are not traversed.
|
||||
func listYAMLFiles(dir string) ([]string, error) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var files []string
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := e.Name()
|
||||
if !strings.HasSuffix(name, ".yml") && !strings.HasSuffix(name, ".yaml") {
|
||||
continue
|
||||
}
|
||||
files = append(files, filepath.Join(dir, name))
|
||||
}
|
||||
sort.Strings(files)
|
||||
return files, nil
|
||||
}
|
||||
|
||||
// parseSource reads and parses one YAML config file into a root mapping node.
|
||||
// Returns a nil node (no error) when the file is empty or contains only
|
||||
// comments.
|
||||
//
|
||||
// Env macros (${env.VAR}) are substituted at the string level before YAML
|
||||
// parsing so that flow-style constructs like [${env.API_KEY}] parse
|
||||
// correctly — the brace would otherwise be interpreted as a flow mapping.
|
||||
func parseSource(path string) (*yaml.Node, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config %s: %w", path, err)
|
||||
}
|
||||
yamlStr, err := substituteEnvMacros(string(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config %s: %w", path, err)
|
||||
}
|
||||
var doc yaml.Node
|
||||
if err := yaml.Unmarshal([]byte(yamlStr), &doc); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config %s: %w", path, err)
|
||||
}
|
||||
// yaml.Unmarshal into a yaml.Node yields a DocumentNode whose Content[0]
|
||||
// is the actual root. Unwrap it so callers see the real top-level node.
|
||||
root := &doc
|
||||
if root.Kind == yaml.DocumentNode && len(root.Content) > 0 {
|
||||
root = root.Content[0]
|
||||
}
|
||||
if root.Kind == 0 || root.Content == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if root.Kind != yaml.MappingNode {
|
||||
return nil, fmt.Errorf("config %s: top-level YAML must be a mapping", path)
|
||||
}
|
||||
return root, nil
|
||||
}
|
||||
|
||||
// mergeNodes merges src into dst (both MappingNodes) in place. Keys present in
|
||||
// only one side are kept; shared keys are merged recursively under the rules
|
||||
// in mergeValue. srcPath is included in error messages to identify the file
|
||||
// that introduced the conflict.
|
||||
func mergeNodes(dst, src *yaml.Node, path, srcPath string) error {
|
||||
srcIdx := indexMapping(src)
|
||||
|
||||
// First pass: merge shared keys in place.
|
||||
for i := 0; i+1 < len(dst.Content); i += 2 {
|
||||
keyNode := dst.Content[i]
|
||||
dstVal := dst.Content[i+1]
|
||||
key := keyNode.Value
|
||||
|
||||
srcVal, ok := srcIdx[key]
|
||||
if !ok {
|
||||
continue // dst-only key, keep as-is
|
||||
}
|
||||
|
||||
childPath := joinPath(path, key)
|
||||
|
||||
if identityMapPaths[childPath] {
|
||||
// Identity-keyed map: each child key names a discrete entity
|
||||
// (a model, group, peer, ...). A shared child key is a hard
|
||||
// error; src-only children are appended in the second pass.
|
||||
if err := mergeIdentityMap(dstVal, srcVal, childPath, key, srcPath); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err := mergeValue(dstVal, srcVal, childPath, srcPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: append src-only keys.
|
||||
dstIdx := indexMapping(dst)
|
||||
for i := 0; i+1 < len(src.Content); i += 2 {
|
||||
keyNode := src.Content[i]
|
||||
srcVal := src.Content[i+1]
|
||||
key := keyNode.Value
|
||||
|
||||
if _, ok := dstIdx[key]; ok {
|
||||
continue // already merged above
|
||||
}
|
||||
keyCopy := *keyNode
|
||||
valCopy := *srcVal
|
||||
dst.Content = append(dst.Content, &keyCopy, &valCopy)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// mergeIdentityMap merges two identity-keyed mapping nodes (e.g. `models`,
|
||||
// `groups`, `peers`). Any child key present in both sides is a duplicate
|
||||
// entity and produces an error naming the conflicting key and source file.
|
||||
// src-only keys are appended to dst.
|
||||
func mergeIdentityMap(dst, src *yaml.Node, path, mapName, srcPath string) error {
|
||||
if dst.Kind != yaml.MappingNode || src.Kind != yaml.MappingNode {
|
||||
return fmt.Errorf("conflict at %q: expected a mapping, introduced by %s", path, srcPath)
|
||||
}
|
||||
dstIdx := indexMapping(dst)
|
||||
for i := 0; i+1 < len(src.Content); i += 2 {
|
||||
keyNode := src.Content[i]
|
||||
srcVal := src.Content[i+1]
|
||||
key := keyNode.Value
|
||||
if _, dup := dstIdx[key]; dup {
|
||||
return fmt.Errorf("duplicate %s %q found in %s (already defined in another config source)", mapName, key, srcPath)
|
||||
}
|
||||
keyCopy := *keyNode
|
||||
valCopy := *srcVal
|
||||
dst.Content = append(dst.Content, &keyCopy, &valCopy)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// mergeValue merges srcVal into dstVal (both pointing into the parent's
|
||||
// Content slice). Mapping+Mapping recurses; Sequence+Sequence concatenates;
|
||||
// Scalar+Scalar errors on value mismatch; null on either side yields to the
|
||||
// non-null side.
|
||||
func mergeValue(dstVal, srcVal *yaml.Node, path, srcPath string) error {
|
||||
switch {
|
||||
case dstVal.Kind == yaml.MappingNode && srcVal.Kind == yaml.MappingNode:
|
||||
return mergeNodes(dstVal, srcVal, path, srcPath)
|
||||
|
||||
case dstVal.Kind == yaml.SequenceNode && srcVal.Kind == yaml.SequenceNode:
|
||||
dstVal.Content = append(dstVal.Content, srcVal.Content...)
|
||||
return nil
|
||||
|
||||
case dstVal.Kind == yaml.ScalarNode && srcVal.Kind == yaml.ScalarNode:
|
||||
if isNullScalar(dstVal) {
|
||||
*dstVal = *srcVal
|
||||
return nil
|
||||
}
|
||||
if isNullScalar(srcVal) {
|
||||
return nil
|
||||
}
|
||||
if dstVal.Value == srcVal.Value && dstVal.Tag == srcVal.Tag {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("conflict at %q: %s sets a different value than a previous source", path, srcPath)
|
||||
|
||||
case isNull(dstVal):
|
||||
*dstVal = *srcVal
|
||||
return nil
|
||||
|
||||
case isNull(srcVal):
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("conflict at %q: incompatible YAML node kinds (kind %d vs %d) introduced by %s", path, dstVal.Kind, srcVal.Kind, srcPath)
|
||||
}
|
||||
}
|
||||
|
||||
// isNull reports whether n represents a YAML null (empty or !!null).
|
||||
func isNull(n *yaml.Node) bool {
|
||||
if n == nil || n.Kind == 0 {
|
||||
return true
|
||||
}
|
||||
return isNullScalar(n)
|
||||
}
|
||||
|
||||
func isNullScalar(n *yaml.Node) bool {
|
||||
return n.Kind == yaml.ScalarNode && (n.Tag == "!!null" || n.Tag == "") && n.Value == ""
|
||||
}
|
||||
|
||||
// indexMapping builds a key -> value-node index for a mapping node.
|
||||
func indexMapping(n *yaml.Node) map[string]*yaml.Node {
|
||||
idx := make(map[string]*yaml.Node, len(n.Content)/2)
|
||||
for i := 0; i+1 < len(n.Content); i += 2 {
|
||||
idx[n.Content[i].Value] = n.Content[i+1]
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func joinPath(parent, key string) string {
|
||||
if parent == "" {
|
||||
return key
|
||||
}
|
||||
return parent + "." + key
|
||||
}
|
||||
@@ -0,0 +1,304 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// writeYAML writes content to a file named name inside dir. Returns the full
|
||||
// path of the written file.
|
||||
func writeYAML(t *testing.T, dir, name, content string) string {
|
||||
t.Helper()
|
||||
p := filepath.Join(dir, name)
|
||||
require.NoError(t, os.MkdirAll(filepath.Dir(p), 0o755))
|
||||
require.NoError(t, os.WriteFile(p, []byte(content), 0o644))
|
||||
return p
|
||||
}
|
||||
|
||||
// modelCfg builds a single-model YAML snippet indented for nesting under a
|
||||
// `models:` key. The proxy uses a fixed port so tests don't depend on
|
||||
// ${PORT} allocation.
|
||||
func modelCfg(id, cmd string) string {
|
||||
return " " + id + ":\n cmd: " + cmd + "\n proxy: \"http://localhost:9999\"\n"
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_NeitherProvided(t *testing.T) {
|
||||
_, err := LoadConfigSources("", "")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "at least one of -config or -config-dir")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_ConfigOnly(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := writeYAML(t, dir, "config.yaml", `
|
||||
models:
|
||||
`+modelCfg("model1", "echo hi")+`
|
||||
groups:
|
||||
group1:
|
||||
members: ["model1"]
|
||||
`)
|
||||
cfg, err := LoadConfigSources(cfgPath, "")
|
||||
require.NoError(t, err)
|
||||
_, id, ok := cfg.FindConfig("model1")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "model1", id)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_DirOnly(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("alpha", "echo a"))
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("beta", "echo b"))
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
for _, want := range []string{"alpha", "beta"} {
|
||||
_, _, ok := cfg.FindConfig(want)
|
||||
assert.True(t, ok, "model %s should be present", want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_ConfigPlusDirAdditive(t *testing.T) {
|
||||
// -config lives outside -config-dir; both contribute models additively.
|
||||
dir := t.TempDir()
|
||||
cfgPath := writeYAML(t, dir, "config.yaml", "models:\n"+modelCfg("base", "echo base"))
|
||||
cfgDir := t.TempDir()
|
||||
writeYAML(t, cfgDir, "extra.yaml", "models:\n"+modelCfg("ext", "echo ext"))
|
||||
|
||||
cfg, err := LoadConfigSources(cfgPath, cfgDir)
|
||||
require.NoError(t, err)
|
||||
for _, want := range []string{"base", "ext"} {
|
||||
_, _, ok := cfg.FindConfig(want)
|
||||
assert.True(t, ok, "model %s should be present after merge", want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfigSources_ConfigInDirOverlap verifies that a -config file that
|
||||
// is also a member of -config-dir is rejected.
|
||||
func TestLoadConfigSources_ConfigInDirOverlap(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := writeYAML(t, dir, "main.yaml", "models:\n"+modelCfg("base", "echo base"))
|
||||
|
||||
_, err := LoadConfigSources(cfgPath, dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "is also present in -config-dir")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_DuplicateModelID(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("dup", "echo a"))
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("dup", "echo b"))
|
||||
|
||||
_, err := LoadConfigSources("", dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), `duplicate models "dup"`)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_DuplicateGroupID(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", `
|
||||
models:
|
||||
`+modelCfg("m1", "echo m1")+"groups:\n g1:\n members: [m1]\n")
|
||||
writeYAML(t, dir, "b.yaml", `
|
||||
models:
|
||||
`+modelCfg("m2", "echo m2")+"groups:\n g1:\n members: [m2]\n")
|
||||
|
||||
_, err := LoadConfigSources("", dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), `duplicate groups "g1"`)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_DuplicatePeer(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
peerA := "peers:\n remote:\n proxy: http://x:1\n models: [m1]\n"
|
||||
peerB := "peers:\n remote:\n proxy: http://x:2\n models: [m2]\n"
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\n"+peerA)
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\n"+peerB)
|
||||
|
||||
_, err := LoadConfigSources("", dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), `duplicate peers "remote"`)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_ScalarConflict(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\nglobalTTL: 100\n")
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\nglobalTTL: 200\n")
|
||||
|
||||
_, err := LoadConfigSources("", dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), `conflict at "globalTTL"`)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_ScalarSameValueNoConflict(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\nglobalTTL: 100\n")
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\nglobalTTL: 100\n")
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, cfg.GlobalTTL)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_MacrosConcatenate(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "macros:\n LOW: 1\nmodels:\n"+modelCfg("m1", "echo ${LOW}"))
|
||||
writeYAML(t, dir, "b.yaml", "macros:\n HIGH: 2\nmodels:\n"+modelCfg("m2", "echo ${HIGH}"))
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
// Both macros are available globally after merge.
|
||||
low, ok := cfg.Macros.Get("LOW")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, 1, low)
|
||||
high, ok := cfg.Macros.Get("HIGH")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, 2, high)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_APIKeysConcatenate(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\napiKeys: [key-a]\n")
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\napiKeys: [key-b]\n")
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"key-a", "key-b"}, cfg.RequiredAPIKeys)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_RoutingGroupsMerge(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", `
|
||||
models:
|
||||
`+modelCfg("m1", "echo m1")+`
|
||||
routing:
|
||||
router:
|
||||
settings:
|
||||
groups:
|
||||
groupA:
|
||||
members: [m1]
|
||||
`)
|
||||
writeYAML(t, dir, "b.yaml", `
|
||||
models:
|
||||
`+modelCfg("m2", "echo m2")+`
|
||||
routing:
|
||||
router:
|
||||
settings:
|
||||
groups:
|
||||
groupB:
|
||||
members: [m2]
|
||||
`)
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
groups := cfg.Routing.Router.Settings.Groups
|
||||
assert.Contains(t, groups, "groupA")
|
||||
assert.Contains(t, groups, "groupB")
|
||||
// default group added by pipeline for orphaned/leftover routing groups...
|
||||
// here both groups reference distinct models
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_EnvMacrosSubstituted(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
// Use ${PORT} in cmd so the pipeline allocates a port and substitutes it;
|
||||
// verifies env/macro substitution runs on the merged document.
|
||||
writeYAML(t, dir, "a.yaml", "models:\n m1:\n cmd: serve --port ${PORT}\n proxy: \"http://localhost:${PORT}\"\n")
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
m := cfg.Models["m1"]
|
||||
assert.NotContains(t, m.Cmd, "${PORT}", "PORT macro should have been substituted")
|
||||
assert.NotContains(t, m.Proxy, "${PORT}", "PORT macro should have been substituted in proxy")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_EnvMacroInFlowStyleList(t *testing.T) {
|
||||
// Regression: flow-style lists with ${env.*} must parse. Previously
|
||||
// parseSource unmarshalled before env substitution, so the brace in
|
||||
// [${env.API_KEY}] was misread as a flow mapping and parsing failed.
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n m1:\n cmd: echo hi\n proxy: \"http://localhost:9999\"\n")
|
||||
writeYAML(t, dir, "keys.yaml", "apiKeys: [${env.TEST_API_KEY}]\nmodels:\n m2:\n cmd: echo hi\n proxy: \"http://localhost:9998\"\n")
|
||||
|
||||
t.Setenv("TEST_API_KEY", "secret123")
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, cfg.RequiredAPIKeys, "secret123")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_SortedOrderDeterministic(t *testing.T) {
|
||||
// Two files defining distinct models, scanned in z..a order by filename.
|
||||
// Determine merged result is the same regardless of how the FS returns them.
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "z.yaml", "models:\n"+modelCfg("zmodel", "echo z"))
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("amodel", "echo a"))
|
||||
|
||||
const runs = 3
|
||||
for i := 0; i < runs; i++ {
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
// startPort-based allocation: first allocated model gets 5800.
|
||||
// Sorted order means amodel gets 5800, zmodel gets 5801.
|
||||
_, _, ok := cfg.FindConfig("amodel")
|
||||
assert.True(t, ok)
|
||||
_, _, ok = cfg.FindConfig("zmodel")
|
||||
assert.True(t, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_EmptyDirWithConfig(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgDir := t.TempDir()
|
||||
cfgPath := writeYAML(t, dir, "main.yaml", "models:\n"+modelCfg("m1", "echo m1"))
|
||||
|
||||
cfg, err := LoadConfigSources(cfgPath, cfgDir)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, cfg.Models, "m1")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_EmptyDirOnly(t *testing.T) {
|
||||
// An empty -config-dir with no -config is an error: there is nothing to
|
||||
// load and silently producing an empty config would mask the misconfig.
|
||||
cfgDir := t.TempDir()
|
||||
_, err := LoadConfigSources("", cfgDir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no configuration sources found")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_AssertNoUnknownMacrosAfterMerge(t *testing.T) {
|
||||
// Macros defined in one file should not satisfy unknown-macro validation in
|
||||
// another — they do, because merge concats global macros before validation
|
||||
// runs. This test documents that a macro from file A is usable in file B.
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "macros.yaml", "macros:\n SHARED: hello\nmodels:\n"+modelCfg("dummy", "echo dummy"))
|
||||
writeYAML(t, dir, "use.yaml", "models:\n"+modelCfg("user", "echo ${SHARED}"))
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
m := cfg.Models["user"]
|
||||
assert.Contains(t, m.Cmd, "hello")
|
||||
assert.NotContains(t, m.Cmd, "${SHARED}")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_KindMismatchErrors(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "startPort: 5800\nmodels:\n"+modelCfg("m1", "echo m1"))
|
||||
writeYAML(t, dir, "b.yaml", "startPort: [5800, 5801]\nmodels:\n"+modelCfg("m2", "echo m2"))
|
||||
|
||||
_, err := LoadConfigSources("", dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "incompatible YAML node kinds")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_NullYieldsToValue(t *testing.T) {
|
||||
// File A: routing.router block absent (null on root for routing);
|
||||
// file B: defines routing.router.settings.groups. Merge should keep B's.
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1"))
|
||||
writeYAML(t, dir, "b.yaml", "routing:\n router:\n settings:\n groups:\n g1:\n members: [m1]\nmodels:\n"+modelCfg("m2", "echo m2"))
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
@@ -9,6 +10,47 @@ const (
|
||||
MODEL_CONFIG_DEFAULT_TTL = -1
|
||||
)
|
||||
|
||||
var validModalities = map[string]struct{}{
|
||||
"text": {},
|
||||
"audio": {},
|
||||
"image": {},
|
||||
}
|
||||
|
||||
// ModelCapConfig defines what modalities and features a model supports.
|
||||
// Used in /v1/models to inform clients. An empty block (all zero values) is
|
||||
// treated as not configured.
|
||||
type ModelCapConfig struct {
|
||||
In []string `yaml:"in"`
|
||||
Out []string `yaml:"out"`
|
||||
Tools bool `yaml:"tools"`
|
||||
Reranker bool `yaml:"reranker"`
|
||||
Context int `yaml:"context"`
|
||||
}
|
||||
|
||||
// Empty returns true when all fields are at their zero values.
|
||||
func (c ModelCapConfig) Empty() bool {
|
||||
return len(c.In) == 0 && len(c.Out) == 0 && !c.Tools && !c.Reranker && c.Context == 0
|
||||
}
|
||||
|
||||
// Validate checks that all modality values are recognized and context is
|
||||
// non-negative. Returns an error if any value is invalid.
|
||||
func (c ModelCapConfig) Validate() error {
|
||||
for _, m := range c.In {
|
||||
if _, ok := validModalities[m]; !ok {
|
||||
return fmt.Errorf("capabilities.in: invalid modality %q, must be one of: text, audio, image", m)
|
||||
}
|
||||
}
|
||||
for _, m := range c.Out {
|
||||
if _, ok := validModalities[m]; !ok {
|
||||
return fmt.Errorf("capabilities.out: invalid modality %q, must be one of: text, audio, image", m)
|
||||
}
|
||||
}
|
||||
if c.Context < 0 {
|
||||
return errors.New("capabilities.context: must be >= 0")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TimeoutsConfig holds timeout settings for proxy connections
|
||||
// 0 = no timeout
|
||||
type TimeoutsConfig struct {
|
||||
@@ -55,6 +97,9 @@ type ModelConfig struct {
|
||||
// Timeout settings for proxy connections
|
||||
Timeouts TimeoutsConfig `yaml:"timeouts"`
|
||||
|
||||
// Capabilities defines what modalities and features the model supports.
|
||||
Capabilities ModelCapConfig `yaml:"capabilities"`
|
||||
|
||||
// Copy of HealthCheckTimeout from global config
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
}
|
||||
|
||||
@@ -152,7 +152,7 @@ models:
|
||||
stop:
|
||||
- "<|end|>"
|
||||
- "<|stop|>"
|
||||
`
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -170,3 +170,167 @@ models:
|
||||
assert.Equal(t, 0.7, setParams["temperature"])
|
||||
assert.Equal(t, 0.9, setParams["top_p"])
|
||||
}
|
||||
|
||||
func TestConfig_ModelCapabilities(t *testing.T) {
|
||||
t.Run("all fields", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
capabilities:
|
||||
in:
|
||||
- text
|
||||
- audio
|
||||
- image
|
||||
out:
|
||||
- text
|
||||
- audio
|
||||
- image
|
||||
tools: true
|
||||
context: 32000
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
mc := config.Models["model1"]
|
||||
assert.False(t, mc.Capabilities.Empty())
|
||||
assert.Equal(t, []string{"text", "audio", "image"}, mc.Capabilities.In)
|
||||
assert.Equal(t, []string{"text", "audio", "image"}, mc.Capabilities.Out)
|
||||
assert.True(t, mc.Capabilities.Tools)
|
||||
assert.Equal(t, 32000, mc.Capabilities.Context)
|
||||
})
|
||||
|
||||
t.Run("partial fields", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
capabilities:
|
||||
tools: true
|
||||
context: 8192
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
mc := config.Models["model1"]
|
||||
assert.False(t, mc.Capabilities.Empty())
|
||||
assert.Nil(t, mc.Capabilities.In)
|
||||
assert.Nil(t, mc.Capabilities.Out)
|
||||
assert.True(t, mc.Capabilities.Tools)
|
||||
assert.Equal(t, 8192, mc.Capabilities.Context)
|
||||
})
|
||||
|
||||
t.Run("not set", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
mc := config.Models["model1"]
|
||||
assert.True(t, mc.Capabilities.Empty())
|
||||
})
|
||||
|
||||
t.Run("tools false is empty", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
capabilities:
|
||||
tools: false
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
mc := config.Models["model1"]
|
||||
assert.True(t, mc.Capabilities.Empty())
|
||||
})
|
||||
|
||||
t.Run("reranker true is not empty", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
capabilities:
|
||||
reranker: true
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
mc := config.Models["model1"]
|
||||
assert.False(t, mc.Capabilities.Empty())
|
||||
assert.True(t, mc.Capabilities.Reranker)
|
||||
})
|
||||
|
||||
t.Run("reranker false is empty", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
capabilities:
|
||||
reranker: false
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
mc := config.Models["model1"]
|
||||
assert.True(t, mc.Capabilities.Empty())
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_ModelCapabilities_Validate(t *testing.T) {
|
||||
t.Run("valid_modalities", func(t *testing.T) {
|
||||
caps := ModelCapConfig{
|
||||
In: []string{"text", "image"},
|
||||
Out: []string{"text", "audio"},
|
||||
Tools: true,
|
||||
Context: 100000,
|
||||
}
|
||||
assert.NoError(t, caps.Validate())
|
||||
})
|
||||
|
||||
t.Run("empty_is_valid", func(t *testing.T) {
|
||||
caps := ModelCapConfig{}
|
||||
assert.NoError(t, caps.Validate())
|
||||
})
|
||||
|
||||
t.Run("invalid_in_modality", func(t *testing.T) {
|
||||
caps := ModelCapConfig{In: []string{"video"}}
|
||||
err := caps.Validate()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "capabilities.in")
|
||||
assert.Contains(t, err.Error(), "video")
|
||||
})
|
||||
|
||||
t.Run("invalid_out_modality", func(t *testing.T) {
|
||||
caps := ModelCapConfig{Out: []string{"video"}}
|
||||
err := caps.Validate()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "capabilities.out")
|
||||
assert.Contains(t, err.Error(), "video")
|
||||
})
|
||||
|
||||
t.Run("negative_context", func(t *testing.T) {
|
||||
caps := ModelCapConfig{Context: -1}
|
||||
err := caps.Validate()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "capabilities.context")
|
||||
})
|
||||
|
||||
t.Run("rejects_invalid_at_load", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
capabilities:
|
||||
in:
|
||||
- text
|
||||
- video
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "video")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// DefaultUpstreamIgnorePathsPattern is the default regular expression applied
|
||||
// to upstream.ignorePaths when the section is empty or absent from the config.
|
||||
// It matches common static-asset suffixes so requests for .js/.css/.png/etc.
|
||||
// files do not trigger a model swap.
|
||||
const DefaultUpstreamIgnorePathsPattern = `.*\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$`
|
||||
|
||||
// DefaultUpstreamIgnorePaths returns the default compiled ignore paths used
|
||||
// when upstream.ignorePaths is not specified in the config. The returned slice
|
||||
// is fresh so callers may mutate it without affecting other configs.
|
||||
func DefaultUpstreamIgnorePaths() []*regexp.Regexp {
|
||||
return []*regexp.Regexp{regexp.MustCompile(DefaultUpstreamIgnorePathsPattern)}
|
||||
}
|
||||
|
||||
// UpstreamConfig controls behaviour of the /upstream passthrough endpoint.
|
||||
type UpstreamConfig struct {
|
||||
// IgnorePaths is a slice of compiled regular expressions. Any request to
|
||||
// /upstream/<model>/<path> whose remaining path matches any of these
|
||||
// expressions will be ignored and not trigger a swap. When the config
|
||||
// does not specify any patterns, DefaultUpstreamIgnorePaths is applied.
|
||||
IgnorePaths []*regexp.Regexp `yaml:"-"`
|
||||
}
|
||||
|
||||
// rawUpstreamConfig is the intermediate form used to unmarshal the YAML into
|
||||
// plain strings, which are then compiled into *regexp.Regexp.
|
||||
type rawUpstreamConfig struct {
|
||||
IgnorePaths []string `yaml:"ignorePaths"`
|
||||
}
|
||||
|
||||
// UnmarshalYAML compiles each ignorePaths entry into a *regexp.Regexp. If any
|
||||
// entry fails to compile, an error is returned.
|
||||
func (u *UpstreamConfig) UnmarshalYAML(value *yaml.Node) error {
|
||||
var raw rawUpstreamConfig
|
||||
if err := value.Decode(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
patterns := make([]*regexp.Regexp, 0, len(raw.IgnorePaths))
|
||||
for _, p := range raw.IgnorePaths {
|
||||
re, err := regexp.Compile(p)
|
||||
if err != nil {
|
||||
return fmt.Errorf("upstream.ignorePaths: invalid regular expression %q: %w", p, err)
|
||||
}
|
||||
patterns = append(patterns, re)
|
||||
}
|
||||
u.IgnorePaths = patterns
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const upstreamConfigHeader = `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
|
||||
func TestConfig_UpstreamIgnorePaths_DefaultWhenAbsent(t *testing.T) {
|
||||
// When upstream is not specified at all, the default pattern is applied.
|
||||
content := upstreamConfigHeader
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cfg.Upstream.IgnorePaths, 1)
|
||||
|
||||
def := cfg.Upstream.IgnorePaths[0]
|
||||
assert.IsType(t, ®exp.Regexp{}, def)
|
||||
assert.Equal(t, DefaultUpstreamIgnorePathsPattern, def.String())
|
||||
|
||||
// The default matches common static-asset suffixes.
|
||||
assert.True(t, def.MatchString("/foo.js"))
|
||||
assert.True(t, def.MatchString("/bar/baz.json"))
|
||||
assert.True(t, def.MatchString("/static/img.png"))
|
||||
assert.True(t, def.MatchString("/notes.txt"))
|
||||
assert.True(t, def.MatchString("/favicon.ico"))
|
||||
// And does not match inference API paths.
|
||||
assert.False(t, def.MatchString("/v1/chat/completions"))
|
||||
assert.False(t, def.MatchString("/v1/models"))
|
||||
assert.False(t, def.MatchString("/health"))
|
||||
}
|
||||
|
||||
func TestConfig_UpstreamIgnorePaths_DefaultWhenSectionEmpty(t *testing.T) {
|
||||
// When upstream is present but ignorePaths is omitted, the default is still
|
||||
// applied.
|
||||
content := `upstream: {}` + "\n" + upstreamConfigHeader
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cfg.Upstream.IgnorePaths, 1)
|
||||
assert.Equal(t, DefaultUpstreamIgnorePathsPattern, cfg.Upstream.IgnorePaths[0].String())
|
||||
}
|
||||
|
||||
func TestConfig_UpstreamIgnorePaths_Compiles(t *testing.T) {
|
||||
content := `
|
||||
upstream:
|
||||
ignorePaths:
|
||||
- ".*\\.(js|json|css|png|gif|jpg|jpeg|txt)$"
|
||||
- "^/static/.*"
|
||||
` + upstreamConfigHeader
|
||||
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cfg.Upstream.IgnorePaths, 2)
|
||||
|
||||
// Verify the patterns are compiled into *regexp.Regexp and match as expected.
|
||||
assert.True(t, cfg.Upstream.IgnorePaths[0].MatchString("/foo.js"))
|
||||
assert.True(t, cfg.Upstream.IgnorePaths[0].MatchString("/bar/baz.json"))
|
||||
assert.False(t, cfg.Upstream.IgnorePaths[0].MatchString("/v1/chat/completions"))
|
||||
assert.True(t, cfg.Upstream.IgnorePaths[1].MatchString("/static/foo.png"))
|
||||
assert.False(t, cfg.Upstream.IgnorePaths[1].MatchString("/v1/chat/completions"))
|
||||
|
||||
// Confirm the type is *regexp.Regexp to satisfy the API contract.
|
||||
for _, re := range cfg.Upstream.IgnorePaths {
|
||||
assert.IsType(t, ®exp.Regexp{}, re)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_UpstreamIgnorePaths_InvalidRegexReturnsError(t *testing.T) {
|
||||
content := `
|
||||
upstream:
|
||||
ignorePaths:
|
||||
- "[invalid("
|
||||
` + upstreamConfigHeader
|
||||
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "upstream.ignorePaths")
|
||||
assert.Contains(t, err.Error(), "invalid regular expression")
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
package perf
|
||||
|
||||
type LUID struct {
|
||||
LowPart uint32
|
||||
HighPart int32
|
||||
}
|
||||
|
||||
const maxEnumAdapters = 16
|
||||
|
||||
type D3DKMT_ENUMADAPTERS2 struct {
|
||||
NumAdapters uint32
|
||||
pAdapters uintptr
|
||||
}
|
||||
|
||||
type D3DKMT_ADAPTERINFO struct {
|
||||
hAdapter uint32
|
||||
AdapterLuid LUID
|
||||
NumOfSources uint32
|
||||
bPresentMoveRegionsPreferred int32
|
||||
}
|
||||
|
||||
type D3DKMT_OPENADAPTERFROMLUID struct {
|
||||
AdapterLuid LUID
|
||||
hAdapter uint32
|
||||
}
|
||||
|
||||
type D3DKMT_CLOSEADAPTER struct {
|
||||
hAdapter uint32
|
||||
}
|
||||
|
||||
type KMTQUERYADAPTERINFOTYPE int32
|
||||
|
||||
const (
|
||||
KMTQAITYPE_UMDRIVERPRIVATE KMTQUERYADAPTERINFOTYPE = 0
|
||||
KMTQAITYPE_ADAPTERREGISTRYINFO KMTQUERYADAPTERINFOTYPE = 8
|
||||
KMTQAITYPE_DRIVERVERSION KMTQUERYADAPTERINFOTYPE = 13
|
||||
KMTQAITYPE_PHYSICALADAPTERDEVICEIDS KMTQUERYADAPTERINFOTYPE = 31
|
||||
KMTQAITYPE_NODEPERFDATA KMTQUERYADAPTERINFOTYPE = 61
|
||||
KMTQAITYPE_ADAPTERPERFDATA KMTQUERYADAPTERINFOTYPE = 62
|
||||
KMTQAITYPE_ADAPTERPERFDATA_CAPS KMTQUERYADAPTERINFOTYPE = 63
|
||||
)
|
||||
|
||||
type D3DKMT_QUERYADAPTERINFO struct {
|
||||
hAdapter uint32
|
||||
Type KMTQUERYADAPTERINFOTYPE
|
||||
pPrivateDriverData uintptr
|
||||
PrivateDriverDataSize uint32
|
||||
}
|
||||
|
||||
type D3DKMT_ADAPTER_PERFDATA struct {
|
||||
PhysicalAdapterIndex uint32
|
||||
MemoryFrequency uint64
|
||||
MaxMemoryFrequency uint64
|
||||
MaxMemoryFrequencyOC uint64
|
||||
MemoryBandwidth uint64
|
||||
PCIEBandwidth uint64
|
||||
FanRPM uint32
|
||||
Power uint32
|
||||
Temperature uint32
|
||||
PowerStateOverride byte
|
||||
}
|
||||
|
||||
type D3DKMT_QUERYSTATISTICS_TYPE int32
|
||||
|
||||
const (
|
||||
D3DKMT_QUERYSTATISTICS_ADAPTER D3DKMT_QUERYSTATISTICS_TYPE = 0
|
||||
D3DKMT_QUERYSTATISTICS_PROCESS D3DKMT_QUERYSTATISTICS_TYPE = 1
|
||||
D3DKMT_QUERYSTATISTICS_PROCESS_ADAPTER D3DKMT_QUERYSTATISTICS_TYPE = 2
|
||||
D3DKMT_QUERYSTATISTICS_SEGMENT D3DKMT_QUERYSTATISTICS_TYPE = 3
|
||||
D3DKMT_QUERYSTATISTICS_PROCESS_SEGMENT D3DKMT_QUERYSTATISTICS_TYPE = 4
|
||||
D3DKMT_QUERYSTATISTICS_NODE D3DKMT_QUERYSTATISTICS_TYPE = 5
|
||||
D3DKMT_QUERYSTATISTICS_PROCESS_NODE D3DKMT_QUERYSTATISTICS_TYPE = 6
|
||||
D3DKMT_QUERYSTATISTICS_VIDPNSOURCE D3DKMT_QUERYSTATISTICS_TYPE = 7
|
||||
D3DKMT_QUERYSTATISTICS_PROCESS_VIDPNSOURCE D3DKMT_QUERYSTATISTICS_TYPE = 8
|
||||
)
|
||||
|
||||
type D3DKMT_ADAPTER_PERFDATACAPS struct {
|
||||
PhysicalAdapterIndex uint32
|
||||
MaxMemoryBandwidth uint64
|
||||
MaxPCIEBandwidth uint64
|
||||
MaxFanRPM uint32
|
||||
TemperatureMax uint32
|
||||
TemperatureWarning uint32
|
||||
}
|
||||
|
||||
type D3DKMT_QUERYSTATISTICS_QUERY_SEGMENT struct {
|
||||
SegmentId uint32
|
||||
}
|
||||
|
||||
type D3DKMT_QUERYSTATISTICS_QUERY_NODE struct {
|
||||
NodeId uint32
|
||||
}
|
||||
@@ -0,0 +1,529 @@
|
||||
//go:build windows
|
||||
|
||||
package perf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
d3dkmDLL *windows.LazyDLL
|
||||
procEnumAdapters2 *windows.LazyProc
|
||||
procOpenAdapterFromLuid *windows.LazyProc
|
||||
procCloseAdapter *windows.LazyProc
|
||||
procQueryAdapterInfo *windows.LazyProc
|
||||
procQueryStatistics *windows.LazyProc
|
||||
d3dkmtInitOnce sync.Once
|
||||
d3dkmtInitErr error
|
||||
)
|
||||
|
||||
// initD3DKMT lazily loads gdi32.dll and resolves D3DKMT function pointers.
|
||||
// Safe for concurrent use via sync.Once.
|
||||
func initD3DKMT() error {
|
||||
d3dkmtInitOnce.Do(func() {
|
||||
d3dkmDLL = windows.NewLazySystemDLL("gdi32.dll")
|
||||
|
||||
procEnumAdapters2 = d3dkmDLL.NewProc("D3DKMTEnumAdapters2")
|
||||
procOpenAdapterFromLuid = d3dkmDLL.NewProc("D3DKMTOpenAdapterFromLuid")
|
||||
procCloseAdapter = d3dkmDLL.NewProc("D3DKMTCloseAdapter")
|
||||
procQueryAdapterInfo = d3dkmDLL.NewProc("D3DKMTQueryAdapterInfo")
|
||||
procQueryStatistics = d3dkmDLL.NewProc("D3DKMTQueryStatistics")
|
||||
|
||||
for name, p := range map[string]*windows.LazyProc{
|
||||
"D3DKMTEnumAdapters2": procEnumAdapters2,
|
||||
"D3DKMTOpenAdapterFromLuid": procOpenAdapterFromLuid,
|
||||
"D3DKMTCloseAdapter": procCloseAdapter,
|
||||
"D3DKMTQueryAdapterInfo": procQueryAdapterInfo,
|
||||
"D3DKMTQueryStatistics": procQueryStatistics,
|
||||
} {
|
||||
if err := p.Find(); err != nil {
|
||||
d3dkmtInitErr = fmt.Errorf("D3DKMT %s not found: %w", name, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
return d3dkmtInitErr
|
||||
}
|
||||
|
||||
// ntstatusCall invokes a D3DKMT function and returns a non-nil error if the
|
||||
// NTSTATUS result is not STATUS_SUCCESS (0).
|
||||
func ntstatusCall(proc *windows.LazyProc, arg unsafe.Pointer) error {
|
||||
ret, _, _ := proc.Call(uintptr(arg))
|
||||
if ret != 0 {
|
||||
return fmt.Errorf("NTSTATUS 0x%08x", uint32(ret))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// d3dkmEnumerateAdapters enumerates all available graphics adapters via
|
||||
// D3DKMTEnumAdapters2.
|
||||
func d3dkmEnumerateAdapters() ([]D3DKMT_ADAPTERINFO, error) {
|
||||
var adapters [maxEnumAdapters]D3DKMT_ADAPTERINFO
|
||||
enum := D3DKMT_ENUMADAPTERS2{
|
||||
NumAdapters: maxEnumAdapters,
|
||||
pAdapters: uintptr(unsafe.Pointer(&adapters[0])),
|
||||
}
|
||||
if err := ntstatusCall(procEnumAdapters2, unsafe.Pointer(&enum)); err != nil {
|
||||
return nil, fmt.Errorf("EnumAdapters2: %w", err)
|
||||
}
|
||||
if enum.NumAdapters == 0 {
|
||||
return nil, fmt.Errorf("no adapters found")
|
||||
}
|
||||
result := make([]D3DKMT_ADAPTERINFO, enum.NumAdapters)
|
||||
for i := uint32(0); i < enum.NumAdapters; i++ {
|
||||
result[i] = adapters[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// d3dkmOpenAdapter opens a D3DKMT adapter handle for the given LUID.
|
||||
func d3dkmOpenAdapter(luid LUID) (uint32, error) {
|
||||
req := D3DKMT_OPENADAPTERFROMLUID{
|
||||
AdapterLuid: luid,
|
||||
}
|
||||
if err := ntstatusCall(procOpenAdapterFromLuid, unsafe.Pointer(&req)); err != nil {
|
||||
return 0, fmt.Errorf("OpenAdapterFromLuid: %w", err)
|
||||
}
|
||||
return req.hAdapter, nil
|
||||
}
|
||||
|
||||
// d3dkmCloseAdapter closes a previously opened D3DKMT adapter handle.
|
||||
func d3dkmCloseAdapter(hAdapter uint32) error {
|
||||
req := D3DKMT_CLOSEADAPTER{hAdapter: hAdapter}
|
||||
return ntstatusCall(procCloseAdapter, unsafe.Pointer(&req))
|
||||
}
|
||||
|
||||
// d3dkmGetAdapterPerfData queries per-adapter performance data (temperature,
|
||||
// fan RPM, power, bandwidth) via KMTQAITYPE_ADAPTERPERFDATA.
|
||||
func d3dkmGetAdapterPerfData(hAdapter uint32) (*D3DKMT_ADAPTER_PERFDATA, error) {
|
||||
var data D3DKMT_ADAPTER_PERFDATA
|
||||
req := D3DKMT_QUERYADAPTERINFO{
|
||||
hAdapter: hAdapter,
|
||||
Type: KMTQAITYPE_ADAPTERPERFDATA,
|
||||
pPrivateDriverData: uintptr(unsafe.Pointer(&data)),
|
||||
PrivateDriverDataSize: uint32(unsafe.Sizeof(data)),
|
||||
}
|
||||
if err := ntstatusCall(procQueryAdapterInfo, unsafe.Pointer(&req)); err != nil {
|
||||
return nil, fmt.Errorf("QueryAdapterInfo(ADAPTERPERFDATA): %w", err)
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
// d3dkmGetAdapterPerfDataCaps queries static adapter performance capabilities
|
||||
// (max fan RPM, temperature limits, max bandwidth) via KMTQAITYPE_ADAPTERPERFDATA_CAPS.
|
||||
func d3dkmGetAdapterPerfDataCaps(hAdapter uint32) (*D3DKMT_ADAPTER_PERFDATACAPS, error) {
|
||||
var data D3DKMT_ADAPTER_PERFDATACAPS
|
||||
req := D3DKMT_QUERYADAPTERINFO{
|
||||
hAdapter: hAdapter,
|
||||
Type: KMTQAITYPE_ADAPTERPERFDATA_CAPS,
|
||||
pPrivateDriverData: uintptr(unsafe.Pointer(&data)),
|
||||
PrivateDriverDataSize: uint32(unsafe.Sizeof(data)),
|
||||
}
|
||||
if err := ntstatusCall(procQueryAdapterInfo, unsafe.Pointer(&req)); err != nil {
|
||||
return nil, fmt.Errorf("QueryAdapterInfo(ADAPTERPERFDATACAPS): %w", err)
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
type queryStatsBuffer struct {
|
||||
Type int32 // offset 0
|
||||
AdapterLuid LUID // offset 4
|
||||
hProcess uintptr // offset 16
|
||||
// _result mirrors the D3DKMT_QUERYSTATISTICS_RESULT union.
|
||||
// sizeof(D3DKMT_QUERYSTATISTICS) == 0x328 (808 bytes) on x64.
|
||||
//
|
||||
// The C struct layout (x64):
|
||||
// offset 0: Type (int32, 4 bytes)
|
||||
// offset 4: AdapterLuid (LUID, 8 bytes)
|
||||
// offset 12: 4 bytes padding (for 8-byte alignment of hProcess)
|
||||
// offset 16: hProcess (HANDLE, 8 bytes)
|
||||
// offset 24: QueryResult (union, 780 bytes — largest member is AdapterInformation)
|
||||
// offset 804: anonymous input union (QueryNode.NodeId / QuerySegment.SegmentId, 4 bytes)
|
||||
//
|
||||
// Previous bug: _result was [776]byte, placing QueryId at offset 800 instead of 804.
|
||||
// The kernel read NodeId/SegmentId from offset 804 (always zero from _pad),
|
||||
// causing all NODE and SEGMENT queries to use index 0 regardless of the value
|
||||
// passed in QueryId. This produced alternating behavior where only GPU util OR
|
||||
// memory util appeared to work, depending on which test variant happened to put
|
||||
// non-zero data near offset 804 in the result buffer.
|
||||
_result [780]byte // offset 24, size 780 — places QueryId at offset 804
|
||||
QueryId int32 // offset 804 — matches C anonymous union for NodeId/SegmentId
|
||||
}
|
||||
|
||||
func init() {
|
||||
var buf queryStatsBuffer
|
||||
if unsafe.Sizeof(buf) != 808 {
|
||||
panic(fmt.Sprintf("queryStatsBuffer size %d != expected 808 (sizeof D3DKMT_QUERYSTATISTICS on x64)", unsafe.Sizeof(buf)))
|
||||
}
|
||||
if unsafe.Offsetof(buf.QueryId) != 804 {
|
||||
panic(fmt.Sprintf("queryStatsBuffer.QueryId offset %d != expected 804 (C anonymous union offset)", unsafe.Offsetof(buf.QueryId)))
|
||||
}
|
||||
|
||||
var perfData D3DKMT_ADAPTER_PERFDATA
|
||||
if unsafe.Sizeof(perfData) != 64 {
|
||||
panic(fmt.Sprintf("D3DKMT_ADAPTER_PERFDATA size %d != expected 64 on x64", unsafe.Sizeof(perfData)))
|
||||
}
|
||||
|
||||
var caps D3DKMT_ADAPTER_PERFDATACAPS
|
||||
if unsafe.Sizeof(caps) != 40 {
|
||||
panic(fmt.Sprintf("D3DKMT_ADAPTER_PERFDATACAPS size %d != expected 40 on x64", unsafe.Sizeof(caps)))
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
qsoffsetNbSegments = 0
|
||||
qsoffsetNodeCount = 4
|
||||
qsoffsetCommitLimit = 0
|
||||
qsoffsetBytesCommitted = 8
|
||||
qsoffsetBytesResident = 16
|
||||
qsoffsetRunningTime = 0
|
||||
qsoffsetSystemRunningTime = 272
|
||||
)
|
||||
|
||||
// d3dkmQueryAdapterStats returns the number of memory segments and compute
|
||||
// nodes for the adapter identified by luid.
|
||||
func d3dkmQueryAdapterStats(luid LUID) (nbSegments uint32, nodeCount uint32, err error) {
|
||||
buf := queryStatsBuffer{
|
||||
Type: int32(D3DKMT_QUERYSTATISTICS_ADAPTER),
|
||||
AdapterLuid: luid,
|
||||
}
|
||||
if err := ntstatusCall(procQueryStatistics, unsafe.Pointer(&buf)); err != nil {
|
||||
return 0, 0, fmt.Errorf("QueryStatistics(ADAPTER): %w", err)
|
||||
}
|
||||
nbSegments = binary.LittleEndian.Uint32(buf._result[qsoffsetNbSegments : qsoffsetNbSegments+4])
|
||||
nodeCount = binary.LittleEndian.Uint32(buf._result[qsoffsetNodeCount : qsoffsetNodeCount+4])
|
||||
return nbSegments, nodeCount, nil
|
||||
}
|
||||
|
||||
// d3dkmQuerySegmentStats returns the commit limit (total) and resident
|
||||
// (used) bytes for the given memory segment of an adapter.
|
||||
func d3dkmQuerySegmentStats(luid LUID, segmentID uint32) (commitLimit uint64, bytesResident uint64, err error) {
|
||||
buf := queryStatsBuffer{
|
||||
Type: int32(D3DKMT_QUERYSTATISTICS_SEGMENT),
|
||||
AdapterLuid: luid,
|
||||
QueryId: int32(segmentID),
|
||||
}
|
||||
if err := ntstatusCall(procQueryStatistics, unsafe.Pointer(&buf)); err != nil {
|
||||
return 0, 0, fmt.Errorf("QueryStatistics(SEGMENT %d): %w", segmentID, err)
|
||||
}
|
||||
commitLimit = binary.LittleEndian.Uint64(buf._result[qsoffsetCommitLimit : qsoffsetCommitLimit+8])
|
||||
bytesResident = binary.LittleEndian.Uint64(buf._result[qsoffsetBytesResident : qsoffsetBytesResident+8])
|
||||
if bytesResident == 0 {
|
||||
bytesResident = binary.LittleEndian.Uint64(buf._result[qsoffsetBytesCommitted : qsoffsetBytesCommitted+8])
|
||||
}
|
||||
return commitLimit, bytesResident, nil
|
||||
}
|
||||
|
||||
// d3dkmQueryNodeStats returns the global and system running time counters
|
||||
// (in 100ns units) for the given compute node of an adapter.
|
||||
func d3dkmQueryNodeStats(luid LUID, nodeID uint32) (runningTime uint64, systemRunningTime uint64, err error) {
|
||||
buf := queryStatsBuffer{
|
||||
Type: int32(D3DKMT_QUERYSTATISTICS_NODE),
|
||||
AdapterLuid: luid,
|
||||
QueryId: int32(nodeID),
|
||||
}
|
||||
if err := ntstatusCall(procQueryStatistics, unsafe.Pointer(&buf)); err != nil {
|
||||
return 0, 0, fmt.Errorf("QueryStatistics(NODE %d): %w", nodeID, err)
|
||||
}
|
||||
runningTime = binary.LittleEndian.Uint64(buf._result[qsoffsetRunningTime : qsoffsetRunningTime+8])
|
||||
systemRunningTime = binary.LittleEndian.Uint64(buf._result[qsoffsetSystemRunningTime : qsoffsetSystemRunningTime+8])
|
||||
return runningTime, systemRunningTime, nil
|
||||
}
|
||||
|
||||
type nodeRunningTimes struct {
|
||||
Global uint64
|
||||
System uint64
|
||||
}
|
||||
|
||||
// d3dkmtNodeUtil computes GPU node utilization as a percentage from running
|
||||
// time deltas. Returns -1 if counters went backwards (wrap/reset), 0 if idle.
|
||||
func d3dkmtNodeUtil(prevRT, curRT nodeRunningTimes, elapsed100ns int64) float64 {
|
||||
if curRT.Global < prevRT.Global || curRT.System < prevRT.System {
|
||||
return -1
|
||||
}
|
||||
gd := curRT.Global - prevRT.Global
|
||||
sd := curRT.System - prevRT.System
|
||||
|
||||
if gd > 0 && sd > 0 {
|
||||
util := float64(gd) / float64(sd)
|
||||
if util > 1.0 {
|
||||
util = 1.0
|
||||
}
|
||||
return util * 100.0
|
||||
} else if gd > 0 && elapsed100ns > 0 {
|
||||
util := float64(gd) / float64(elapsed100ns) * 100.0
|
||||
if util > 100.0 {
|
||||
util = 100.0
|
||||
}
|
||||
return util
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// d3dkmtFanPct returns fan speed as a percentage of maxFanRPM, clamped to
|
||||
// 100%. Returns 0 if maxFanRPM is unavailable or fan is not spinning.
|
||||
func d3dkmtFanPct(fanRPM, maxFanRPM uint32) float64 {
|
||||
if maxFanRPM > 0 && fanRPM > 0 {
|
||||
pct := float64(fanRPM) / float64(maxFanRPM) * 100.0
|
||||
if pct > 100.0 {
|
||||
pct = 100.0
|
||||
}
|
||||
return pct
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// d3dkmtPowerW converts power from deci-watts (as reported by D3DKMT) to
|
||||
// watts. Returns 0 if the power value is zero.
|
||||
func d3dkmtPowerW(power uint32) float64 {
|
||||
if power > 0 {
|
||||
return float64(power) / 10.0
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// d3dkmtTempC converts temperature from deci-Celsius (as reported by D3DKMT)
|
||||
// to degrees Celsius.
|
||||
func d3dkmtTempC(tempDeciC uint32) int {
|
||||
return int(tempDeciC / 10)
|
||||
}
|
||||
|
||||
type d3dkmtAdapterState struct {
|
||||
luid LUID
|
||||
hAdapter uint32
|
||||
nbSegments uint32
|
||||
nodeCount uint32
|
||||
maxFanRPM uint32
|
||||
prevNodeRT map[uint32]nodeRunningTimes
|
||||
prevTime time.Time
|
||||
}
|
||||
|
||||
// tryD3DKMT attempts to start GPU monitoring using D3DKMT and optional PDH
|
||||
// counters. It returns a channel of GpuStat snapshots or an error if no
|
||||
// usable adapters are found.
|
||||
func tryD3DKMT(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||
if err := initD3DKMT(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
adapterInfos, err := d3dkmEnumerateAdapters()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
type adapterMeta struct {
|
||||
luid LUID
|
||||
nbSegments uint32
|
||||
nodeCount uint32
|
||||
maxFanRPM uint32
|
||||
}
|
||||
|
||||
var metaList []adapterMeta
|
||||
|
||||
for i, ai := range adapterInfos {
|
||||
hAdapter, err := d3dkmOpenAdapter(ai.AdapterLuid)
|
||||
if err != nil {
|
||||
logger.Debugf("adapter %d: open failed: %s", i, err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
nbSegments, nodeCount, err := d3dkmQueryAdapterStats(ai.AdapterLuid)
|
||||
if err != nil {
|
||||
logger.Debugf("adapter %d: query stats failed: %s", i, err.Error())
|
||||
d3dkmCloseAdapter(hAdapter)
|
||||
continue
|
||||
}
|
||||
|
||||
caps, err := d3dkmGetAdapterPerfDataCaps(hAdapter)
|
||||
if err != nil {
|
||||
logger.Debugf("adapter %d: perf caps failed: %s", i, err.Error())
|
||||
}
|
||||
|
||||
d3dkmCloseAdapter(hAdapter)
|
||||
|
||||
var maxFanRPM uint32
|
||||
if caps != nil {
|
||||
maxFanRPM = caps.MaxFanRPM
|
||||
}
|
||||
|
||||
metaList = append(metaList, adapterMeta{
|
||||
luid: ai.AdapterLuid,
|
||||
nbSegments: nbSegments,
|
||||
nodeCount: nodeCount,
|
||||
maxFanRPM: maxFanRPM,
|
||||
})
|
||||
logger.Debugf("adapter %d: segments=%d nodes=%d fan_max=%d luid=%d:%d", i, nbSegments, nodeCount, maxFanRPM, ai.AdapterLuid.HighPart, ai.AdapterLuid.LowPart)
|
||||
}
|
||||
|
||||
if len(metaList) == 0 {
|
||||
return nil, fmt.Errorf("no usable D3DKMT adapters found")
|
||||
}
|
||||
|
||||
pdhUtil, pdhErr := initPdhGpuUtil()
|
||||
if pdhErr != nil {
|
||||
logger.Debugf("PDH GPU utilization not available: %s", pdhErr.Error())
|
||||
} else {
|
||||
logger.Info("using PDH performance counters for GPU utilization")
|
||||
}
|
||||
|
||||
ch := make(chan []GpuStat, 1)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
if pdhUtil != nil {
|
||||
defer pdhUtil.close()
|
||||
}
|
||||
|
||||
var adapters []d3dkmtAdapterState
|
||||
for _, m := range metaList {
|
||||
hAdapter, err := d3dkmOpenAdapter(m.luid)
|
||||
if err != nil {
|
||||
logger.Debugf("reopen adapter failed: %s", err.Error())
|
||||
continue
|
||||
}
|
||||
adapters = append(adapters, d3dkmtAdapterState{
|
||||
luid: m.luid,
|
||||
hAdapter: hAdapter,
|
||||
nbSegments: m.nbSegments,
|
||||
nodeCount: m.nodeCount,
|
||||
maxFanRPM: m.maxFanRPM,
|
||||
prevNodeRT: make(map[uint32]nodeRunningTimes),
|
||||
})
|
||||
}
|
||||
|
||||
if len(adapters) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for _, a := range adapters {
|
||||
d3dkmCloseAdapter(a.hAdapter)
|
||||
}
|
||||
}()
|
||||
|
||||
for i := range adapters {
|
||||
a := &adapters[i]
|
||||
for node := uint32(0); node < a.nodeCount; node++ {
|
||||
globalRT, systemRT, err := d3dkmQueryNodeStats(a.luid, node)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
a.prevNodeRT[node] = nodeRunningTimes{Global: globalRT, System: systemRT}
|
||||
}
|
||||
a.prevTime = time.Now()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(every)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
stats := make([]GpuStat, 0, len(adapters))
|
||||
now := time.Now()
|
||||
|
||||
var pdhUtilMap map[LUID]float64
|
||||
if pdhUtil != nil {
|
||||
pdhUtilMap = pdhUtil.collect()
|
||||
}
|
||||
|
||||
for i := range adapters {
|
||||
a := &adapters[i]
|
||||
|
||||
perfData, err := d3dkmGetAdapterPerfData(a.hAdapter)
|
||||
if err != nil {
|
||||
logger.Debugf("adapter %d perfdata: %s", i, err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
var memUsedMB, memTotalMB int
|
||||
for seg := uint32(0); seg < a.nbSegments; seg++ {
|
||||
limit, resident, err := d3dkmQuerySegmentStats(a.luid, seg)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
memUsedMB += int(resident / (1024 * 1024))
|
||||
memTotalMB += int(limit / (1024 * 1024))
|
||||
}
|
||||
|
||||
var gpuUtil float64
|
||||
pdhGaveValue := false
|
||||
if pdhUtilMap != nil {
|
||||
if util, ok := pdhUtilMap[a.luid]; ok {
|
||||
gpuUtil = util
|
||||
pdhGaveValue = true
|
||||
}
|
||||
}
|
||||
|
||||
if !pdhGaveValue && a.nodeCount > 0 {
|
||||
elapsedNs := now.Sub(a.prevTime).Nanoseconds()
|
||||
elapsed100ns := elapsedNs / 100
|
||||
|
||||
for node := uint32(0); node < a.nodeCount; node++ {
|
||||
globalRT, systemRT, err := d3dkmQueryNodeStats(a.luid, node)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if prevRT, ok := a.prevNodeRT[node]; ok {
|
||||
if globalRT < prevRT.Global || systemRT < prevRT.System {
|
||||
a.prevNodeRT[node] = nodeRunningTimes{Global: globalRT, System: systemRT}
|
||||
continue
|
||||
}
|
||||
nodeUtil := d3dkmtNodeUtil(prevRT, nodeRunningTimes{Global: globalRT, System: systemRT}, elapsed100ns)
|
||||
if nodeUtil > gpuUtil {
|
||||
gpuUtil = nodeUtil
|
||||
}
|
||||
}
|
||||
a.prevNodeRT[node] = nodeRunningTimes{Global: globalRT, System: systemRT}
|
||||
}
|
||||
|
||||
a.prevTime = now
|
||||
}
|
||||
|
||||
tempC := d3dkmtTempC(perfData.Temperature)
|
||||
|
||||
fanSpeedPct := d3dkmtFanPct(perfData.FanRPM, a.maxFanRPM)
|
||||
powerDrawW := d3dkmtPowerW(perfData.Power)
|
||||
|
||||
var memUtilPct float64
|
||||
if memTotalMB > 0 {
|
||||
memUtilPct = float64(memUsedMB) / float64(memTotalMB) * 100.0
|
||||
}
|
||||
|
||||
stats = append(stats, GpuStat{
|
||||
Timestamp: now,
|
||||
ID: i,
|
||||
Name: fmt.Sprintf("GPU %d", i),
|
||||
TempC: tempC,
|
||||
GpuUtilPct: gpuUtil,
|
||||
MemUtilPct: memUtilPct,
|
||||
MemUsedMB: memUsedMB,
|
||||
MemTotalMB: memTotalMB,
|
||||
FanSpeedPct: fanSpeedPct,
|
||||
PowerDrawW: powerDrawW,
|
||||
})
|
||||
}
|
||||
|
||||
if len(stats) > 0 {
|
||||
select {
|
||||
case ch <- stats:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
//go:build windows
|
||||
|
||||
package perf
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestD3dkmtNodeUtil_FullLoad(t *testing.T) {
|
||||
prev := nodeRunningTimes{Global: 1000, System: 10000}
|
||||
cur := nodeRunningTimes{Global: 5000, System: 14000}
|
||||
got := d3dkmtNodeUtil(prev, cur, 100000)
|
||||
assert.Equal(t, 100.0, got)
|
||||
}
|
||||
|
||||
func TestD3dkmtNodeUtil_PartialUtil(t *testing.T) {
|
||||
prev := nodeRunningTimes{Global: 1000, System: 10000}
|
||||
cur := nodeRunningTimes{Global: 3000, System: 14000}
|
||||
got := d3dkmtNodeUtil(prev, cur, 100000)
|
||||
assert.Equal(t, 50.0, got)
|
||||
}
|
||||
|
||||
func TestD3dkmtNodeUtil_Identical(t *testing.T) {
|
||||
prev := nodeRunningTimes{Global: 10000, System: 10000}
|
||||
cur := nodeRunningTimes{Global: 20000, System: 20000}
|
||||
got := d3dkmtNodeUtil(prev, cur, 100000)
|
||||
assert.Equal(t, 100.0, got)
|
||||
}
|
||||
|
||||
func TestD3dkmtNodeUtil_CounterWrap(t *testing.T) {
|
||||
prev := nodeRunningTimes{Global: 9000, System: 10000}
|
||||
cur := nodeRunningTimes{Global: 1000, System: 10000}
|
||||
got := d3dkmtNodeUtil(prev, cur, 100000)
|
||||
assert.Equal(t, -1.0, got)
|
||||
}
|
||||
|
||||
func TestD3dkmtNodeUtil_SystemWrap(t *testing.T) {
|
||||
prev := nodeRunningTimes{Global: 1000, System: 9000}
|
||||
cur := nodeRunningTimes{Global: 5000, System: 1000}
|
||||
got := d3dkmtNodeUtil(prev, cur, 100000)
|
||||
assert.Equal(t, -1.0, got)
|
||||
}
|
||||
|
||||
func TestD3dkmtNodeUtil_ZeroDelta(t *testing.T) {
|
||||
prev := nodeRunningTimes{Global: 1000, System: 10000}
|
||||
cur := nodeRunningTimes{Global: 1000, System: 10000}
|
||||
got := d3dkmtNodeUtil(prev, cur, 100000)
|
||||
assert.Equal(t, 0.0, got)
|
||||
}
|
||||
|
||||
func TestD3dkmtNodeUtil_ElapsedFallback(t *testing.T) {
|
||||
prev := nodeRunningTimes{Global: 1000, System: 10000}
|
||||
cur := nodeRunningTimes{Global: 6000, System: 10000}
|
||||
got := d3dkmtNodeUtil(prev, cur, 50000)
|
||||
assert.InDelta(t, 10.0, got, 0.01)
|
||||
}
|
||||
|
||||
func TestD3dkmtFanPct_Normal(t *testing.T) {
|
||||
assert.Equal(t, 50.0, d3dkmtFanPct(1500, 3000))
|
||||
}
|
||||
|
||||
func TestD3dkmtFanPct_MaxFan(t *testing.T) {
|
||||
assert.Equal(t, 100.0, d3dkmtFanPct(3000, 3000))
|
||||
}
|
||||
|
||||
func TestD3dkmtFanPct_OverMaxClamped(t *testing.T) {
|
||||
assert.Equal(t, 100.0, d3dkmtFanPct(4000, 3000))
|
||||
}
|
||||
|
||||
func TestD3dkmtFanPct_ZeroMaxFan(t *testing.T) {
|
||||
assert.Equal(t, 0.0, d3dkmtFanPct(1500, 0))
|
||||
}
|
||||
|
||||
func TestD3dkmtFanPct_ZeroFanRPM(t *testing.T) {
|
||||
assert.Equal(t, 0.0, d3dkmtFanPct(0, 3000))
|
||||
}
|
||||
|
||||
func TestD3dkmtFanPct_BothZero(t *testing.T) {
|
||||
assert.Equal(t, 0.0, d3dkmtFanPct(0, 0))
|
||||
}
|
||||
|
||||
func TestD3dkmtPowerW(t *testing.T) {
|
||||
assert.Equal(t, 250.0, d3dkmtPowerW(2500))
|
||||
}
|
||||
|
||||
func TestD3dkmtPowerW_Zero(t *testing.T) {
|
||||
assert.Equal(t, 0.0, d3dkmtPowerW(0))
|
||||
}
|
||||
|
||||
func TestD3dkmtTempC(t *testing.T) {
|
||||
assert.Equal(t, 65, d3dkmtTempC(650))
|
||||
}
|
||||
|
||||
func TestD3dkmtTempC_Zero(t *testing.T) {
|
||||
assert.Equal(t, 0, d3dkmtTempC(0))
|
||||
}
|
||||
@@ -1,6 +1,10 @@
|
||||
package perf
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -43,3 +47,168 @@ func ParseNvidiaSmiLine(line string) *GpuStat {
|
||||
PowerDrawW: powerDraw,
|
||||
}
|
||||
}
|
||||
|
||||
// mactopOutput maps the subset of mactop's headless JSON output that is
|
||||
// relevant to GpuStat. Note that mactop's memory object is whole-system memory,
|
||||
// not GPU-attributed; the darwin monitor overlays ioreg's GPU-attributed
|
||||
// unified memory (see overlayIoregMem) so both backends report consistent
|
||||
// memory figures.
|
||||
type mactopOutput struct {
|
||||
SocMetrics struct {
|
||||
GPUPower float64 `json:"gpu_power"`
|
||||
GPUFreq int `json:"gpu_freq_mhz"`
|
||||
GPUTemp float64 `json:"gpu_temp"`
|
||||
} `json:"soc_metrics"`
|
||||
Memory struct {
|
||||
Total uint64 `json:"total"`
|
||||
Used uint64 `json:"used"`
|
||||
} `json:"memory"`
|
||||
GPUUsage float64 `json:"gpu_usage"`
|
||||
SystemInfo struct {
|
||||
Name string `json:"name"`
|
||||
GPUCoreCount int `json:"gpu_core_count"`
|
||||
} `json:"system_info"`
|
||||
Fans []struct {
|
||||
RPM int `json:"rpm"`
|
||||
MinRPM int `json:"min_rpm"`
|
||||
MaxRPM int `json:"max_rpm"`
|
||||
} `json:"fans"`
|
||||
Temperatures []struct {
|
||||
Group string `json:"group"`
|
||||
Avg float64 `json:"avg_celsius"`
|
||||
} `json:"temperatures"`
|
||||
}
|
||||
|
||||
// ioreg output uses ` = ` (with spaces) for top-level device properties and
|
||||
// `=` (no spaces) for values inside nested dictionaries such as
|
||||
// PerformanceStatistics.
|
||||
var (
|
||||
reIoregModel = regexp.MustCompile(`"model"\s*=\s*"([^"]+)"`)
|
||||
reIoregCoreCount = regexp.MustCompile(`"gpu-core-count"\s*=\s*(\d+)`)
|
||||
reIoregUtil = regexp.MustCompile(`"Device Utilization %"=(\d+)`)
|
||||
reIoregMemUsed = regexp.MustCompile(`"In use system memory"=(\d+)`)
|
||||
)
|
||||
|
||||
// ParseIoregOutput parses `ioreg -r -c IOGPU -d 1 -f` output into a GpuStat for
|
||||
// the Apple Silicon integrated GPU. This is a fallback for when mactop is not
|
||||
// installed: utilization and used memory are available, but power, temperature,
|
||||
// and fan speed are not exposed by ioreg. memTotalMB is the unified memory size
|
||||
// supplied by the caller, since Apple Silicon shares memory between CPU and GPU.
|
||||
// Returns nil if no GPU device is found in the output.
|
||||
func ParseIoregOutput(out []byte, memTotalMB int) *GpuStat {
|
||||
utilMatch := reIoregUtil.FindSubmatch(out)
|
||||
memMatch := reIoregMemUsed.FindSubmatch(out)
|
||||
if utilMatch == nil && memMatch == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var gpuUtil float64
|
||||
if utilMatch != nil {
|
||||
gpuUtil, _ = strconv.ParseFloat(string(utilMatch[1]), 64)
|
||||
}
|
||||
|
||||
const toMB = 1024 * 1024
|
||||
var memUsedMB int
|
||||
if memMatch != nil {
|
||||
memUsedBytes, _ := strconv.ParseInt(string(memMatch[1]), 10, 64)
|
||||
memUsedMB = int(memUsedBytes / toMB)
|
||||
}
|
||||
|
||||
var memUtil float64
|
||||
if memTotalMB > 0 {
|
||||
memUtil = float64(memUsedMB) / float64(memTotalMB) * 100
|
||||
}
|
||||
|
||||
name := "Apple GPU"
|
||||
if m := reIoregModel.FindSubmatch(out); m != nil {
|
||||
name = string(m[1])
|
||||
}
|
||||
if m := reIoregCoreCount.FindSubmatch(out); m != nil {
|
||||
if cores, err := strconv.Atoi(string(m[1])); err == nil && cores > 0 {
|
||||
name = fmt.Sprintf("%s (%d-core GPU)", name, cores)
|
||||
}
|
||||
}
|
||||
|
||||
return &GpuStat{
|
||||
Timestamp: time.Now(),
|
||||
ID: 0,
|
||||
Name: name,
|
||||
GpuUtilPct: gpuUtil,
|
||||
MemUtilPct: memUtil,
|
||||
MemUsedMB: memUsedMB,
|
||||
MemTotalMB: memTotalMB,
|
||||
}
|
||||
}
|
||||
|
||||
// ParseMactopLine parses a single line of mactop headless JSON output into a
|
||||
// GpuStat for the Apple Silicon integrated GPU. Returns nil if the line cannot
|
||||
// be parsed.
|
||||
func ParseMactopLine(line string) *GpuStat {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var out mactopOutput
|
||||
if err := json.Unmarshal([]byte(line), &out); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
const toMB = 1024 * 1024
|
||||
memUsedMB := int(out.Memory.Used / toMB)
|
||||
memTotalMB := int(out.Memory.Total / toMB)
|
||||
|
||||
var memUtil float64
|
||||
if memTotalMB > 0 {
|
||||
memUtil = float64(memUsedMB) / float64(memTotalMB) * 100
|
||||
}
|
||||
|
||||
name := out.SystemInfo.Name
|
||||
if name == "" {
|
||||
name = "Apple GPU"
|
||||
}
|
||||
if out.SystemInfo.GPUCoreCount > 0 {
|
||||
name = fmt.Sprintf("%s (%d-core GPU)", name, out.SystemInfo.GPUCoreCount)
|
||||
}
|
||||
|
||||
// Unified memory has no dedicated VRAM sensor; use the memory temperature
|
||||
// group when mactop exposes it.
|
||||
var vramTempC int
|
||||
for _, t := range out.Temperatures {
|
||||
if strings.EqualFold(t.Group, "Memory") {
|
||||
vramTempC = int(math.Round(t.Avg))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Average fan load across all fans as a percentage of their RPM range.
|
||||
var fanSpeed float64
|
||||
var fanCount int
|
||||
for _, f := range out.Fans {
|
||||
if f.MaxRPM > f.MinRPM {
|
||||
pct := float64(f.RPM-f.MinRPM) / float64(f.MaxRPM-f.MinRPM) * 100
|
||||
if pct < 0 {
|
||||
pct = 0
|
||||
}
|
||||
fanSpeed += pct
|
||||
fanCount++
|
||||
}
|
||||
}
|
||||
if fanCount > 0 {
|
||||
fanSpeed /= float64(fanCount)
|
||||
}
|
||||
|
||||
return &GpuStat{
|
||||
Timestamp: time.Now(),
|
||||
ID: 0,
|
||||
Name: name,
|
||||
TempC: int(math.Round(out.SocMetrics.GPUTemp)),
|
||||
VramTempC: vramTempC,
|
||||
GpuUtilPct: out.GPUUsage,
|
||||
MemUtilPct: memUtil,
|
||||
MemUsedMB: memUsedMB,
|
||||
MemTotalMB: memTotalMB,
|
||||
FanSpeedPct: fanSpeed,
|
||||
PowerDrawW: out.SocMetrics.GPUPower,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotImplemented = errors.New("Not Implemented")
|
||||
ErrNotImplemented = errors.New("not implemented")
|
||||
ErrNoGpuTool = errors.New("no GPU monitoring tool available")
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
package perf
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
@@ -11,7 +15,156 @@ import (
|
||||
)
|
||||
|
||||
func getGpuStats(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||
return nil, ErrNotImplemented
|
||||
if ch, err := tryMactop(ctx, every, logger); err == nil {
|
||||
logger.Info("using mactop for GPU monitoring")
|
||||
return ch, nil
|
||||
} else {
|
||||
logger.Debugf("mactop: %s", err.Error())
|
||||
}
|
||||
|
||||
if ch, err := tryIoreg(ctx, every, logger); err == nil {
|
||||
logger.Info("using ioreg for GPU monitoring")
|
||||
return ch, nil
|
||||
} else {
|
||||
logger.Debugf("ioreg: %s", err.Error())
|
||||
}
|
||||
|
||||
return nil, ErrNoGpuTool
|
||||
}
|
||||
|
||||
// tryIoreg polls `ioreg -r -c IOGPU -d 1 -f` for Apple Silicon GPU stats. It is
|
||||
// a fallback for when mactop is not installed. ioreg exposes GPU utilization and
|
||||
// used memory but not power, temperature, or fan speed.
|
||||
func tryIoreg(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||
if _, err := exec.LookPath("ioreg"); err != nil {
|
||||
return nil, ErrNoGpuTool
|
||||
}
|
||||
|
||||
// Verify ioreg actually reports a GPU device before committing to it, so we
|
||||
// can fall through to ErrNoGpuTool otherwise.
|
||||
if stat := sampleIoreg(ctx); stat == nil {
|
||||
return nil, fmt.Errorf("ioreg reported no GPU device")
|
||||
}
|
||||
|
||||
if every < time.Second {
|
||||
every = time.Second
|
||||
}
|
||||
|
||||
ch := make(chan []GpuStat, 1)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
ticker := time.NewTicker(every)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
stat := sampleIoreg(ctx)
|
||||
if stat == nil {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case ch <- []GpuStat{*stat}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// sampleIoreg runs ioreg once and parses a single GpuStat, or returns nil.
|
||||
func sampleIoreg(ctx context.Context) *GpuStat {
|
||||
out, err := exec.CommandContext(ctx, "ioreg", "-r", "-c", "IOGPU", "-d", "1", "-f").Output()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var memTotalMB int
|
||||
if vmStat, err := mem.VirtualMemory(); err == nil {
|
||||
memTotalMB = int(vmStat.Total / (1024 * 1024))
|
||||
}
|
||||
|
||||
return ParseIoregOutput(out, memTotalMB)
|
||||
}
|
||||
|
||||
// overlayIoregMem replaces a GpuStat's memory fields with the GPU-attributed
|
||||
// unified memory reported by ioreg. mactop only exposes whole-system memory, so
|
||||
// without this the mactop and ioreg backends would report different memory
|
||||
// semantics. It is a no-op when ioreg is unavailable or reports no GPU memory,
|
||||
// leaving the mactop-supplied values in place.
|
||||
func overlayIoregMem(ctx context.Context, stat *GpuStat) {
|
||||
ioStat := sampleIoreg(ctx)
|
||||
if ioStat == nil {
|
||||
return
|
||||
}
|
||||
stat.MemUsedMB = ioStat.MemUsedMB
|
||||
stat.MemTotalMB = ioStat.MemTotalMB
|
||||
stat.MemUtilPct = ioStat.MemUtilPct
|
||||
}
|
||||
|
||||
// tryMactop streams Apple Silicon GPU stats from mactop's headless mode.
|
||||
// See https://github.com/metaspartan/mactop. mactop emits one JSON object per
|
||||
// sample to stdout, which we parse into GpuStat.
|
||||
func tryMactop(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||
if _, err := exec.LookPath("mactop"); err != nil {
|
||||
return nil, ErrNoGpuTool
|
||||
}
|
||||
|
||||
// mactop samples power over the interval, so give it at least a second.
|
||||
intervalMs := int(every.Milliseconds())
|
||||
if intervalMs < 1000 {
|
||||
intervalMs = 1000
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "mactop",
|
||||
"--headless",
|
||||
"--format", "json",
|
||||
"--interval", fmt.Sprintf("%d", intervalMs),
|
||||
)
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mactop stdout pipe failed: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("mactop start failed: %w", err)
|
||||
}
|
||||
|
||||
ch := make(chan []GpuStat, 1)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
// mactop's JSON objects can be large; allow generous line lengths.
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
stat := ParseMactopLine(line)
|
||||
if stat != nil {
|
||||
// mactop only reports whole-system memory; overlay ioreg's
|
||||
// GPU-attributed unified memory so both backends are consistent.
|
||||
overlayIoregMem(ctx, stat)
|
||||
select {
|
||||
case ch <- []GpuStat{*stat}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
cmd.Wait()
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func readSysStats() (SysStat, error) {
|
||||
|
||||
@@ -264,3 +264,50 @@ func TestParseNvidiaSmiLine_ZeroMemoryTotal(t *testing.T) {
|
||||
require.NotNil(t, stat)
|
||||
assert.Equal(t, 0.0, stat.MemUtilPct)
|
||||
}
|
||||
|
||||
const ioregSample = `+-o AGXAcceleratorG13X <class AGXAcceleratorG13X, id 0x1000009a1, registered, matched, active, busy 0 (39191 ms), retain 108>
|
||||
{
|
||||
"model" = "Apple M1 Pro"
|
||||
"gpu-core-count" = 16
|
||||
"PerformanceStatistics" = {"In use system memory (driver)"=0,"Alloc system memory"=14511046656,"Tiler Utilization %"=34,"recoveryCount"=0,"Renderer Utilization %"=34,"Device Utilization %"=34,"In use system memory"=7688503296}
|
||||
"IOClass" = "AGXAcceleratorG13X"
|
||||
}`
|
||||
|
||||
func TestParseIoregOutput_ValidOutput(t *testing.T) {
|
||||
const memTotalMB = 32768
|
||||
|
||||
stat := ParseIoregOutput([]byte(ioregSample), memTotalMB)
|
||||
require.NotNil(t, stat)
|
||||
|
||||
assert.Equal(t, 0, stat.ID)
|
||||
assert.Equal(t, "Apple M1 Pro (16-core GPU)", stat.Name)
|
||||
assert.Equal(t, 34.0, stat.GpuUtilPct)
|
||||
assert.Equal(t, 7688503296/(1024*1024), stat.MemUsedMB)
|
||||
assert.Equal(t, memTotalMB, stat.MemTotalMB)
|
||||
assert.InDelta(t, float64(stat.MemUsedMB)/memTotalMB*100, stat.MemUtilPct, 0.01)
|
||||
// Not exposed by ioreg.
|
||||
assert.Equal(t, 0, stat.TempC)
|
||||
assert.Equal(t, 0.0, stat.PowerDrawW)
|
||||
assert.Equal(t, 0.0, stat.FanSpeedPct)
|
||||
}
|
||||
|
||||
func TestParseIoregOutput_NoGpuDevice(t *testing.T) {
|
||||
stat := ParseIoregOutput([]byte("no gpu here"), 32768)
|
||||
assert.Nil(t, stat)
|
||||
}
|
||||
|
||||
func TestParseIoregOutput_ZeroMemTotal(t *testing.T) {
|
||||
stat := ParseIoregOutput([]byte(ioregSample), 0)
|
||||
require.NotNil(t, stat)
|
||||
assert.Equal(t, 0.0, stat.MemUtilPct)
|
||||
}
|
||||
|
||||
func TestParseIoregOutput_MissingModel(t *testing.T) {
|
||||
const out = `"Device Utilization %"=50,"In use system memory"=1048576`
|
||||
|
||||
stat := ParseIoregOutput([]byte(out), 1024)
|
||||
require.NotNil(t, stat)
|
||||
assert.Equal(t, "Apple GPU", stat.Name)
|
||||
assert.Equal(t, 50.0, stat.GpuUtilPct)
|
||||
assert.Equal(t, 1, stat.MemUsedMB)
|
||||
}
|
||||
|
||||
@@ -22,6 +22,13 @@ func getGpuStats(ctx context.Context, every time.Duration, logger *logmon.Monito
|
||||
logger.Debugf("nvidia-smi: %s", err.Error())
|
||||
}
|
||||
|
||||
if ch, err := tryD3DKMT(ctx, every, logger); err == nil {
|
||||
logger.Info("using D3DKMT for GPU monitoring")
|
||||
return ch, nil
|
||||
} else {
|
||||
logger.Debugf("D3DKMT: %s", err.Error())
|
||||
}
|
||||
|
||||
return nil, ErrNoGpuTool
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
//go:build windows
|
||||
|
||||
package perf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
pdhDLL = windows.NewLazySystemDLL("pdh.dll")
|
||||
procPdhOpenQuery = pdhDLL.NewProc("PdhOpenQueryW")
|
||||
procPdhAddEnglishCounter = pdhDLL.NewProc("PdhAddEnglishCounterW")
|
||||
procPdhCollectQueryData = pdhDLL.NewProc("PdhCollectQueryData")
|
||||
procPdhGetFormattedCounterArray = pdhDLL.NewProc("PdhGetFormattedCounterArrayW")
|
||||
procPdhCloseQuery = pdhDLL.NewProc("PdhCloseQuery")
|
||||
)
|
||||
|
||||
const (
|
||||
pdhFmtDouble = 0x00000200
|
||||
pdhMoreData = 0x800007D2
|
||||
pdhNoData = 0x800007D5
|
||||
)
|
||||
|
||||
type pdhCounterValue struct {
|
||||
CStatus uint32
|
||||
DblVal float64
|
||||
}
|
||||
|
||||
type pdhCounterValueItem struct {
|
||||
SzName *uint16
|
||||
FmtValue pdhCounterValue
|
||||
}
|
||||
|
||||
func init() {
|
||||
var item pdhCounterValueItem
|
||||
if unsafe.Sizeof(item) != 24 {
|
||||
panic(fmt.Sprintf("pdhCounterValueItem size %d != expected 24 on x64", unsafe.Sizeof(item)))
|
||||
}
|
||||
}
|
||||
|
||||
type pdhGpuUtil struct {
|
||||
query uintptr
|
||||
counter uintptr
|
||||
}
|
||||
|
||||
// initPdhGpuUtil creates a PDH query for the GPU Engine utilization counter.
|
||||
// Returns nil with an error if PDH or the counter is unavailable.
|
||||
func initPdhGpuUtil() (*pdhGpuUtil, error) {
|
||||
var query uintptr
|
||||
if ret, _, _ := procPdhOpenQuery.Call(0, 0, uintptr(unsafe.Pointer(&query))); ret != 0 {
|
||||
return nil, fmt.Errorf("PdhOpenQuery: 0x%x", ret)
|
||||
}
|
||||
|
||||
path, _ := windows.UTF16PtrFromString(`\GPU Engine(*)\Utilization Percentage`)
|
||||
var counter uintptr
|
||||
if ret, _, _ := procPdhAddEnglishCounter.Call(
|
||||
query, uintptr(unsafe.Pointer(path)), 0, uintptr(unsafe.Pointer(&counter)),
|
||||
); ret != 0 {
|
||||
procPdhCloseQuery.Call(query)
|
||||
return nil, fmt.Errorf("PdhAddEnglishCounter(GPU Engine): 0x%x", ret)
|
||||
}
|
||||
|
||||
procPdhCollectQueryData.Call(query)
|
||||
|
||||
return &pdhGpuUtil{query: query, counter: counter}, nil
|
||||
}
|
||||
|
||||
// close releases the PDH query handle.
|
||||
func (p *pdhGpuUtil) close() {
|
||||
if p.query != 0 {
|
||||
procPdhCloseQuery.Call(p.query)
|
||||
p.query = 0
|
||||
}
|
||||
}
|
||||
|
||||
// collect reads the PDH counter and returns a map of adapter LUID to
|
||||
// aggregated GPU utilization percentage, summed across all engine instances
|
||||
// per adapter and clamped to 100%.
|
||||
func (p *pdhGpuUtil) collect() map[LUID]float64 {
|
||||
ret, _, _ := procPdhCollectQueryData.Call(p.query)
|
||||
if ret != 0 && ret != pdhNoData {
|
||||
return nil
|
||||
}
|
||||
|
||||
var bufSize uint32
|
||||
var itemCount uint32
|
||||
ret, _, _ = procPdhGetFormattedCounterArray.Call(
|
||||
p.counter, pdhFmtDouble,
|
||||
uintptr(unsafe.Pointer(&bufSize)),
|
||||
uintptr(unsafe.Pointer(&itemCount)),
|
||||
0,
|
||||
)
|
||||
if ret != pdhMoreData || itemCount == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
buf := make([]byte, bufSize)
|
||||
ret, _, _ = procPdhGetFormattedCounterArray.Call(
|
||||
p.counter, pdhFmtDouble,
|
||||
uintptr(unsafe.Pointer(&bufSize)),
|
||||
uintptr(unsafe.Pointer(&itemCount)),
|
||||
uintptr(unsafe.Pointer(&buf[0])),
|
||||
)
|
||||
if ret != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
itemSize := uint32(unsafe.Sizeof(pdhCounterValueItem{}))
|
||||
result := make(map[LUID]float64)
|
||||
|
||||
for i := uint32(0); i < itemCount; i++ {
|
||||
item := (*pdhCounterValueItem)(unsafe.Pointer(&buf[i*itemSize]))
|
||||
if item.FmtValue.CStatus != 0 {
|
||||
continue
|
||||
}
|
||||
luid, ok := parsePdhLuid(windows.UTF16PtrToString(item.SzName))
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
result[luid] += item.FmtValue.DblVal
|
||||
}
|
||||
|
||||
for luid := range result {
|
||||
if result[luid] > 100.0 {
|
||||
result[luid] = 100.0
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// parsePdhLuid extracts the adapter LUID (high and low parts) from a PDH
|
||||
// GPU Engine instance name (e.g. "pid_1234_luid_0x00000000_0x000148BF_phys_0_eng_2_engtype_Compute").
|
||||
func parsePdhLuid(name string) (LUID, bool) {
|
||||
idx := strings.Index(name, "luid_0x")
|
||||
if idx < 0 {
|
||||
return LUID{}, false
|
||||
}
|
||||
rest := name[idx+7:]
|
||||
parts := strings.SplitN(rest, "_", 4)
|
||||
if len(parts) < 3 {
|
||||
return LUID{}, false
|
||||
}
|
||||
hp, err := strconv.ParseUint(parts[0], 16, 32)
|
||||
if err != nil {
|
||||
return LUID{}, false
|
||||
}
|
||||
lpStr := strings.TrimPrefix(parts[1], "0x")
|
||||
lp, err := strconv.ParseUint(lpStr, 16, 32)
|
||||
if err != nil {
|
||||
return LUID{}, false
|
||||
}
|
||||
return LUID{LowPart: uint32(lp), HighPart: int32(hp)}, true
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
//go:build windows
|
||||
|
||||
package perf
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParsePdhLuid_Valid(t *testing.T) {
|
||||
name := `pid_25312_luid_0x00000000_0x000148BF_phys_0_eng_2_engtype_Compute`
|
||||
got, ok := parsePdhLuid(name)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, uint32(0x000148BF), got.LowPart)
|
||||
assert.Equal(t, int32(0x00000000), got.HighPart)
|
||||
}
|
||||
|
||||
func TestParsePdhLuid_ValidNvidia(t *testing.T) {
|
||||
name := `pid_1388_luid_0x00000000_0x00011372_phys_0_eng_8_engtype_Compute_1`
|
||||
got, ok := parsePdhLuid(name)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, uint32(0x00011372), got.LowPart)
|
||||
assert.Equal(t, int32(0x00000000), got.HighPart)
|
||||
}
|
||||
|
||||
func TestParsePdhLuid_NonZeroHighPart(t *testing.T) {
|
||||
name := `pid_1234_luid_0x00000001_0x0000C85A_phys_0_eng_5_engtype_Copy`
|
||||
got, ok := parsePdhLuid(name)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, uint32(0x0000C85A), got.LowPart)
|
||||
assert.Equal(t, int32(0x00000001), got.HighPart)
|
||||
}
|
||||
|
||||
func TestParsePdhLuid_InvalidNoLuid(t *testing.T) {
|
||||
_, ok := parsePdhLuid("invalid_string_without_luid")
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestParsePdhLuid_InvalidEmpty(t *testing.T) {
|
||||
_, ok := parsePdhLuid("")
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestParsePdhLuid_InvalidHex(t *testing.T) {
|
||||
_, ok := parsePdhLuid("pid_1234_luid_0xZZZZ_0xGGGG_phys_0")
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestParsePdhLuid_ShortAfterLuid(t *testing.T) {
|
||||
_, ok := parsePdhLuid("pid_1234_luid_0x00000000")
|
||||
assert.False(t, ok)
|
||||
}
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
@@ -22,6 +21,22 @@ import (
|
||||
|
||||
var ErrStartAborted = fmt.Errorf("aborted")
|
||||
|
||||
// cmdWaitDelay is the upper bound the runtime will wait for child I/O to
|
||||
// drain after the process exits before force-closing the stdout/stderr
|
||||
// pipes. Required so that cmd.Wait() returns even when a forked grandchild
|
||||
// inherits and holds the pipes open (e.g. a shell wrapper that backgrounds
|
||||
// the real binary). killProcess sends the stop signal directly (not via the
|
||||
// cmd context), so this delay is measured from process exit rather than from
|
||||
// the stop request, and stays independent of the caller's graceful timeout.
|
||||
const cmdWaitDelay = 10 * time.Second
|
||||
|
||||
// parentCancelGraceTimeout is the graceful timeout used when the process is
|
||||
// torn down because parentCtx was cancelled (final router teardown or app
|
||||
// shutdown). In the normal flow the process has already been stopped via
|
||||
// Stop() by this point, so killProcess is a no-op kill; the short grace just
|
||||
// bounds the rare case where a process is still alive when its context is cut.
|
||||
const parentCancelGraceTimeout = time.Second
|
||||
|
||||
type runReq struct {
|
||||
timeout time.Duration
|
||||
respond chan error
|
||||
@@ -39,6 +54,7 @@ type waitReadyReq struct {
|
||||
type startResult struct {
|
||||
cmd *exec.Cmd
|
||||
cmdDone chan struct{}
|
||||
cancel context.CancelFunc
|
||||
handlerFn http.HandlerFunc
|
||||
err error
|
||||
}
|
||||
@@ -51,6 +67,11 @@ type ProcessCommand struct {
|
||||
processLogger *logmon.Monitor
|
||||
proxyLogger *logmon.Monitor
|
||||
|
||||
// waitDelay is assigned to cmd.WaitDelay when starting the upstream
|
||||
// process. Defaults to cmdWaitDelay; tests override it to keep the
|
||||
// pipe-close backstop from dominating their runtime.
|
||||
waitDelay time.Duration
|
||||
|
||||
runCh chan runReq
|
||||
stopCh chan stopReq
|
||||
waitReadyCh chan waitReadyReq
|
||||
@@ -85,6 +106,7 @@ func New(
|
||||
runCh: make(chan runReq),
|
||||
stopCh: make(chan stopReq),
|
||||
waitReadyCh: make(chan waitReadyReq),
|
||||
waitDelay: cmdWaitDelay,
|
||||
}
|
||||
p.state.Store(StateStopped)
|
||||
|
||||
@@ -122,6 +144,7 @@ func (p *ProcessCommand) run() {
|
||||
var (
|
||||
cmd *exec.Cmd
|
||||
cmdDone <-chan struct{}
|
||||
cmdCancel context.CancelFunc
|
||||
readyWaiters []waitReadyReq
|
||||
// runResp parks the in-flight Run caller's response channel. The
|
||||
// interface contract is that Run blocks until the process is
|
||||
@@ -164,9 +187,10 @@ func (p *ProcessCommand) run() {
|
||||
setState(StateShutdown)
|
||||
if cmd != nil {
|
||||
p.handler.Store(nil)
|
||||
p.killProcess(cmd, cmdDone, 100*time.Millisecond)
|
||||
p.killProcess(cmd, cmdCancel, cmdDone, parentCancelGraceTimeout)
|
||||
cmd = nil
|
||||
cmdDone = nil
|
||||
cmdCancel = nil
|
||||
}
|
||||
notifyWaiters(fmt.Errorf("[%s] shutdown", p.id))
|
||||
respondRun(fmt.Errorf("[%s] shutdown", p.id))
|
||||
@@ -177,8 +201,12 @@ func (p *ProcessCommand) run() {
|
||||
// cmdDone is nil while no process is running, so this case is
|
||||
// dormant outside of StateReady.
|
||||
case <-cmdDone:
|
||||
if cmdCancel != nil {
|
||||
cmdCancel()
|
||||
}
|
||||
cmd = nil
|
||||
cmdDone = nil
|
||||
cmdCancel = nil
|
||||
p.handler.Store(nil)
|
||||
setState(StateStopped)
|
||||
respondRun(fmt.Errorf("[%s] upstream exited unexpectedly", p.id))
|
||||
@@ -226,6 +254,7 @@ func (p *ProcessCommand) run() {
|
||||
if res.err == nil {
|
||||
cmd = res.cmd
|
||||
cmdDone = res.cmdDone
|
||||
cmdCancel = res.cancel
|
||||
fn := res.handlerFn
|
||||
p.handler.Store(&fn)
|
||||
setState(StateReady)
|
||||
@@ -273,7 +302,7 @@ func (p *ProcessCommand) run() {
|
||||
cancelStart()
|
||||
res := <-resultCh
|
||||
if res.cmd != nil {
|
||||
p.killProcess(res.cmd, res.cmdDone, stop.timeout)
|
||||
p.killProcess(res.cmd, res.cancel, res.cmdDone, stop.timeout)
|
||||
}
|
||||
setState(StateStopped)
|
||||
notifyWaiters(ErrStartAborted)
|
||||
@@ -293,7 +322,7 @@ func (p *ProcessCommand) run() {
|
||||
setState(StateShutdown)
|
||||
res := <-resultCh
|
||||
if res.cmd != nil {
|
||||
p.killProcess(res.cmd, res.cmdDone, 100*time.Millisecond)
|
||||
p.killProcess(res.cmd, res.cancel, res.cmdDone, parentCancelGraceTimeout)
|
||||
}
|
||||
notifyWaiters(fmt.Errorf("[%s] shutdown", p.id))
|
||||
respondRun(fmt.Errorf("[%s] shutdown", p.id))
|
||||
@@ -310,9 +339,10 @@ func (p *ProcessCommand) run() {
|
||||
case stop := <-p.stopCh:
|
||||
if cmd != nil {
|
||||
setState(StateStopping)
|
||||
p.killProcess(cmd, cmdDone, stop.timeout)
|
||||
p.killProcess(cmd, cmdCancel, cmdDone, stop.timeout)
|
||||
cmd = nil
|
||||
cmdDone = nil
|
||||
cmdCancel = nil
|
||||
p.handler.Store(nil)
|
||||
}
|
||||
// Stop is a no-op (and not an error) when already Stopped — this
|
||||
@@ -377,46 +407,71 @@ func (p *ProcessCommand) doStart(startCtx context.Context, healthCheckTimeout ti
|
||||
reverseProxy.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
cmd := exec.Command(args[0], args[1:]...)
|
||||
// cmdCtx + cmd.Cancel are wired as a safety net: if the context is ever
|
||||
// cancelled while the process is alive, cmd.Cancel sends SIGTERM / CmdStop
|
||||
// and the runtime escalates to SIGKILL after cmd.WaitDelay. In the normal
|
||||
// teardown path killProcess sends the stop signal directly instead, so
|
||||
// cmd.WaitDelay only acts as the inherited-pipe backstop measured from
|
||||
// process exit (see killProcess).
|
||||
cmdCtx, cmdCancel := context.WithCancel(context.Background())
|
||||
cmd := exec.CommandContext(cmdCtx, args[0], args[1:]...)
|
||||
cmd.Stderr = p.processLogger
|
||||
cmd.Stdout = p.processLogger
|
||||
cmd.Env = append(cmd.Environ(), p.config.Env...)
|
||||
cmd.Cancel = func() error { return p.sendStopSignal(cmd) }
|
||||
cmd.WaitDelay = p.waitDelay
|
||||
setProcAttributes(cmd)
|
||||
|
||||
p.proxyLogger.Debugf("<%s> Executing start command: %s, env: %s", p.id, strings.Join(args, " "), strings.Join(p.config.Env, ", "))
|
||||
|
||||
cmdDone := make(chan struct{})
|
||||
if err := cmd.Start(); err != nil {
|
||||
cmdCancel()
|
||||
return startResult{err: fmt.Errorf("failed to start command '%s': %w", strings.Join(args, " "), err)}
|
||||
}
|
||||
|
||||
go func() {
|
||||
waitErr := cmd.Wait()
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
p.proxyLogger.Debugf("<%s> process exited: code=%d, err=%v", p.id, exitErr.ExitCode(), waitErr)
|
||||
} else if waitErr != nil {
|
||||
p.proxyLogger.Debugf("<%s> process exited with error: %v", p.id, waitErr)
|
||||
} else {
|
||||
switch st := p.State(); {
|
||||
case waitErr == nil:
|
||||
p.proxyLogger.Debugf("<%s> process exited cleanly", p.id)
|
||||
case st == StateStopping || st == StateShutdown:
|
||||
// Expected: we force-terminated the process. A forced kill exits
|
||||
// the child with a non-zero code (e.g. taskkill /f on Windows
|
||||
// yields exit status 1), so this is not an error.
|
||||
p.proxyLogger.Debugf("<%s> process stopped by llama-swap: %v", p.id, waitErr)
|
||||
default:
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
p.proxyLogger.Debugf("<%s> process exited: code=%d, err=%v", p.id, exitErr.ExitCode(), waitErr)
|
||||
} else {
|
||||
p.proxyLogger.Debugf("<%s> process exited with error: %v", p.id, waitErr)
|
||||
}
|
||||
}
|
||||
close(cmdDone)
|
||||
}()
|
||||
|
||||
abort := func(err error) startResult {
|
||||
p.killProcess(cmd, cmdCancel, cmdDone, 5*time.Second)
|
||||
return startResult{err: err}
|
||||
}
|
||||
prematureExit := func() startResult {
|
||||
cmdCancel()
|
||||
return startResult{err: fmt.Errorf("upstream command exited prematurely")}
|
||||
}
|
||||
|
||||
if startCtx.Err() != nil {
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: ErrStartAborted}
|
||||
return abort(ErrStartAborted)
|
||||
}
|
||||
|
||||
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
||||
if checkEndpoint == "none" {
|
||||
return startResult{cmd: cmd, cmdDone: cmdDone, handlerFn: handlerFn}
|
||||
return startResult{cmd: cmd, cmdDone: cmdDone, cancel: cmdCancel, handlerFn: handlerFn}
|
||||
}
|
||||
|
||||
// Wait 250ms for the command to start up before health checking
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: ErrStartAborted}
|
||||
return abort(ErrStartAborted)
|
||||
case <-time.After(250 * time.Millisecond):
|
||||
}
|
||||
|
||||
@@ -424,16 +479,14 @@ func (p *ProcessCommand) doStart(startCtx context.Context, healthCheckTimeout ti
|
||||
for {
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: ErrStartAborted}
|
||||
return abort(ErrStartAborted)
|
||||
case <-cmdDone:
|
||||
return startResult{err: fmt.Errorf("upstream command exited prematurely")}
|
||||
return prematureExit()
|
||||
default:
|
||||
}
|
||||
|
||||
if time.Now().After(deadline) {
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: fmt.Errorf("health check timed out after %v", healthCheckTimeout)}
|
||||
return abort(fmt.Errorf("health check timed out after %v", healthCheckTimeout))
|
||||
}
|
||||
|
||||
req, _ := http.NewRequestWithContext(startCtx, "GET", p.config.CheckEndpoint, nil)
|
||||
@@ -445,42 +498,99 @@ func (p *ProcessCommand) doStart(startCtx context.Context, healthCheckTimeout ti
|
||||
p.proxyLogger.Infof("<%s> Health check passed on %s%s", p.id, p.config.Proxy, p.config.CheckEndpoint)
|
||||
break
|
||||
} else if startCtx.Err() != nil {
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: ErrStartAborted}
|
||||
return abort(ErrStartAborted)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: ErrStartAborted}
|
||||
return abort(ErrStartAborted)
|
||||
case <-cmdDone:
|
||||
return startResult{err: fmt.Errorf("upstream command exited prematurely")}
|
||||
return prematureExit()
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
return startResult{cmd: cmd, cmdDone: cmdDone, handlerFn: handlerFn}
|
||||
return startResult{cmd: cmd, cmdDone: cmdDone, cancel: cmdCancel, handlerFn: handlerFn}
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) killProcess(cmd *exec.Cmd, cmdDone <-chan struct{}, gracefulTimeout time.Duration) {
|
||||
// sendStopSignal runs the configured CmdStop (if any) or sends SIGTERM to
|
||||
// the upstream process. Wired up as cmd.Cancel so it fires whenever the
|
||||
// cmd's context is cancelled.
|
||||
func (p *ProcessCommand) sendStopSignal(cmd *exec.Cmd) error {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return
|
||||
p.processLogger.Debugf("<%s> sendStopSignal() called with nil cmd or process, nothing to stop", p.id)
|
||||
return nil
|
||||
}
|
||||
|
||||
pid := cmd.Process.Pid
|
||||
if p.config.CmdStop != "" {
|
||||
p.processLogger.Debugf("<%s> sendStopSignal() using CmdStop %q for pid %d", p.id, p.config.CmdStop, pid)
|
||||
stopArgs, err := config.SanitizeCommand(
|
||||
strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", cmd.Process.Pid)),
|
||||
strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", pid)),
|
||||
)
|
||||
if err == nil {
|
||||
p.processLogger.Debugf("<%s> sendStopSignal() running stop command: %s", p.id, strings.Join(stopArgs, " "))
|
||||
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
|
||||
stopCmd.Env = cmd.Env
|
||||
setProcAttributes(stopCmd)
|
||||
stopCmd.Run()
|
||||
} else {
|
||||
cmd.Process.Signal(syscall.SIGTERM)
|
||||
runErr := stopCmd.Run()
|
||||
if runErr != nil {
|
||||
p.processLogger.Errorf("<%s> sendStopSignal() stop command failed: %v", p.id, runErr)
|
||||
} else {
|
||||
p.processLogger.Debugf("<%s> sendStopSignal() stop command completed for pid %d", p.id, pid)
|
||||
}
|
||||
return runErr
|
||||
}
|
||||
} else {
|
||||
cmd.Process.Signal(syscall.SIGTERM)
|
||||
// fall through to SIGTERM if sanitize failed
|
||||
p.processLogger.Errorf("<%s> sendStopSignal() failed to sanitize CmdStop %q: %v, falling back to terminateProcessTree", p.id, p.config.CmdStop, err)
|
||||
}
|
||||
// On Unix this SIGTERMs the whole process group so a forked grandchild
|
||||
// (e.g. a shell wrapper that backgrounds the real binary) is taken down
|
||||
// with the parent rather than orphaned.
|
||||
p.processLogger.Debugf("<%s> sendStopSignal() no CmdStop configured, calling terminateProcessTree for pid %d", p.id, pid)
|
||||
termErr := terminateProcessTree(cmd)
|
||||
if termErr != nil {
|
||||
p.processLogger.Errorf("<%s> sendStopSignal() terminateProcessTree failed for pid %d: %v", p.id, pid, termErr)
|
||||
}
|
||||
return termErr
|
||||
}
|
||||
|
||||
// killProcess terminates the upstream process. The flow:
|
||||
//
|
||||
// 1. Send the graceful stop signal (CmdStop / SIGTERM) directly — NOT by
|
||||
// cancelling cmdCtx. Cancelling the context would start cmd.WaitDelay
|
||||
// immediately, which force-kills the process WaitDelay after the signal
|
||||
// and would silently cap gracefulTimeout at WaitDelay whenever
|
||||
// gracefulTimeout is the longer of the two.
|
||||
// 2. We wait up to gracefulTimeout for the process to exit on its own.
|
||||
// 3. If still alive, we SIGKILL the process group directly (Unix) so any
|
||||
// forked descendant is force-terminated alongside the parent.
|
||||
// 4. We wait on cmdDone. cmd.WaitDelay (set when the cmd was built) is the
|
||||
// critical backstop here: once the process exits, if a forked grandchild
|
||||
// inherited the stdout/stderr pipes and is still holding them, the runtime
|
||||
// force-closes the pipes WaitDelay after the exit and cmd.Wait() unblocks.
|
||||
// Because we never cancelled the context, that WaitDelay timer measures
|
||||
// from process exit (see os/exec awaitGoroutines), not from this call.
|
||||
// Without WaitDelay this select would hang forever (the v219 bug).
|
||||
//
|
||||
// cancel() is still invoked (deferred) to release the context, but only after
|
||||
// the process has exited and os/exec's ctx watcher has already torn down, so it
|
||||
// never re-fires cmd.Cancel.
|
||||
func (p *ProcessCommand) killProcess(cmd *exec.Cmd, cancel context.CancelFunc, cmdDone <-chan struct{}, gracefulTimeout time.Duration) {
|
||||
if cancel == nil {
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
// Deliver CmdStop / SIGTERM in a goroutine so a slow or hanging CmdStop
|
||||
// cannot block the run() goroutine; the gracefulTimeout + Process.Kill
|
||||
// path below still guarantees teardown.
|
||||
if cmd != nil {
|
||||
go func() {
|
||||
p.proxyLogger.Debugf("[%s] sending stop signal with timeout %v", p.id, gracefulTimeout)
|
||||
if err := p.sendStopSignal(cmd); err != nil {
|
||||
p.proxyLogger.Warnf("[%s] stop signal failed: %v", p.id, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
timer := time.NewTimer(gracefulTimeout)
|
||||
@@ -488,10 +598,16 @@ func (p *ProcessCommand) killProcess(cmd *exec.Cmd, cmdDone <-chan struct{}, gra
|
||||
|
||||
select {
|
||||
case <-cmdDone:
|
||||
return
|
||||
case <-timer.C:
|
||||
cmd.Process.Kill()
|
||||
<-cmdDone
|
||||
}
|
||||
|
||||
if cmd != nil {
|
||||
// SIGKILL the whole process group on Unix so any descendant that
|
||||
// ignored or outlived the graceful signal is force-terminated too.
|
||||
_ = killProcessTree(cmd)
|
||||
}
|
||||
<-cmdDone
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) ID() string {
|
||||
|
||||
@@ -0,0 +1,262 @@
|
||||
//go:build !windows
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
)
|
||||
|
||||
// TestProcessCommand_StopForkingWrapper is a regression for the bug reported
|
||||
// against v219 where Stop would hang indefinitely when the upstream command
|
||||
// is a shell wrapper that forks the real binary (e.g. `#!/bin/bash` then
|
||||
// `"$@"`). After SIGTERM the wrapper dies but the grandchild inherits the
|
||||
// stdout/stderr pipes; cmd.Wait() blocks waiting for the pipe-copy goroutine
|
||||
// to drain EOF, which never happens while the grandchild holds the fds.
|
||||
//
|
||||
// The fix is cmd.WaitDelay (combined with exec.CommandContext + cmd.Cancel),
|
||||
// which causes the runtime to force-close the pipes after the delay so
|
||||
// cmd.Wait() — and therefore Stop — returns.
|
||||
func TestProcessCommand_StopForkingWrapper(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
port := getFreePort(t)
|
||||
dir := t.TempDir()
|
||||
pidFile := filepath.Join(dir, "child.pid")
|
||||
|
||||
// Wrapper script: backgrounds the child (which inherits stdout/stderr),
|
||||
// records its PID for cleanup, then waits. When SIGTERM hits bash it
|
||||
// dies without forwarding the signal; the grandchild keeps running and
|
||||
// keeps the inherited pipe fds open. This is the scenario reported in
|
||||
// the v219 regression.
|
||||
wrapper := filepath.Join(dir, "wrapper.sh")
|
||||
script := fmt.Sprintf("#!/bin/bash\n%q -port %d -silent &\necho $! > %q\nwait\n",
|
||||
simpleResponderPath, port, pidFile)
|
||||
if err := os.WriteFile(wrapper, []byte(script), 0o755); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { killChildFromPidFile(pidFile) })
|
||||
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: wrapper,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
})
|
||||
// Shrink the pipe-close backstop so the test doesn't sit at the
|
||||
// production default (10s). Must be set before Run() so doStart picks
|
||||
// it up when building the cmd.
|
||||
const testWaitDelay = 250 * time.Millisecond
|
||||
p.waitDelay = testWaitDelay
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
|
||||
// Stop must return within a bounded time even though the grandchild
|
||||
// is still holding the pipe open. Budget is generous on top of
|
||||
// testWaitDelay to absorb scheduling jitter on slow CI runners; the
|
||||
// pre-fix behaviour was an unbounded hang, so any reasonable cap
|
||||
// distinguishes pass from fail.
|
||||
stopReturned := make(chan error, 1)
|
||||
stopStart := time.Now()
|
||||
go func() { stopReturned <- p.Stop(testStopTimeout) }()
|
||||
|
||||
const stopBudget = testWaitDelay + 2*time.Second
|
||||
select {
|
||||
case err := <-stopReturned:
|
||||
if err != nil {
|
||||
t.Fatalf("Stop: %v", err)
|
||||
}
|
||||
t.Logf("Stop returned in %v", time.Since(stopStart))
|
||||
case <-time.After(stopBudget):
|
||||
t.Fatalf("Stop did not return within %v — cmd.Wait() likely hung on inherited pipe", stopBudget)
|
||||
}
|
||||
|
||||
if got := p.State(); got != StateStopped {
|
||||
t.Errorf("after Stop: expected state %s, got %s", StateStopped, got)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Errorf("Run did not return after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_StopHonorsGracefulTimeout is a regression for the bug
|
||||
// where cmd.WaitDelay capped the graceful shutdown window. killProcess used to
|
||||
// cancel the cmd context to deliver SIGTERM, which starts cmd.WaitDelay
|
||||
// immediately; a process whose SIGTERM handler needs longer than WaitDelay to
|
||||
// finish was force-killed early even though Stop was given a much longer
|
||||
// timeout. The fix sends the signal directly so WaitDelay measures from process
|
||||
// exit (its inherited-pipe backstop role), leaving the graceful window to the
|
||||
// caller's Stop timeout.
|
||||
func TestProcessCommand_StopHonorsGracefulTimeout(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
marker := filepath.Join(dir, "graceful.done")
|
||||
ready := filepath.Join(dir, "trap.ready")
|
||||
|
||||
// On SIGTERM, sleep past the (short) WaitDelay, then write the marker and
|
||||
// exit cleanly. If WaitDelay still drove the kill, bash would be SIGKILLed
|
||||
// mid-handler and the marker would never be written. The ready file is
|
||||
// written only after the trap is installed so the test does not race
|
||||
// SIGTERM ahead of it (CheckEndpoint:none marks ready before bash runs).
|
||||
script := filepath.Join(dir, "graceful.sh")
|
||||
body := fmt.Sprintf(
|
||||
"#!/bin/bash\ncleanup() { sleep 0.6; echo done > %q; exit 0; }\ntrap cleanup SIGTERM\necho ready > %q\nwhile true; do sleep 0.1; done\n",
|
||||
marker, ready,
|
||||
)
|
||||
if err := os.WriteFile(script, []byte(body), 0o755); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: script,
|
||||
Proxy: "http://127.0.0.1:1", // unused: health check disabled
|
||||
CheckEndpoint: "none",
|
||||
})
|
||||
// WaitDelay shorter than the handler's 0.6s sleep, and far shorter than the
|
||||
// Stop timeout below — this is the window the old code mis-killed in.
|
||||
p.waitDelay = 200 * time.Millisecond
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
|
||||
// Wait until the trap is installed before stopping.
|
||||
trapDeadline := time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
if _, err := os.Stat(ready); err == nil {
|
||||
break
|
||||
}
|
||||
if time.Now().After(trapDeadline) {
|
||||
t.Fatalf("script did not install SIGTERM trap in time")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
stopStart := time.Now()
|
||||
if err := p.Stop(5 * time.Second); err != nil {
|
||||
t.Fatalf("Stop: %v", err)
|
||||
}
|
||||
elapsed := time.Since(stopStart)
|
||||
|
||||
// The handler must have run to completion (marker written) rather than
|
||||
// being force-killed at waitDelay.
|
||||
if _, err := os.Stat(marker); err != nil {
|
||||
t.Fatalf("graceful handler did not complete (marker missing): %v", err)
|
||||
}
|
||||
// And Stop must have waited for the handler (>~0.6s), not returned at the
|
||||
// 200ms waitDelay.
|
||||
if elapsed < 500*time.Millisecond {
|
||||
t.Fatalf("Stop returned in %v — process was killed before its graceful handler finished", elapsed)
|
||||
}
|
||||
|
||||
if got := p.State(); got != StateStopped {
|
||||
t.Errorf("after Stop: expected state %s, got %s", StateStopped, got)
|
||||
}
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Errorf("Run did not return after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_StopReapsForkedGrandchild verifies that stopping a forking
|
||||
// wrapper takes down the backgrounded grandchild too, rather than leaving it as
|
||||
// an orphan. The fix is Setpgid (runtime_unix.go): the wrapper leads its own
|
||||
// process group, so the stop signal is delivered to the whole group via the
|
||||
// negative PID and reaches the grandchild the wrapper never reaped.
|
||||
func TestProcessCommand_StopReapsForkedGrandchild(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
port := getFreePort(t)
|
||||
dir := t.TempDir()
|
||||
pidFile := filepath.Join(dir, "child.pid")
|
||||
|
||||
wrapper := filepath.Join(dir, "wrapper.sh")
|
||||
script := fmt.Sprintf("#!/bin/bash\n%q -port %d -silent &\necho $! > %q\nwait\n",
|
||||
simpleResponderPath, port, pidFile)
|
||||
if err := os.WriteFile(wrapper, []byte(script), 0o755); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { killChildFromPidFile(pidFile) })
|
||||
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: wrapper,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
})
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
|
||||
// Read the grandchild PID the wrapper recorded.
|
||||
var childPID int
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
data, err := os.ReadFile(pidFile)
|
||||
if err == nil {
|
||||
if pid, perr := strconv.Atoi(strings.TrimSpace(string(data))); perr == nil && pid > 0 {
|
||||
childPID = pid
|
||||
break
|
||||
}
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatalf("wrapper did not record grandchild PID")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop: %v", err)
|
||||
}
|
||||
|
||||
// After Stop the grandchild must be gone. Signal 0 probes liveness without
|
||||
// actually sending a signal; give it a brief window to exit after the
|
||||
// group SIGTERM.
|
||||
proc, err := os.FindProcess(childPID)
|
||||
if err != nil {
|
||||
t.Fatalf("FindProcess: %v", err)
|
||||
}
|
||||
gone := false
|
||||
for i := 0; i < 100; i++ {
|
||||
if err := proc.Signal(syscall.Signal(0)); err != nil {
|
||||
gone = true
|
||||
break
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
if !gone {
|
||||
t.Errorf("grandchild PID %d still alive after Stop — process group was not reaped", childPID)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Errorf("Run did not return after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
// killChildFromPidFile reads a PID written by the wrapper script and SIGKILLs
|
||||
// it so leaked orphans don't accumulate between test runs. Best-effort.
|
||||
func killChildFromPidFile(pidFile string) {
|
||||
data, err := os.ReadFile(pidFile)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
|
||||
if err != nil || pid <= 0 {
|
||||
return
|
||||
}
|
||||
proc, err := os.FindProcess(pid)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = proc.Kill()
|
||||
}
|
||||
@@ -4,9 +4,41 @@ package process
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes
|
||||
// setProcAttributes starts the upstream in its own process group (Setpgid) so
|
||||
// the entire process tree can be signalled at once via its negative PID. This
|
||||
// is what lets us reap a forked grandchild — e.g. a shell wrapper that
|
||||
// backgrounds the real binary and exits — instead of leaking it as an orphan
|
||||
// that holds the inherited stdout/stderr pipes open.
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
// No-op on Unix systems
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
}
|
||||
|
||||
// terminateProcessTree sends SIGTERM to the whole process group led by the
|
||||
// command, giving every process in the tree a chance to shut down gracefully.
|
||||
func terminateProcessTree(cmd *exec.Cmd) error {
|
||||
return signalProcessTree(cmd, syscall.SIGTERM)
|
||||
}
|
||||
|
||||
// killProcessTree sends SIGKILL to the whole process group, force-terminating
|
||||
// every process in the tree.
|
||||
func killProcessTree(cmd *exec.Cmd) error {
|
||||
return signalProcessTree(cmd, syscall.SIGKILL)
|
||||
}
|
||||
|
||||
// signalProcessTree signals the process group led by cmd.Process. Because the
|
||||
// child was started with Setpgid it is its own group leader (pgid == pid), so
|
||||
// targeting -pid reaches the child and every descendant still in the group.
|
||||
// Falls back to signalling just the child if the group send fails (e.g. the
|
||||
// group has already drained), so we never silently skip the signal.
|
||||
func signalProcessTree(cmd *exec.Cmd, sig syscall.Signal) error {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
if err := syscall.Kill(-cmd.Process.Pid, sig); err != nil {
|
||||
return cmd.Process.Signal(sig)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,14 +3,51 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes
|
||||
// setProcAttributes sets platform-specific process attributes. CREATE_NO_WINDOW
|
||||
// keeps the upstream from spawning its own console window.
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
HideWindow: true,
|
||||
CreationFlags: 0x08000000, // CREATE_NO_WINDOW
|
||||
}
|
||||
}
|
||||
|
||||
// terminateProcessTree requests a graceful shutdown of the whole process tree
|
||||
// rooted at cmd.Process. Windows has no SIGTERM or process-group signalling, so
|
||||
// we shell out to `taskkill /t`, which walks the child tree by PID — the
|
||||
// equivalent of signalling a Unix process group. Without /f, taskkill asks the
|
||||
// processes to close rather than force-killing them.
|
||||
func terminateProcessTree(cmd *exec.Cmd) error {
|
||||
return taskkillProcessTree(cmd, false)
|
||||
}
|
||||
|
||||
// killProcessTree force-terminates the whole process tree rooted at cmd.Process
|
||||
// via `taskkill /f /t`, so any descendant that ignored or outlived the graceful
|
||||
// request is killed alongside the parent rather than leaked as an orphan.
|
||||
func killProcessTree(cmd *exec.Cmd) error {
|
||||
return taskkillProcessTree(cmd, true)
|
||||
}
|
||||
|
||||
// taskkillProcessTree runs taskkill against cmd.Process.Pid. The /t flag
|
||||
// terminates the process together with any child processes it started, which is
|
||||
// the Windows analogue of signalling a Unix process group via its negative PID.
|
||||
// When force is true the /f flag force-kills; otherwise taskkill requests a
|
||||
// graceful close.
|
||||
func taskkillProcessTree(cmd *exec.Cmd, force bool) error {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
args := make([]string, 0, 4)
|
||||
if force {
|
||||
args = append(args, "/f")
|
||||
}
|
||||
args = append(args, "/t", "/pid", fmt.Sprintf("%d", cmd.Process.Pid))
|
||||
kill := exec.Command("taskkill", args...)
|
||||
setProcAttributes(kill)
|
||||
return kill.Run()
|
||||
}
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
//go:build !windows
|
||||
|
||||
package process
|
||||
|
||||
// SetupTreeCleanup is a no-op on non-Windows platforms, where upstream process
|
||||
// teardown is handled via process-group signalling (see runtime_unix.go).
|
||||
func SetupTreeCleanup() error { return nil }
|
||||
@@ -0,0 +1,50 @@
|
||||
//go:build windows
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// SetupTreeCleanup assigns the current process to a Windows Job Object
|
||||
// configured with JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE. Upstream processes
|
||||
// spawned afterwards are associated with the same job, so when llama-swap exits
|
||||
// for any reason — graceful shutdown, a forced second Ctrl+C, or a crash — the
|
||||
// OS terminates the whole job and reaps every child instead of leaving orphans
|
||||
// behind. It is the parent-side complement to the per-process teardown in
|
||||
// runtime_windows.go.
|
||||
//
|
||||
// The job handle is intentionally leaked for the lifetime of the process: the
|
||||
// kill-on-close behaviour fires when the last handle is released, which the OS
|
||||
// does when the process exits.
|
||||
func SetupTreeCleanup() error {
|
||||
job, err := windows.CreateJobObject(nil, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("CreateJobObject: %w", err)
|
||||
}
|
||||
|
||||
info := windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION{
|
||||
BasicLimitInformation: windows.JOBOBJECT_BASIC_LIMIT_INFORMATION{
|
||||
LimitFlags: windows.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
|
||||
},
|
||||
}
|
||||
if _, err := windows.SetInformationJobObject(
|
||||
job,
|
||||
windows.JobObjectExtendedLimitInformation,
|
||||
uintptr(unsafe.Pointer(&info)),
|
||||
uint32(unsafe.Sizeof(info)),
|
||||
); err != nil {
|
||||
windows.CloseHandle(job)
|
||||
return fmt.Errorf("SetInformationJobObject: %w", err)
|
||||
}
|
||||
|
||||
if err := windows.AssignProcessToJobObject(job, windows.CurrentProcess()); err != nil {
|
||||
windows.CloseHandle(job)
|
||||
return fmt.Errorf("AssignProcessToJobObject: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
+155
-425
@@ -11,6 +11,8 @@ import (
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
type shutdownReq struct {
|
||||
@@ -24,66 +26,38 @@ type unloadReq struct {
|
||||
respond chan struct{}
|
||||
}
|
||||
|
||||
type handlerReq struct {
|
||||
model string
|
||||
ctx context.Context
|
||||
respond chan handlerResp
|
||||
positionCh chan int
|
||||
}
|
||||
|
||||
type handlerResp struct {
|
||||
handleFunc http.HandlerFunc
|
||||
err error
|
||||
}
|
||||
|
||||
type swapDone struct {
|
||||
modelID string
|
||||
err error
|
||||
}
|
||||
|
||||
type serveDoneEvent struct {
|
||||
modelID string
|
||||
}
|
||||
|
||||
type activeSwap struct {
|
||||
modelID string
|
||||
evict []string
|
||||
waiters []handlerReq
|
||||
}
|
||||
|
||||
// swapPlanner is the only piece of behaviour that differs between concrete
|
||||
// routers. baseRouter never inspects its internals.
|
||||
type swapPlanner interface {
|
||||
// EvictionFor returns running model IDs that must be stopped before
|
||||
// target can serve. alsoRunning lists models the baseRouter has already
|
||||
// committed to loading (in-flight swaps) which the planner cannot see
|
||||
// via process.State() yet. Pure decision; must not log.
|
||||
EvictionFor(target string, alsoRunning []string) []string
|
||||
|
||||
// OnSwapStart runs once at the start of every swap. Planners may log
|
||||
// their decision here at whatever verbosity they choose.
|
||||
OnSwapStart(target string)
|
||||
}
|
||||
|
||||
// baseRouter owns the channels, run-loop, and orchestration code shared by
|
||||
// every concrete router. Concrete routers embed *baseRouter and supply a
|
||||
// swapPlanner that captures how their eviction set is decided.
|
||||
// baseRouter owns the channels, run-loop, and process machinery shared by every
|
||||
// concrete router. Concrete routers embed *baseRouter and supply a
|
||||
// scheduler.Swapper describing how eviction sets are decided. baseRouter
|
||||
// implements scheduler.Effects so the scheduler can call back for side-effects.
|
||||
type baseRouter struct {
|
||||
name string
|
||||
config config.Config
|
||||
processes map[string]process.Process
|
||||
logger *logmon.Monitor
|
||||
planner swapPlanner
|
||||
schedule scheduler.Scheduler
|
||||
|
||||
// shutdownCtx governs the request machinery: cancelling it tells grant()
|
||||
// and ServeHTTP to stop granting and reject callers. It is deliberately
|
||||
// separate from procCtx — see procCtx below.
|
||||
shutdownCtx context.Context
|
||||
shutdownFn context.CancelFunc
|
||||
shuttingDown atomic.Bool
|
||||
|
||||
handlerCh chan handlerReq
|
||||
// procCtx is the parent context for every managed process and governs
|
||||
// process lifetime only. handleShutdown stops processes gracefully via
|
||||
// Stop() and cancels procCtx afterwards, so teardown is never a context
|
||||
// cancel racing the graceful path (which collapsed the grace to 100ms and
|
||||
// let the caller return before children were reaped — see process run loop).
|
||||
procCtx context.Context
|
||||
procCancel context.CancelFunc
|
||||
|
||||
handlerCh chan scheduler.HandlerReq
|
||||
cancelCh chan scheduler.HandlerReq
|
||||
shutdownCh chan shutdownReq
|
||||
unloadCh chan unloadReq
|
||||
swapDoneCh chan swapDone
|
||||
serveDoneCh chan serveDoneEvent
|
||||
swapDoneCh chan scheduler.SwapDone
|
||||
serveDoneCh chan scheduler.ServeDoneEvent
|
||||
|
||||
runDone chan struct{}
|
||||
|
||||
@@ -95,23 +69,38 @@ type baseRouter struct {
|
||||
testProcessed chan struct{}
|
||||
}
|
||||
|
||||
func newBaseRouter(name string, conf config.Config, processes map[string]process.Process, planner swapPlanner, logger *logmon.Monitor) *baseRouter {
|
||||
func newBaseRouter(
|
||||
name string,
|
||||
conf config.Config,
|
||||
processes map[string]process.Process,
|
||||
logger *logmon.Monitor,
|
||||
planner scheduler.Swapper,
|
||||
) (*baseRouter, error) {
|
||||
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
|
||||
return &baseRouter{
|
||||
procCtx, procCancel := context.WithCancel(context.Background())
|
||||
b := &baseRouter{
|
||||
name: name,
|
||||
config: conf,
|
||||
processes: processes,
|
||||
logger: logger,
|
||||
planner: planner,
|
||||
shutdownCtx: shutdownCtx,
|
||||
shutdownFn: shutdownFn,
|
||||
handlerCh: make(chan handlerReq),
|
||||
procCtx: procCtx,
|
||||
procCancel: procCancel,
|
||||
handlerCh: make(chan scheduler.HandlerReq),
|
||||
cancelCh: make(chan scheduler.HandlerReq),
|
||||
shutdownCh: make(chan shutdownReq),
|
||||
unloadCh: make(chan unloadReq),
|
||||
swapDoneCh: make(chan swapDone),
|
||||
serveDoneCh: make(chan serveDoneEvent),
|
||||
swapDoneCh: make(chan scheduler.SwapDone),
|
||||
serveDoneCh: make(chan scheduler.ServeDoneEvent),
|
||||
runDone: make(chan struct{}),
|
||||
}
|
||||
sched, err := scheduler.New(conf, name, logger, planner, b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b.schedule = sched
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (b *baseRouter) notifyProcessed() {
|
||||
@@ -123,30 +112,31 @@ func (b *baseRouter) notifyProcessed() {
|
||||
func (b *baseRouter) run() {
|
||||
defer close(b.runDone)
|
||||
|
||||
active := make(map[string]*activeSwap)
|
||||
inFlight := make(map[string]int)
|
||||
var queued []handlerReq
|
||||
|
||||
for {
|
||||
select {
|
||||
case req := <-b.shutdownCh:
|
||||
b.handleShutdown(req, active, queued)
|
||||
b.handleShutdown(req)
|
||||
return
|
||||
|
||||
case req := <-b.handlerCh:
|
||||
b.handleRequest(req, active, inFlight, &queued)
|
||||
b.schedule.OnRequest(req)
|
||||
b.notifyProcessed()
|
||||
|
||||
case req := <-b.cancelCh:
|
||||
b.schedule.OnCancel(req)
|
||||
b.notifyProcessed()
|
||||
|
||||
case req := <-b.unloadCh:
|
||||
b.handleUnload(req, active, inFlight, &queued)
|
||||
b.schedule.OnUnload(req.targets, req.timeout)
|
||||
close(req.respond)
|
||||
b.notifyProcessed()
|
||||
|
||||
case ev := <-b.swapDoneCh:
|
||||
b.handleSwapDone(ev, active, inFlight, &queued)
|
||||
b.schedule.OnSwapDone(ev)
|
||||
b.notifyProcessed()
|
||||
|
||||
case ev := <-b.serveDoneCh:
|
||||
b.handleServeDone(ev, active, inFlight, &queued)
|
||||
b.schedule.OnServeDone(ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -163,37 +153,68 @@ func (b *baseRouter) run() {
|
||||
// down, the send never lands, one of the other select cases fires, and we
|
||||
// report back that the grant did NOT happen.
|
||||
//
|
||||
// That distinction matters for in-flight bookkeeping — see grantHandler.
|
||||
func (b *baseRouter) grant(req handlerReq, resp handlerResp) bool {
|
||||
// That distinction matters for in-flight bookkeeping — see GrantServe.
|
||||
func (b *baseRouter) grant(req scheduler.HandlerReq, resp scheduler.HandlerResp) bool {
|
||||
select {
|
||||
case req.respond <- resp:
|
||||
case req.Respond <- resp:
|
||||
return true
|
||||
case <-req.ctx.Done():
|
||||
case <-req.Ctx.Done():
|
||||
return false
|
||||
case <-b.shutdownCtx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// grantHandler is the "this caller can now use process p" path. It does
|
||||
// two things that must stay locked together:
|
||||
//
|
||||
// 1. Hand the caller a wrapped p.ServeHTTP (via trackedServe) so when the
|
||||
// HTTP request finishes, the run loop hears about it.
|
||||
// 2. Bump inFlight[modelID] so the router knows this process is busy and
|
||||
// refuses to evict it until the count comes back down.
|
||||
//
|
||||
// The increment is gated on grant() returning true. If grant() returns
|
||||
// false, the caller already walked away and trackedServe will never run —
|
||||
// which means no matching decrement will ever arrive on serveDoneCh.
|
||||
// Incrementing in that case would strand the counter at >0 forever and
|
||||
// the router would never again be willing to swap this model out.
|
||||
//
|
||||
// In short: increment if and only if we know a decrement is coming.
|
||||
func (b *baseRouter) grantHandler(req handlerReq, modelID string, p process.Process, inFlight map[string]int) {
|
||||
if b.grant(req, handlerResp{handleFunc: b.trackedServe(modelID, p)}) {
|
||||
inFlight[modelID]++
|
||||
// ModelState implements scheduler.Effects.
|
||||
func (b *baseRouter) ModelState(modelID string) (process.ProcessState, bool) {
|
||||
p, ok := b.processes[modelID]
|
||||
if !ok {
|
||||
var zero process.ProcessState
|
||||
return zero, false
|
||||
}
|
||||
return p.State(), true
|
||||
}
|
||||
|
||||
// StartSwap implements scheduler.Effects, launching the swap goroutine.
|
||||
func (b *baseRouter) StartSwap(modelID string, evict []string) {
|
||||
go b.doSwap(modelID, evict)
|
||||
}
|
||||
|
||||
// GrantError implements scheduler.Effects.
|
||||
func (b *baseRouter) GrantError(req scheduler.HandlerReq, err error) {
|
||||
b.grant(req, scheduler.HandlerResp{Err: err})
|
||||
}
|
||||
|
||||
// GrantServe implements scheduler.Effects. It hands the caller a wrapped
|
||||
// p.ServeHTTP (via trackedServe) so the run loop hears about the request
|
||||
// finishing, and reports whether the caller received it. The scheduler bumps
|
||||
// its in-flight count only on a true return: if grant() returns false the
|
||||
// caller already walked away and trackedServe will never run, so no matching
|
||||
// decrement will ever arrive — incrementing would strand the counter at >0 and
|
||||
// the router would never again be willing to evict this model.
|
||||
func (b *baseRouter) GrantServe(req scheduler.HandlerReq, modelID string) bool {
|
||||
p := b.processes[modelID]
|
||||
return b.grant(req, scheduler.HandlerResp{HandleFunc: b.trackedServe(modelID, p)})
|
||||
}
|
||||
|
||||
// StopProcesses implements scheduler.Effects, stopping the named processes in
|
||||
// parallel and blocking until all have stopped.
|
||||
func (b *baseRouter) StopProcesses(timeout time.Duration, ids []string) {
|
||||
var wg sync.WaitGroup
|
||||
for _, id := range ids {
|
||||
p, ok := b.processes[id]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(id string, p process.Process) {
|
||||
defer wg.Done()
|
||||
if err := p.Stop(timeout); err != nil {
|
||||
b.logger.Warnf("%s: stopping %s failed: %v", b.name, id, err)
|
||||
}
|
||||
}(id, p)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// trackedServe is the wrapper that closes the loop on in-flight tracking.
|
||||
@@ -210,7 +231,7 @@ func (b *baseRouter) trackedServe(modelID string, p process.Process) http.Handle
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
select {
|
||||
case b.serveDoneCh <- serveDoneEvent{modelID: modelID}:
|
||||
case b.serveDoneCh <- scheduler.ServeDoneEvent{ModelID: modelID}:
|
||||
case <-b.shutdownCtx.Done():
|
||||
}
|
||||
}()
|
||||
@@ -218,240 +239,6 @@ func (b *baseRouter) trackedServe(modelID string, p process.Process) http.Handle
|
||||
}
|
||||
}
|
||||
|
||||
// handleRequest decides what to do with one incoming ServeHTTP request. It is
|
||||
// called from run() and never blocks indefinitely: any work that has to wait
|
||||
// (starting a process, stopping siblings, waiting for ready) is deferred to
|
||||
// a swap goroutine and reported back via swapDoneCh.
|
||||
//
|
||||
// The decision tree, in order:
|
||||
//
|
||||
// 1. Unknown model — respond with ErrNoLocalModelFound and move on.
|
||||
// 2. A swap to the same model is already in flight — attach this waiter so
|
||||
// one swap serves all callers that asked for the same model.
|
||||
// 3. Fast path — the target process is already ready, the planner sees
|
||||
// nothing to evict, and no in-flight swap is evicting it. Hand back its
|
||||
// ServeHTTP immediately (wrapped so the run loop knows when it ends).
|
||||
// 4. Would collide with an in-flight swap (we'd stop their target, or
|
||||
// they're stopping us) — park in the queue for handleSwapDone to drain.
|
||||
// 5. Would evict a process that is still handling requests — park in the
|
||||
// queue. handleServeDone will retry when the busy process drains.
|
||||
// 6. Otherwise — start a new swap. This may run in parallel with other
|
||||
// active swaps when their evict sets don't intersect.
|
||||
func (b *baseRouter) handleRequest(req handlerReq, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
||||
// (1) Unknown model.
|
||||
p, ok := b.processes[req.model]
|
||||
if !ok {
|
||||
b.logger.Debugf("%s: model %s not handled by this router", b.name, req.model)
|
||||
b.grant(req, handlerResp{err: ErrNoLocalModelFound})
|
||||
return
|
||||
}
|
||||
|
||||
// (2) Join an in-flight swap for the same model.
|
||||
if s, ok := active[req.model]; ok {
|
||||
b.logger.Debugf("%s: joining in-flight swap for model %s (%d waiters)", b.name, req.model, len(s.waiters)+1)
|
||||
s.waiters = append(s.waiters, req)
|
||||
return
|
||||
}
|
||||
|
||||
evict := b.planner.EvictionFor(req.model, activeTargets(active, req.model))
|
||||
|
||||
// (3) Fast path: ready, nothing to evict, and nobody is evicting us.
|
||||
if p.State() == process.StateReady && len(evict) == 0 && !collidesWith(req.model, evict, active) {
|
||||
b.logger.Debugf("%s: fast-path serving model %s (already ready)", b.name, req.model)
|
||||
b.grantHandler(req, req.model, p, inFlight)
|
||||
return
|
||||
}
|
||||
|
||||
// (4) Collision with an in-flight swap — queue.
|
||||
if collidesWith(req.model, evict, active) {
|
||||
b.logger.Debugf("%s: queuing request for model %s (collides with in-flight swap)", b.name, req.model)
|
||||
*queued = append(*queued, req)
|
||||
b.broadcastQueuePositions(*queued)
|
||||
return
|
||||
}
|
||||
|
||||
// (5) Would evict a busy process — queue until it drains.
|
||||
if conflictsWithInFlight(evict, inFlight) {
|
||||
b.logger.Debugf("%s: queuing request for model %s (would evict in-flight process)", b.name, req.model)
|
||||
*queued = append(*queued, req)
|
||||
b.broadcastQueuePositions(*queued)
|
||||
return
|
||||
}
|
||||
|
||||
// (6) Start a new (possibly parallel) swap.
|
||||
b.logger.Debugf("%s: starting swap for model %s, evicting %v", b.name, req.model, evict)
|
||||
s := b.startSwap(req, evict)
|
||||
active[s.modelID] = s
|
||||
}
|
||||
|
||||
// handleSwapDone is called from run() when a swap goroutine reports that it
|
||||
// has finished. It fans out the result to every waiter that joined this swap,
|
||||
// removes the swap from the active map, and then walks the queue once,
|
||||
// promoting any items that no longer collide with the remaining active set.
|
||||
// FIFO order is preserved: items still blocked stay in place.
|
||||
func (b *baseRouter) handleSwapDone(ev swapDone, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
||||
s, ok := active[ev.modelID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(active, ev.modelID)
|
||||
|
||||
for _, w := range s.waiters {
|
||||
if ev.err != nil {
|
||||
b.grant(w, handlerResp{err: ev.err})
|
||||
} else {
|
||||
p := b.processes[ev.modelID]
|
||||
b.grantHandler(w, ev.modelID, p, inFlight)
|
||||
}
|
||||
}
|
||||
|
||||
b.drainQueue(active, inFlight, queued)
|
||||
}
|
||||
|
||||
// handleServeDone is called from run() each time a tracked ServeHTTP
|
||||
// finishes. It decrements the per-model in-flight count and, when that
|
||||
// drops to zero, retries the queue: requests whose swap was deferred
|
||||
// because they would have evicted this (now-idle) process can now proceed.
|
||||
func (b *baseRouter) handleServeDone(ev serveDoneEvent, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
||||
inFlight[ev.modelID]--
|
||||
if inFlight[ev.modelID] <= 0 {
|
||||
delete(inFlight, ev.modelID)
|
||||
b.drainQueue(active, inFlight, queued)
|
||||
}
|
||||
}
|
||||
|
||||
// drainQueue walks the queued requests in order, re-running the handleRequest
|
||||
// decision tree against the (now smaller) active set. Items that can now start
|
||||
// or join become satisfied; items still blocked remain queued in original
|
||||
// order so they get another chance on the next swap completion.
|
||||
func (b *baseRouter) drainQueue(active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
||||
if len(*queued) == 0 {
|
||||
return
|
||||
}
|
||||
pending := *queued
|
||||
var remaining []handlerReq
|
||||
for _, req := range pending {
|
||||
p, ok := b.processes[req.model]
|
||||
if !ok {
|
||||
b.grant(req, handlerResp{err: ErrNoLocalModelFound})
|
||||
continue
|
||||
}
|
||||
if s, ok := active[req.model]; ok {
|
||||
b.logger.Debugf("%s: queued request for model %s now joining in-flight swap", b.name, req.model)
|
||||
s.waiters = append(s.waiters, req)
|
||||
continue
|
||||
}
|
||||
evict := b.planner.EvictionFor(req.model, activeTargets(active, req.model))
|
||||
if p.State() == process.StateReady && len(evict) == 0 && !collidesWith(req.model, evict, active) {
|
||||
b.logger.Debugf("%s: queued request for model %s now served fast-path", b.name, req.model)
|
||||
b.grantHandler(req, req.model, p, inFlight)
|
||||
continue
|
||||
}
|
||||
if collidesWith(req.model, evict, active) {
|
||||
remaining = append(remaining, req)
|
||||
continue
|
||||
}
|
||||
if conflictsWithInFlight(evict, inFlight) {
|
||||
remaining = append(remaining, req)
|
||||
continue
|
||||
}
|
||||
b.logger.Debugf("%s: queued request for model %s now starting swap, evicting %v", b.name, req.model, evict)
|
||||
s := b.startSwap(req, evict)
|
||||
active[s.modelID] = s
|
||||
}
|
||||
*queued = remaining
|
||||
b.broadcastQueuePositions(*queued)
|
||||
}
|
||||
|
||||
// broadcastQueuePositions sends each queued request its current 1-indexed
|
||||
// position. Sends are non-blocking: if the channel is full, the old value is
|
||||
// drained first so the consumer always sees the latest position.
|
||||
func (b *baseRouter) broadcastQueuePositions(queued []handlerReq) {
|
||||
for i, req := range queued {
|
||||
pos := i + 1
|
||||
select {
|
||||
case req.positionCh <- pos:
|
||||
default:
|
||||
select {
|
||||
case <-req.positionCh:
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case req.positionCh <- pos:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *baseRouter) startSwap(initial handlerReq, evict []string) *activeSwap {
|
||||
swap := &activeSwap{
|
||||
modelID: initial.model,
|
||||
evict: evict,
|
||||
waiters: []handlerReq{initial},
|
||||
}
|
||||
b.planner.OnSwapStart(initial.model)
|
||||
go b.doSwap(initial.model, evict)
|
||||
return swap
|
||||
}
|
||||
|
||||
// activeTargets returns the IDs of every in-flight swap target except exclude.
|
||||
// baseRouter passes this to the planner so eviction decisions account for
|
||||
// models that have been committed to but have not yet transitioned to
|
||||
// StateStarting in their process state machine.
|
||||
func activeTargets(active map[string]*activeSwap, exclude string) []string {
|
||||
if len(active) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(active))
|
||||
for id := range active {
|
||||
if id == exclude {
|
||||
continue
|
||||
}
|
||||
out = append(out, id)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// collidesWith reports whether a new swap with this target and evict set can
|
||||
// safely run alongside the currently active swaps. Same-target callers should
|
||||
// JOIN (handled before this) — they do not collide with themselves.
|
||||
func collidesWith(target string, evict []string, active map[string]*activeSwap) bool {
|
||||
for id, s := range active {
|
||||
if id == target {
|
||||
continue
|
||||
}
|
||||
if containsString(evict, id) {
|
||||
return true
|
||||
}
|
||||
if containsString(s.evict, target) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// conflictsWithInFlight reports whether any model in evict is still handling
|
||||
// requests. Stopping a busy process would cancel its callers' connections,
|
||||
// so the router defers the swap until those callers finish.
|
||||
func conflictsWithInFlight(evict []string, inFlight map[string]int) bool {
|
||||
for _, m := range evict {
|
||||
if inFlight[m] > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsString(xs []string, s string) bool {
|
||||
for _, x := range xs {
|
||||
if x == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *baseRouter) doSwap(modelID string, toStop []string) {
|
||||
timeout := b.healthCheckTimeout()
|
||||
|
||||
@@ -479,29 +266,24 @@ func (b *baseRouter) doSwap(modelID string, toStop []string) {
|
||||
err := target.WaitReady(b.shutdownCtx)
|
||||
|
||||
select {
|
||||
case b.swapDoneCh <- swapDone{modelID: modelID, err: err}:
|
||||
case b.swapDoneCh <- scheduler.SwapDone{ModelID: modelID, Err: err}:
|
||||
case <-b.shutdownCtx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
func (b *baseRouter) handleShutdown(req shutdownReq, active map[string]*activeSwap, queued []handlerReq) {
|
||||
func (b *baseRouter) handleShutdown(req shutdownReq) {
|
||||
shutdownErr := fmt.Errorf("%s is shutting down", b.name)
|
||||
|
||||
// Cancel shutdownCtx first so any waiter that is currently parked on
|
||||
// its respond channel can exit via its own shutdownCtx.Done() branch.
|
||||
// The grant calls below then either land (waiter happened to receive
|
||||
// The OnShutdown grants below then either land (waiter happened to receive
|
||||
// before noticing shutdown) or fall through immediately via grant's
|
||||
// shutdownCtx case — either way the waiter sees a non-OK response.
|
||||
// This does NOT touch processes: their lifetime is procCtx, cancelled
|
||||
// only after the graceful Stop() calls below have reaped them.
|
||||
b.shutdownFn()
|
||||
|
||||
for _, s := range active {
|
||||
for _, w := range s.waiters {
|
||||
b.grant(w, handlerResp{err: shutdownErr})
|
||||
}
|
||||
}
|
||||
for _, w := range queued {
|
||||
b.grant(w, handlerResp{err: shutdownErr})
|
||||
}
|
||||
b.schedule.OnShutdown(shutdownErr)
|
||||
|
||||
stopTimeout := req.timeout
|
||||
if stopTimeout <= 0 {
|
||||
@@ -535,6 +317,11 @@ func (b *baseRouter) handleShutdown(req shutdownReq, active map[string]*activeSw
|
||||
<-done
|
||||
}
|
||||
|
||||
// Every process is stopped (children reaped via Stop()). Cancel procCtx so
|
||||
// the process run-loop goroutines exit; they are already StateStopped, so
|
||||
// this is a clean no-op kill rather than a forced teardown.
|
||||
b.procCancel()
|
||||
|
||||
req.respond <- nil
|
||||
}
|
||||
|
||||
@@ -607,75 +394,6 @@ func (b *baseRouter) Unload(timeout time.Duration, models ...string) {
|
||||
<-req.respond
|
||||
}
|
||||
|
||||
// handleUnload runs on the run loop in response to an Unload call. It
|
||||
// reconciles router-owned state with the impending Stop, then performs
|
||||
// the Stop synchronously so callers of Unload remain blocked until each
|
||||
// targeted process has actually exited.
|
||||
func (b *baseRouter) handleUnload(req unloadReq, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
||||
unloadErr := fmt.Errorf("%s: model unloaded", b.name)
|
||||
|
||||
targetSet := make(map[string]bool, len(req.targets))
|
||||
for _, id := range req.targets {
|
||||
targetSet[id] = true
|
||||
}
|
||||
|
||||
// Release waiters of any in-flight swap whose target is being
|
||||
// unloaded. The swap goroutine itself is left to finish on its own;
|
||||
// when its swapDone arrives, handleSwapDone will find no entry in
|
||||
// active and silently drop it.
|
||||
for id := range targetSet {
|
||||
s, ok := active[id]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, w := range s.waiters {
|
||||
b.grant(w, handlerResp{err: unloadErr})
|
||||
}
|
||||
delete(active, id)
|
||||
}
|
||||
|
||||
// Drop queued requests addressed to unloaded models. Requests for
|
||||
// other models stay queued and may benefit from drainQueue at the end.
|
||||
if len(*queued) > 0 {
|
||||
kept := (*queued)[:0]
|
||||
for _, w := range *queued {
|
||||
if targetSet[w.model] {
|
||||
b.grant(w, handlerResp{err: unloadErr})
|
||||
continue
|
||||
}
|
||||
kept = append(kept, w)
|
||||
}
|
||||
*queued = kept
|
||||
}
|
||||
|
||||
// Stop the targeted processes. Done synchronously so Unload's caller
|
||||
// can rely on "after Unload returns, the process is stopped". inFlight
|
||||
// is intentionally NOT cleared here: each dying handler will fire its
|
||||
// trackedServe defer and reach handleServeDone in the normal way once
|
||||
// the run loop is free again.
|
||||
var wg sync.WaitGroup
|
||||
for id := range targetSet {
|
||||
p, ok := b.processes[id]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(id string, p process.Process) {
|
||||
defer wg.Done()
|
||||
if err := p.Stop(req.timeout); err != nil {
|
||||
b.logger.Warnf("%s: unloading %s failed: %v", b.name, id, err)
|
||||
}
|
||||
}(id, p)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Removing entries from active above may have unblocked queued
|
||||
// requests that previously collided with the now-cancelled swaps.
|
||||
b.drainQueue(active, inFlight, queued)
|
||||
|
||||
close(req.respond)
|
||||
}
|
||||
|
||||
func (b *baseRouter) Shutdown(timeout time.Duration) error {
|
||||
if !b.shuttingDown.CompareAndSwap(false, true) {
|
||||
return fmt.Errorf("%s shutdown already in progress", b.name)
|
||||
@@ -691,24 +409,24 @@ func (b *baseRouter) Shutdown(timeout time.Duration) error {
|
||||
|
||||
func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
if b.shuttingDown.Load() {
|
||||
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||
shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||
return
|
||||
}
|
||||
|
||||
data, err := FetchContext(req, b.config)
|
||||
data, err := shared.FetchContext(req, b.config)
|
||||
if err != nil {
|
||||
SendError(w, req, err)
|
||||
shared.SendError(w, req, err)
|
||||
return
|
||||
}
|
||||
|
||||
hr := handlerReq{
|
||||
model: data.ModelID,
|
||||
ctx: req.Context(),
|
||||
// Unbuffered: a successful send on respond proves the waiter is
|
||||
hr := scheduler.HandlerReq{
|
||||
Model: data.ModelID,
|
||||
Ctx: req.Context(),
|
||||
// Unbuffered: a successful send on Respond proves the waiter is
|
||||
// alive and consuming. grant() relies on this to avoid handing a
|
||||
// handleFunc to a cancelled waiter and leaking the inFlight count.
|
||||
respond: make(chan handlerResp),
|
||||
positionCh: make(chan int, 1),
|
||||
Respond: make(chan scheduler.HandlerResp),
|
||||
PositionCh: make(chan int, 1),
|
||||
}
|
||||
|
||||
select {
|
||||
@@ -716,7 +434,7 @@ func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
case <-req.Context().Done():
|
||||
return
|
||||
case <-b.shutdownCtx.Done():
|
||||
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||
shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -736,7 +454,7 @@ func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case pos := <-hr.positionCh:
|
||||
case pos := <-hr.PositionCh:
|
||||
lw.setUpdate(fmt.Sprintf("Queue position: #%d", pos))
|
||||
case <-swapCtx.Done():
|
||||
return
|
||||
@@ -745,31 +463,43 @@ func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
}()
|
||||
}
|
||||
|
||||
var resp handlerResp
|
||||
select {
|
||||
case resp = <-hr.respond:
|
||||
// finishLoading stops the loading stream and fences its goroutine off from
|
||||
// the ResponseWriter before the real handler (or ServeHTTP's return)
|
||||
// reclaims it. release() must run even when waitForCompletion times out:
|
||||
// otherwise a still-streaming goroutine flushes a finalized response and
|
||||
// panics on the recycled *bufio.Writer.
|
||||
finishLoading := func() {
|
||||
cancelLoad()
|
||||
if lw != nil {
|
||||
lw.waitForCompletion(1 * time.Second)
|
||||
lw.release()
|
||||
}
|
||||
}
|
||||
|
||||
var resp scheduler.HandlerResp
|
||||
select {
|
||||
case resp = <-hr.Respond:
|
||||
finishLoading()
|
||||
case <-req.Context().Done():
|
||||
cancelLoad()
|
||||
if lw != nil {
|
||||
lw.waitForCompletion(1 * time.Second)
|
||||
finishLoading()
|
||||
// Notify the scheduler so it can prune this request from its queue
|
||||
// and swap waiters. Without this, a queued request whose client left
|
||||
// would sit in the scheduler until drainQueue eventually starts a
|
||||
// wasted model load for it.
|
||||
select {
|
||||
case b.cancelCh <- hr:
|
||||
case <-b.shutdownCtx.Done():
|
||||
}
|
||||
return
|
||||
case <-b.shutdownCtx.Done():
|
||||
cancelLoad()
|
||||
if lw != nil {
|
||||
lw.waitForCompletion(1 * time.Second)
|
||||
}
|
||||
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||
finishLoading()
|
||||
shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||
return
|
||||
}
|
||||
|
||||
if resp.err != nil {
|
||||
SendError(w, req, resp.err)
|
||||
if resp.Err != nil {
|
||||
shared.SendError(w, req, resp.Err)
|
||||
return
|
||||
}
|
||||
resp.handleFunc(w, req)
|
||||
resp.HandleFunc(w, req)
|
||||
}
|
||||
|
||||
+15
-614
@@ -5,35 +5,34 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
|
||||
)
|
||||
|
||||
// stubPlanner is a swapPlanner that returns a fixed eviction list per target
|
||||
// and never logs. It lets the base-router tests cover shared run-loop
|
||||
// behaviour without dragging in either real router's eviction rules.
|
||||
type stubPlanner struct {
|
||||
evict map[string][]string
|
||||
}
|
||||
// These tests cover baseRouter's own machinery — the run loop, process
|
||||
// lifecycle (doSwap), grant/ServeHTTP plumbing, Unload, and Shutdown. The
|
||||
// scheduling decision logic (queueing, collation, eviction collisions) lives in
|
||||
// the scheduler package and is tested directly there; see fifo_test.go.
|
||||
|
||||
func (s *stubPlanner) EvictionFor(target string, _ []string) []string {
|
||||
if s.evict == nil {
|
||||
return nil
|
||||
}
|
||||
return s.evict[target]
|
||||
}
|
||||
// stubPlanner evicts nothing. baseRouter tests drive the run loop through the
|
||||
// default FIFO scheduler without exercising any particular eviction policy.
|
||||
type stubPlanner struct{}
|
||||
|
||||
func (s *stubPlanner) OnSwapStart(string) {}
|
||||
func (s *stubPlanner) EvictionFor(string, []string) []string { return nil }
|
||||
func (s *stubPlanner) OnSwapStart(string, []string) {}
|
||||
|
||||
func newTestBase(t *testing.T, processes map[string]process.Process, planner swapPlanner) *baseRouter {
|
||||
func newTestBase(t *testing.T, processes map[string]process.Process, planner scheduler.Swapper) *baseRouter {
|
||||
t.Helper()
|
||||
conf := config.Config{HealthCheckTimeout: 5}
|
||||
b := newBaseRouter("test", conf, processes, planner, logmon.NewWriter(io.Discard))
|
||||
b, err := newBaseRouter("test", conf, processes, logmon.NewWriter(io.Discard), planner)
|
||||
if err != nil {
|
||||
t.Fatalf("newBaseRouter: %v", err)
|
||||
}
|
||||
b.testProcessed = make(chan struct{}, 64)
|
||||
go b.run()
|
||||
t.Cleanup(func() {
|
||||
@@ -157,114 +156,6 @@ func TestBaseRouter_Unload_StopsInParallel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_Unload_ReleasesActiveSwapWaiters verifies that Unload
|
||||
// rejoins router state: a request whose swap to the unloaded model is
|
||||
// still in progress receives an error, instead of being abandoned
|
||||
// against a process that's about to vanish.
|
||||
func TestBaseRouter_Unload_ReleasesActiveSwapWaiters(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
// autoReady=false: the swap parks on WaitReady so we can interrupt
|
||||
// it with Unload before it completes.
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w, newRequest("a"))
|
||||
close(done)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1) // handlerReq absorbed; swap started
|
||||
<-a.runStarted
|
||||
|
||||
b.Unload(time.Second, "a")
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("ServeHTTP did not return after Unload")
|
||||
}
|
||||
if w.Code == http.StatusOK {
|
||||
t.Errorf("expected non-OK status after Unload, got %d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if a.State() != process.StateStopped {
|
||||
t.Errorf("a state=%q want stopped", a.State())
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_Unload_DropsQueuedRequests verifies that queued requests
|
||||
// for an unloaded model receive an error rather than sitting forever in
|
||||
// the queue against state the router no longer maintains.
|
||||
func TestBaseRouter_Unload_DropsQueuedRequests(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
// Loading B evicts A — so a request for B while A is loading queues.
|
||||
planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}}
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, planner)
|
||||
|
||||
// r1 starts the swap to A and parks on WaitReady (autoReady=false).
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
<-a.runStarted
|
||||
|
||||
// r2 for B collides with A's in-flight swap and queues.
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
// Unload B — r2 (queued, targeting B) must be released with an error.
|
||||
b.Unload(time.Second, "b")
|
||||
|
||||
select {
|
||||
case <-done2:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("queued B request did not return after Unload(b)")
|
||||
}
|
||||
if w2.Code == http.StatusOK {
|
||||
t.Errorf("queued B request: expected non-OK status, got %d", w2.Code)
|
||||
}
|
||||
if got := pb.runCalls.Load(); got != 0 {
|
||||
t.Errorf("b.runCalls=%d want 0 (B should never have been started)", got)
|
||||
}
|
||||
|
||||
// Release r1 so the test cleans up cleanly.
|
||||
a.markReady()
|
||||
select {
|
||||
case <-done1:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("r1 did not complete after a.markReady")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_FastPath(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
b.ServeHTTP(w, newRequest("a"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := a.serveCalls.Load(); got != 1 {
|
||||
t.Errorf("serveCalls=%d want 1", got)
|
||||
}
|
||||
if got := a.runCalls.Load(); got != 0 {
|
||||
t.Errorf("runCalls=%d want 0 (fast path should not start)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_OnDemandStart(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.autoReady = true
|
||||
@@ -285,43 +176,6 @@ func TestBaseRouter_OnDemandStart(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_ConcurrentSameModel(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
// autoReady=false so the swap parks on WaitReady until we release it.
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||
|
||||
const N = 5
|
||||
var wg sync.WaitGroup
|
||||
codes := make([]int, N)
|
||||
for i := 0; i < N; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
w := httptest.NewRecorder()
|
||||
b.ServeHTTP(w, newRequest("a"))
|
||||
codes[i] = w.Code
|
||||
}(i)
|
||||
}
|
||||
|
||||
waitProcessed(t, b.testProcessed, N) // all N handlerReqs absorbed by run()
|
||||
<-a.runStarted // swap goroutine reached Run()
|
||||
a.markReady()
|
||||
wg.Wait()
|
||||
|
||||
for i, c := range codes {
|
||||
if c != http.StatusOK {
|
||||
t.Errorf("request %d: status=%d", i, c)
|
||||
}
|
||||
}
|
||||
if got := a.runCalls.Load(); got != 1 {
|
||||
t.Errorf("runCalls=%d want 1 (single swap should issue one Run)", got)
|
||||
}
|
||||
if got := a.serveCalls.Load(); got != N {
|
||||
t.Errorf("serveCalls=%d want %d", got, N)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_ContextCancel(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
// autoReady=false so swap parks forever until we mark ready.
|
||||
@@ -364,459 +218,6 @@ func TestBaseRouter_ContextCancel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_QueuedDifferentModel(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pa := newFakeProcess("b")
|
||||
|
||||
// Loading b must stop a.
|
||||
planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}}
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pa}, planner)
|
||||
|
||||
// First request starts a swap to A; A's autoReady=false so it parks.
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
<-a.runStarted
|
||||
|
||||
// Second request for B should queue while A's swap is in flight.
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
if got := pa.runCalls.Load(); got != 0 {
|
||||
t.Errorf("b started early: runCalls=%d want 0 while A's swap is pending", got)
|
||||
}
|
||||
|
||||
// Release A's swap. B's swap should then run.
|
||||
a.markReady()
|
||||
waitProcessed(t, b.testProcessed, 1) // swapDone for A → B's swap kicked off
|
||||
<-pa.runStarted
|
||||
|
||||
select {
|
||||
case <-done1:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("A request did not complete")
|
||||
}
|
||||
pa.markReady()
|
||||
select {
|
||||
case <-done2:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("queued B request did not complete after A's swap")
|
||||
}
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Errorf("B status=%d body=%q", w2.Code, w2.Body.String())
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 1 {
|
||||
t.Errorf("a.stopCalls=%d want 1 (B's swap must stop A)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_QueueCollation verifies that incoming requests of the form
|
||||
// a, b, c, a, b, c collapse into three swaps (one per model) and that the
|
||||
// second request for each model rides the fast path — either joining the
|
||||
// active swap, or being pulled out of the queue when handleSwapDone promotes
|
||||
// the next model.
|
||||
func TestBaseRouter_QueueCollation(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
pc := newFakeProcess("c")
|
||||
|
||||
// Each model evicts the other two so all swaps are mutually exclusive.
|
||||
planner := &stubPlanner{evict: map[string][]string{
|
||||
"a": {"b", "c"},
|
||||
"b": {"a", "c"},
|
||||
"c": {"a", "b"},
|
||||
}}
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner)
|
||||
|
||||
var (
|
||||
completedMu sync.Mutex
|
||||
completed []string
|
||||
)
|
||||
record := func(id string) {
|
||||
completedMu.Lock()
|
||||
defer completedMu.Unlock()
|
||||
completed = append(completed, id)
|
||||
}
|
||||
|
||||
ids := []string{"a", "b", "c", "a", "b", "c"}
|
||||
var wg sync.WaitGroup
|
||||
for _, id := range ids {
|
||||
id := id
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
w := httptest.NewRecorder()
|
||||
b.ServeHTTP(w, newRequest(id))
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("%s: status=%d body=%q", id, w.Code, w.Body.String())
|
||||
return
|
||||
}
|
||||
record(id)
|
||||
}()
|
||||
// Wait for run() to absorb this request before launching the next,
|
||||
// so handlerCh receives them in launch order.
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
}
|
||||
|
||||
// All 6 are now parked in run()'s waiters/queue. Release each swap in
|
||||
// sequence, waiting deterministically for each promotion to fire.
|
||||
<-a.runStarted
|
||||
a.markReady()
|
||||
waitProcessed(t, b.testProcessed, 1) // swapDone(a) → b swap kicked off
|
||||
|
||||
<-pb.runStarted
|
||||
pb.markReady()
|
||||
waitProcessed(t, b.testProcessed, 1) // swapDone(b) → c swap kicked off
|
||||
|
||||
<-pc.runStarted
|
||||
pc.markReady()
|
||||
wg.Wait()
|
||||
|
||||
if got := len(completed); got != 6 {
|
||||
t.Fatalf("completed=%v want 6", completed)
|
||||
}
|
||||
|
||||
// run() fans out responses in model-grouped order (a1,a2 → b1,b2 → c1,c2)
|
||||
// but waiter goroutines may be scheduled in any order after their respond
|
||||
// channel fires, so completion order isn't deterministic. Per-model counts
|
||||
// (combined with the runCalls checks below) are sufficient to prove queue
|
||||
// collation collapsed each pair into a single swap.
|
||||
aDone, bDone, cDone := 0, 0, 0
|
||||
for _, id := range completed {
|
||||
switch id {
|
||||
case "a":
|
||||
aDone++
|
||||
case "b":
|
||||
bDone++
|
||||
case "c":
|
||||
cDone++
|
||||
}
|
||||
}
|
||||
if aDone != 2 || bDone != 2 || cDone != 2 {
|
||||
t.Errorf("per-model counts: a=%d b=%d c=%d, want 2 each (order=%v)", aDone, bDone, cDone, completed)
|
||||
}
|
||||
|
||||
// Single swap per model — the second request for each must have ridden
|
||||
// the fast path (joined active swap or joined a queued sibling), not
|
||||
// triggered an extra Run.
|
||||
if got := a.runCalls.Load(); got != 1 {
|
||||
t.Errorf("a.runCalls=%d want 1", got)
|
||||
}
|
||||
if got := pb.runCalls.Load(); got != 1 {
|
||||
t.Errorf("b.runCalls=%d want 1", got)
|
||||
}
|
||||
if got := pc.runCalls.Load(); got != 1 {
|
||||
t.Errorf("c.runCalls=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_ConcurrentDisjointSwaps verifies that two requests with
|
||||
// non-conflicting evict sets are loaded in parallel: both Run() calls happen
|
||||
// before either process is marked ready.
|
||||
func TestBaseRouter_ConcurrentDisjointSwaps(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
|
||||
// Empty evict sets for both: they can load in parallel.
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, &stubPlanner{})
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
// Both swaps must have reached Run() before either is marked ready —
|
||||
// proves they ran in parallel rather than serializing.
|
||||
<-a.runStarted
|
||||
<-pb.runStarted
|
||||
|
||||
a.markReady()
|
||||
pb.markReady()
|
||||
|
||||
select {
|
||||
case <-done1:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("request A did not complete")
|
||||
}
|
||||
select {
|
||||
case <-done2:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("request B did not complete")
|
||||
}
|
||||
|
||||
if w1.Code != http.StatusOK {
|
||||
t.Errorf("A status=%d body=%q", w1.Code, w1.Body.String())
|
||||
}
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Errorf("B status=%d body=%q", w2.Code, w2.Body.String())
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 0 {
|
||||
t.Errorf("a.stopCalls=%d want 0 (parallel swap, no eviction)", got)
|
||||
}
|
||||
if got := pb.stopCalls.Load(); got != 0 {
|
||||
t.Errorf("b.stopCalls=%d want 0 (parallel swap, no eviction)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_QueueDrainPromotesMultiple verifies that completing one swap
|
||||
// unblocks every queued request that no longer collides — they all start in
|
||||
// parallel rather than one-per-completion.
|
||||
func TestBaseRouter_QueueDrainPromotesMultiple(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
pc := newFakeProcess("c")
|
||||
|
||||
// A's swap evicts both B and C, so B and C must queue. Once A finishes
|
||||
// B and C themselves have empty evict sets, so they can start together.
|
||||
planner := &stubPlanner{evict: map[string][]string{
|
||||
"a": {"b", "c"},
|
||||
}}
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner)
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
<-a.runStarted
|
||||
|
||||
// B and C arrive while A is loading. evict_b and evict_c are empty,
|
||||
// but collidesWith returns true because they appear in A's evict set.
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
w3 := httptest.NewRecorder()
|
||||
done3 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w3, newRequest("c"))
|
||||
close(done3)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
if got := pb.runCalls.Load(); got != 0 {
|
||||
t.Errorf("b started early: runCalls=%d", got)
|
||||
}
|
||||
if got := pc.runCalls.Load(); got != 0 {
|
||||
t.Errorf("c started early: runCalls=%d", got)
|
||||
}
|
||||
|
||||
// Release A. The swapDone handler should drain the queue and start
|
||||
// both B and C in parallel.
|
||||
a.markReady()
|
||||
waitProcessed(t, b.testProcessed, 1) // swapDone(A) → drainQueue starts B and C
|
||||
<-pb.runStarted
|
||||
<-pc.runStarted
|
||||
|
||||
pb.markReady()
|
||||
pc.markReady()
|
||||
|
||||
for i, ch := range []chan struct{}{done1, done2, done3} {
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("request %d did not complete", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_Shutdown_FailsAllInFlight verifies that shutdown returns
|
||||
// the shutdown error to every waiter on every active swap AND to every
|
||||
// queued request.
|
||||
func TestBaseRouter_Shutdown_FailsAllInFlight(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
pc := newFakeProcess("c")
|
||||
|
||||
// a and b load in parallel (empty evicts). c collides with both.
|
||||
planner := &stubPlanner{evict: map[string][]string{
|
||||
"c": {"a", "b"},
|
||||
}}
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner)
|
||||
|
||||
const waitersPer = 2
|
||||
var wg sync.WaitGroup
|
||||
codes := make([]int, 0, 2*waitersPer+1)
|
||||
var codesMu sync.Mutex
|
||||
record := func(code int) {
|
||||
codesMu.Lock()
|
||||
codes = append(codes, code)
|
||||
codesMu.Unlock()
|
||||
}
|
||||
|
||||
launch := func(model string) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
w := httptest.NewRecorder()
|
||||
b.ServeHTTP(w, newRequest(model))
|
||||
record(w.Code)
|
||||
}()
|
||||
}
|
||||
|
||||
// Active swaps for a and b, each with 2 waiters.
|
||||
for i := 0; i < waitersPer; i++ {
|
||||
launch("a")
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
}
|
||||
for i := 0; i < waitersPer; i++ {
|
||||
launch("b")
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
}
|
||||
// c collides with both → queues.
|
||||
launch("c")
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
<-a.runStarted
|
||||
<-pb.runStarted
|
||||
|
||||
if err := b.Shutdown(time.Second); err != nil {
|
||||
t.Fatalf("Shutdown: %v", err)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
codesMu.Lock()
|
||||
defer codesMu.Unlock()
|
||||
if len(codes) != 2*waitersPer+1 {
|
||||
t.Fatalf("got %d responses, want %d", len(codes), 2*waitersPer+1)
|
||||
}
|
||||
for i, c := range codes {
|
||||
if c == http.StatusOK {
|
||||
t.Errorf("response %d: status=%d, want non-200 (shutdown)", i, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_NoSwapWhileServing verifies that an already-loaded model
|
||||
// is not stopped to satisfy another model's swap while it is still handling
|
||||
// a request.
|
||||
//
|
||||
// Sequence:
|
||||
// 1. r1 (A) — A loads; ServeHTTP enters and is pinned via serveBlock.
|
||||
// 2. r2 (B, planner: B evicts A) — must NOT cause A.Stop while r1 is live.
|
||||
// 3. r3 (A) — arrives next; the existing code queues it because B's swap
|
||||
// intent collides with A.
|
||||
// 4. r1 released — A finishes r1, then r3 is served by A.
|
||||
// 5. B's swap then proceeds; r2 is served by B.
|
||||
//
|
||||
// fakeProcess.stoppedWhileServing flips true if Stop is ever called while
|
||||
// a ServeHTTP is in flight — a direct, race-free signal of the violation.
|
||||
func TestBaseRouter_NoSwapWhileServing(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
// autoReady left false: we markReady manually after observing runStarted,
|
||||
// so autoReady's setState(Ready) cannot race with a later Stop and leave
|
||||
// A in Ready, masking the bug.
|
||||
a.serveBlock = make(chan struct{})
|
||||
pb := newFakeProcess("b")
|
||||
// Same reasoning for B: park its swap on WaitReady until we choose.
|
||||
|
||||
planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}}
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, planner)
|
||||
|
||||
// r1 — load A and enter its ServeHTTP (which blocks on serveBlock).
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1) // handlerReq for r1
|
||||
<-a.runStarted
|
||||
a.markReady()
|
||||
waitProcessed(t, b.testProcessed, 1) // swapDone for A
|
||||
<-a.serveStarted
|
||||
|
||||
// r2 — would evict A. A must not be stopped while r1 is in flight.
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
// r3 — another request for A, arrives behind r2 and queues because
|
||||
// B's swap intent (which evicts A) is recorded as active.
|
||||
w3 := httptest.NewRecorder()
|
||||
done3 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w3, newRequest("a"))
|
||||
close(done3)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
// Release r1 (and r3 if it is fast-pathed onto the still-loaded A).
|
||||
// The router must hold off B's swap until A has drained.
|
||||
close(a.serveBlock)
|
||||
|
||||
select {
|
||||
case <-done1:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("r1 did not complete after serveBlock release")
|
||||
}
|
||||
|
||||
// Wait for B.Run before marking it ready: markReady before Run would
|
||||
// skip the Run path entirely and leave pb.runCalls at 0. In a correct
|
||||
// implementation B's swap only starts after A has drained; in the
|
||||
// current implementation it has already started — either way runStarted
|
||||
// fires.
|
||||
<-pb.runStarted
|
||||
pb.markReady()
|
||||
|
||||
select {
|
||||
case <-done2:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("r2 did not complete after B marked ready")
|
||||
}
|
||||
select {
|
||||
case <-done3:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("r3 did not complete")
|
||||
}
|
||||
|
||||
if w1.Code != http.StatusOK || w2.Code != http.StatusOK || w3.Code != http.StatusOK {
|
||||
t.Fatalf("statuses: w1=%d w2=%d w3=%d", w1.Code, w2.Code, w3.Code)
|
||||
}
|
||||
if w1.Body.String() != "ok:a" {
|
||||
t.Errorf("r1 body=%q want ok:a", w1.Body.String())
|
||||
}
|
||||
if w3.Body.String() != "ok:a" {
|
||||
t.Errorf("r3 body=%q want ok:a (r3 must be served by A)", w3.Body.String())
|
||||
}
|
||||
if w2.Body.String() != "ok:b" {
|
||||
t.Errorf("r2 body=%q want ok:b", w2.Body.String())
|
||||
}
|
||||
if a.stoppedWhileServing.Load() {
|
||||
t.Errorf("A.Stop was called while A was still handling a request — the router swapped out a busy process")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_ModelNotFound(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||
|
||||
@@ -0,0 +1,404 @@
|
||||
# Router design
|
||||
|
||||
A developer tutorial for the `internal/router` package and its `scheduler`
|
||||
sub-package.
|
||||
|
||||
## Intro
|
||||
|
||||
A llama-swap router is the component that sits behind the proxy and answers one
|
||||
question for every incoming request: _can this model serve right now, and if
|
||||
not, what has to happen first?_ Answering it means juggling three concerns that
|
||||
used to live tangled together in one type:
|
||||
|
||||
1. **Process machinery** — owning the OS processes, starting and stopping them,
|
||||
running health checks, and shuttling HTTP requests onto the right upstream.
|
||||
2. **Scheduling strategy** — the queue, in-flight bookkeeping, and the decision
|
||||
tree that turns one request into "serve now", "join an existing swap",
|
||||
"queue", or "start a swap".
|
||||
3. **Eviction policy** — given a model we want to load, which currently-running
|
||||
models have to be stopped to make room?
|
||||
|
||||
The design pulls those three apart into separate, independently replaceable
|
||||
pieces:
|
||||
|
||||
| Concern | Type | Lives in |
|
||||
| ------------------- | ------------------------------ | ------------------------------- |
|
||||
| Process machinery | `baseRouter` | `internal/router/base.go` |
|
||||
| Scheduling strategy | `scheduler.Scheduler` (`FIFO`) | `internal/router/scheduler/` |
|
||||
| Eviction policy | `scheduler.Swapper` | `groupSwapper`, `matrixSwapper` |
|
||||
|
||||
`baseRouter` keeps the channels, run loop, process lifecycle, and shutdown
|
||||
teardown, and exposes the side-effects a scheduler needs through the
|
||||
`scheduler.Effects` interface. The scheduler owns the queue and decision tree
|
||||
but performs no side-effects directly — it calls back through `Effects`. The
|
||||
`Swapper` is a pure function from "target model + currently running" to "models
|
||||
to evict", and knows nothing about queues, channels, or processes.
|
||||
|
||||
Because the seams are interfaces, you can replace the scheduling strategy
|
||||
without touching process management, or write a new eviction policy without
|
||||
touching either. `FIFO` is the first and currently only `Scheduler`;
|
||||
`groupSwapper` and `matrixSwapper` are the two `Swapper`s.
|
||||
|
||||
## Key concepts
|
||||
|
||||
### One run loop, no locks
|
||||
|
||||
`baseRouter.run()` is a single goroutine selecting over a handful of channels:
|
||||
|
||||
```go
|
||||
for {
|
||||
select {
|
||||
case req := <-b.shutdownCh: b.handleShutdown(req); return
|
||||
case req := <-b.handlerCh: b.schedule.OnRequest(req)
|
||||
case req := <-b.unloadCh: b.schedule.OnUnload(req.targets, req.timeout); close(req.respond)
|
||||
case ev := <-b.swapDoneCh: b.schedule.OnSwapDone(ev)
|
||||
case ev := <-b.serveDoneCh: b.schedule.OnServeDone(ev)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Every `Scheduler` method runs on this one goroutine. That is the single most
|
||||
important fact about the design: **the scheduler never needs a mutex for its own
|
||||
state**. All scheduler state is touched only from these callbacks, which are
|
||||
serialized by the run loop. If you write a new scheduler, you get the same
|
||||
guarantee for free — and you must not break it by spinning up goroutines that
|
||||
mutate scheduler state.
|
||||
|
||||
### Events flow in, side-effects flow out
|
||||
|
||||
The run loop turns external happenings into method calls on the scheduler:
|
||||
|
||||
- A new HTTP request becomes `OnRequest(HandlerReq)`.
|
||||
- A swap goroutine finishing becomes `OnSwapDone(SwapDone)`.
|
||||
- A tracked request handler returning becomes `OnServeDone(ServeDoneEvent)`.
|
||||
- An admin unload becomes `OnUnload(targets, timeout)`.
|
||||
- Shutdown becomes `OnShutdown(err)`.
|
||||
|
||||
The scheduler reacts by calling **back out** through `Effects`: inspect a
|
||||
process state, start a swap, grant a response to a caller, or stop processes. It
|
||||
never calls `process.Process` directly and never writes to a channel directly.
|
||||
This keeps the scheduler pure enough to unit-test against a fake `Effects` with
|
||||
no goroutines or real processes involved (see `scheduler/fifo_test.go`).
|
||||
|
||||
```
|
||||
HTTP request admin Unload / Shutdown
|
||||
│ │
|
||||
▼ ▼
|
||||
ServeHTTP ──HandlerReq──▶ baseRouter.run() ◀──unloadCh/shutdownCh
|
||||
│ (single goroutine)
|
||||
▼
|
||||
Scheduler.On*(...)
|
||||
│ calls back through
|
||||
▼
|
||||
Effects: ModelState / StartSwap /
|
||||
GrantServe / GrantError / StopProcesses
|
||||
│
|
||||
▼
|
||||
baseRouter side-effects: doSwap goroutine,
|
||||
grant() to caller, process.Stop()
|
||||
│
|
||||
swap completes ──SwapDone──▶ back into run loop
|
||||
```
|
||||
|
||||
### The swap goroutine
|
||||
|
||||
Scheduling decisions must be quick and non-blocking, but loading a model is
|
||||
slow. The two are reconciled by doing the slow part on a separate goroutine.
|
||||
|
||||
When the scheduler decides to start a swap, inside `OnRequest` it:
|
||||
|
||||
1. records "a swap for X is in flight" in its own state, then
|
||||
2. calls `Effects.StartSwap(modelID, evict)`.
|
||||
|
||||
`StartSwap` does **not** load the model itself — it just launches a detached
|
||||
goroutine (`doSwap`) and returns straight away. `doSwap` is what does the slow
|
||||
work: stop the evicted processes, start the target, wait for it to become ready.
|
||||
Because `StartSwap` returned immediately, `OnRequest` returns too, and the run
|
||||
loop is free to pick up the next event — another request, a serve-done, an
|
||||
unload — while `doSwap` runs in the background.
|
||||
|
||||
The swap's eventual result comes back as just another event: when `doSwap`
|
||||
finishes it posts a `SwapDone` onto `swapDoneCh`, which the run loop delivers as
|
||||
`OnSwapDone`. So a slow load never blocks the run loop; it brackets it with two
|
||||
quick events (`OnRequest` to start, `OnSwapDone` to finish) and everything in
|
||||
between is handled normally.
|
||||
|
||||
### In-flight tracking and `trackedServe`
|
||||
|
||||
When the scheduler grants a request, the handler it hands back is wrapped by
|
||||
`baseRouter.trackedServe`. The wrapper runs the real `ServeHTTP` and, on return,
|
||||
posts a `ServeDoneEvent` so the run loop can decrement the per-model in-flight
|
||||
count. This is why the scheduler can know whether a process is "busy": it counts
|
||||
grants out and serve-dones in. A swap that would evict a busy process is
|
||||
deferred until that process's in-flight count hits zero (`OnServeDone` then
|
||||
re-drains the queue).
|
||||
|
||||
The subtle contract here is `GrantServe`'s boolean return. The caller's
|
||||
`Respond` channel is unbuffered, so a successful send proves the HTTP goroutine
|
||||
is alive and took the handler. If the caller already disconnected, the send
|
||||
fails, `trackedServe` never runs, and **no** `ServeDoneEvent` will ever arrive —
|
||||
so the scheduler must only increment `inFlight` when `GrantServe` returns true.
|
||||
Incrementing on a false return would strand the counter above zero and the model
|
||||
could never be evicted again.
|
||||
|
||||
## The interfaces
|
||||
|
||||
All three live in `scheduler/scheduler.go`.
|
||||
|
||||
### `Scheduler`
|
||||
|
||||
```go
|
||||
type Scheduler interface {
|
||||
OnRequest(req HandlerReq)
|
||||
OnSwapDone(ev SwapDone)
|
||||
OnServeDone(ev ServeDoneEvent)
|
||||
OnUnload(targets []string, timeout time.Duration)
|
||||
OnShutdown(err error)
|
||||
}
|
||||
```
|
||||
|
||||
Owns the queue, in-flight tracking, and the decision tree. All methods run on
|
||||
the run-loop goroutine, so no internal locking is needed.
|
||||
|
||||
### `Swapper`
|
||||
|
||||
```go
|
||||
type Swapper interface {
|
||||
EvictionFor(target string, running []string) []string
|
||||
OnSwapStart(target string, running []string)
|
||||
}
|
||||
```
|
||||
|
||||
The eviction policy. `EvictionFor` is a **pure decision** — given the target and
|
||||
the complete `running` set, return the running model IDs that must stop. It must
|
||||
not log or mutate anything, and it does **not** inspect process state itself:
|
||||
the scheduler hands it `running` already assembled (every non-stopped process,
|
||||
unioned with the targets of in-flight swaps already committed but not yet
|
||||
visible in process state). That keeps the swapper a pure function of its inputs,
|
||||
with no reference to processes.
|
||||
|
||||
The reason it must not log is that it is a _speculative_ query — "what would we
|
||||
evict if we started this swap right now?" — called far more often than swaps
|
||||
actually happen. The scheduler calls it once per incoming request, and then
|
||||
**again for every still-queued request on every queue drain** (each `OnSwapDone`,
|
||||
`OnServeDone`, and `OnUnload`). Most of those calls end in "still queued",
|
||||
"collides", or "nothing to evict", not a real swap. Logging there would emit
|
||||
duplicate lines for a request that simply sits in the queue, and lines for
|
||||
decisions that never happen — the log would stop meaning "a swap occurred".
|
||||
|
||||
`OnSwapStart` is the one place a Swapper may log, because it is called exactly
|
||||
once, at the moment a swap is committed. One log line there equals one real swap,
|
||||
with the evict set that is genuinely being applied — which is why `matrixSwapper`
|
||||
re-solves and logs the full decision (set, DSL, cost) in `OnSwapStart` rather
|
||||
than in `EvictionFor`.
|
||||
|
||||
### `Effects`
|
||||
|
||||
```go
|
||||
type Effects interface {
|
||||
ModelState(modelID string) (process.ProcessState, bool)
|
||||
RunningModels() map[string]process.ProcessState
|
||||
StartSwap(modelID string, evict []string)
|
||||
GrantError(req HandlerReq, err error)
|
||||
GrantServe(req HandlerReq, modelID string) bool
|
||||
StopProcesses(timeout time.Duration, ids []string)
|
||||
}
|
||||
```
|
||||
|
||||
Implemented by `baseRouter`. This is the scheduler's entire window onto the
|
||||
outside world; everything else about the router is hidden from it. See the
|
||||
deep-dive below.
|
||||
|
||||
### `Factory` — wiring it together
|
||||
|
||||
```go
|
||||
type Factory func(name string, logger *logmon.Monitor, eff Effects) Scheduler
|
||||
```
|
||||
|
||||
`baseRouter` doesn't know which scheduler or swapper it has — it is handed a
|
||||
`Factory` at construction and calls it once, passing itself as the `Effects`.
|
||||
The concrete router captures its `Swapper` in the closure. From `group.go`:
|
||||
|
||||
```go
|
||||
swapper := &groupSwapper{ /* ... */ }
|
||||
base := newBaseRouter("group", conf, processes, proxylog,
|
||||
func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler {
|
||||
return scheduler.NewFIFO(name, logger, swapper, eff)
|
||||
})
|
||||
```
|
||||
|
||||
This closure is the single point where the three pieces meet: it binds a
|
||||
specific `Swapper` (`swapper`) and a specific `Scheduler` (`FIFO`) to the
|
||||
`baseRouter`'s `Effects` (`eff`).
|
||||
|
||||
**The swapper is a separate type from the concrete router.** There are currently two router implementations router.Group and router.Matrix. Each of these has a custom swapper that implements scheduler.Swapper for custom eviction logic. This decoupling of responsibilities makes it easy to implement custom swapping strategies.
|
||||
|
||||
### The events
|
||||
|
||||
A single goroutine in `baseRouter.run()` owns and serializes all state changes in the router. By processing events one at a time it ensures correctness and eliminates complex mutex lock logic.
|
||||
|
||||
These are the events the router currently uses:
|
||||
|
||||
```go
|
||||
type HandlerReq struct { // one in-flight ServeHTTP awaiting a decision
|
||||
Model string
|
||||
Ctx context.Context
|
||||
Respond chan HandlerResp // UNBUFFERED — see GrantServe contract
|
||||
PositionCh chan int // queue-position updates for the loading UI
|
||||
}
|
||||
|
||||
type HandlerResp struct { // the decision handed back to the caller
|
||||
HandleFunc http.HandlerFunc // serve with this, or...
|
||||
Err error // ...fail with this
|
||||
}
|
||||
|
||||
type SwapDone struct{ ModelID string; Err error } // swap goroutine finished
|
||||
type ServeDoneEvent struct{ ModelID string } // tracked handler returned
|
||||
```
|
||||
|
||||
## Deep-dive: the `Effects` interface and why it exists
|
||||
|
||||
`Effects` is the inversion-of-control boundary that makes the split possible.
|
||||
The scheduler decides and `baseRouter` _acts_. Pulling the side-effects behind this
|
||||
interface buys three things:
|
||||
|
||||
1. **Purity and testability.** The scheduler performs no I/O, starts no
|
||||
goroutines of its own, and touches no real processes. Its tests drive the
|
||||
`On*` methods directly and assert on a `fakeEffects` that just records the
|
||||
calls — synchronous, deterministic, no sleeps. (`scheduler/fifo_test.go`.)
|
||||
2. **A single, auditable side-effect surface.** Every externally-visible thing a
|
||||
scheduler can do is one of six methods. You can reason about the whole
|
||||
contract by reading one interface.
|
||||
3. **Decoupling lifetime.** The scheduler never holds a `process.Process`,
|
||||
never sees a channel, and never learns how shutdown teardown works. It only
|
||||
knows model IDs and states.
|
||||
|
||||
Method by method, as implemented in `base.go`:
|
||||
|
||||
- **`ModelState(modelID) (state, ok)`** — read-only snapshot of a process's
|
||||
state, and whether this router handles the model at all. The scheduler uses it
|
||||
for the "unknown model" check and the "already ready" fast path. Safe to call
|
||||
any time because the process map is fixed at construction and `State()` is a
|
||||
snapshot.
|
||||
|
||||
- **`RunningModels()`** — the state of every process that isn't stopped or shut
|
||||
down. The scheduler unions its keys with its own in-flight swap targets to
|
||||
build the `running` set it hands the `Swapper`, so the swapper never has to
|
||||
touch process state itself.
|
||||
|
||||
- **`StartSwap(modelID, evict)`** — fire-and-forget. `baseRouter` launches the
|
||||
`doSwap` goroutine and returns immediately; the result comes back later as a
|
||||
`SwapDone`. The scheduler records the swap as active _before_ calling this so
|
||||
that requests arriving in the meantime can join it.
|
||||
|
||||
- **`GrantError(req, err)`** — hand a caller an error response. Used for unknown
|
||||
models, failed swaps, unloads, and shutdown.
|
||||
|
||||
- **`GrantServe(req, modelID) bool`** — hand a caller the tracked handler for a
|
||||
ready model, returning whether the caller was still there to receive it. The
|
||||
scheduler increments the in-flight count **only on a true return** (see the
|
||||
in-flight contract above). This is the one `Effects` method whose return value
|
||||
carries state-machine significance.
|
||||
|
||||
- **`StopProcesses(timeout, ids)`** — stop processes in parallel and **block**
|
||||
until all have stopped. Used by `OnUnload` so an admin `Unload` call can
|
||||
guarantee the process is dead by the time it returns. (Note `StartSwap` is
|
||||
async but `StopProcesses` is sync — the difference is deliberate and tied to
|
||||
the caller's expectations.)
|
||||
|
||||
A useful way to hold it in your head: `Effects` is the scheduler's syscall
|
||||
table. The scheduler is a pure state machine; `Effects` is how it touches the
|
||||
world, and `baseRouter` is the kernel that implements those syscalls with real
|
||||
goroutines, channels, and processes.
|
||||
|
||||
## How to implement a new `Swapper`
|
||||
|
||||
A `Swapper` is a pure decision function plus a logging hook — the easiest of the three pieces to replace.
|
||||
|
||||
1. **Write the swapper type** and give it whatever config it needs to make a
|
||||
decision. It does **not** need the process map — the scheduler supplies the
|
||||
running set as an argument. `groupSwapper` holds only its group config;
|
||||
`matrixSwapper` holds only its solver and logger:
|
||||
|
||||
```go
|
||||
type mySwapper struct {
|
||||
config config.Config
|
||||
}
|
||||
```
|
||||
|
||||
2. **Implement `EvictionFor(target, running)`** as a _pure_ decision:
|
||||
- `running` is the complete live set, already assembled for you: every
|
||||
non-stopped process unioned with the targets of in-flight swaps the
|
||||
scheduler has committed to. You don't filter process state or fold in
|
||||
in-flight targets yourself, that's the scheduler's job. Just decide against the slice you're handed.
|
||||
- Return the list of model IDs in `running` that must stop for `target` to
|
||||
run. Return `nil`/empty when nothing needs evicting.
|
||||
- Do **not** mutate state here.
|
||||
- Do **not** log here. It can be called multiple times per request. Since it is pure function have tests verify the expected behaviour.
|
||||
|
||||
3. **Implement `OnSwapStart(target, running)`** — called once when a swap
|
||||
actually begins, with the same `running` set `EvictionFor` saw. This is the
|
||||
right place to log: one call equals one real swap. `matrixSwapper` re-solves
|
||||
and logs the chosen set and cost here; `groupSwapper` logs nothing.
|
||||
|
||||
4. **Wire it in** by instantiating the swapper in your router's constructor and
|
||||
capturing it in the `Factory` closure passed to `newBaseRouter` — exactly as
|
||||
`NewGroup` and `NewMatrix` do. The router struct itself only ever embeds
|
||||
`*baseRouter`; the swapper reaches the scheduler solely through that closure.
|
||||
|
||||
Reference implementations: `groupSwapper` (static group config) in `group.go`
|
||||
and `matrixSwapper` (cost-based set solver) in `matrix.go`.
|
||||
|
||||
## How to implement a new `Scheduler`
|
||||
|
||||
Replacing the scheduler means taking over the queue and the entire decision tree. Read `scheduler/fifo.go` end to end first — it is the reference implementation and the rules below are easiest to understand in context.
|
||||
|
||||
The rules you must honour:
|
||||
|
||||
- **Single goroutine.** Every method runs on the `baseRouter.run()` goroutine. Keep your state in plain maps/slices and never read or write it from another goroutine. If you need slow work done, hand it to `Effects.StartSwap` and react to the resulting `SwapDone` — do not block a method waiting for it.
|
||||
|
||||
- **Never block the run loop.** `OnRequest`, `OnSwapDone`, and `OnServeDone` must make a decision and return. The one method allowed to block is `OnUnload`, and only because it must wait on the synchronous `StopProcesses` so the admin caller's guarantee holds.
|
||||
|
||||
- **Respect the `GrantServe` boolean.** Only count a request as in-flight when `GrantServe` returns true (see the in-flight contract above). A false return means the caller is gone; no `ServeDoneEvent` will ever arrive, so incrementing on false permanently strands the counter.
|
||||
|
||||
- **Account for in-flight swaps in your running set.** When you call `Swapper.EvictionFor`, the running set you pass must include not just live processes (`Effects.RunningModels`) but also the targets of swaps you've already started that aren't yet visible in process state — otherwise the swapper contradicts decisions already in motion.
|
||||
|
||||
What each method must do:
|
||||
|
||||
- **`OnRequest(req)`** — every request must resolve to exactly one of: granted, errored, joined (piggybacks an in-flight swap), queued, or swap-started. No request may be silently dropped.
|
||||
|
||||
- **`OnSwapDone(ev)`** — deliver the result to every waiter that joined this swap (grant on success, error on `ev.Err`), drop the swap from active tracking, then re-examine anything queued — a finished swap may have unblocked it.
|
||||
|
||||
- **`OnServeDone(ev)`** — decrement the model's in-flight count; when it hits zero, re-examine the queue. Do **not** clear in-flight counts by hand; the handlers post their own `ServeDoneEvent`s on return.
|
||||
|
||||
- **`OnUnload(targets, timeout)`** — error out any waiters or queued requests for the unloaded models, call `Effects.StopProcesses` (synchronously — the admin caller relies on the process being dead afterwards), then re-examine the queue.
|
||||
|
||||
- **`OnShutdown(err)`** — error out every waiter you still hold (active swap waiters and queued requests). Don't touch processes; teardown is `baseRouter`'s job.
|
||||
|
||||
Expose a constructor matching the `Factory` shape:
|
||||
|
||||
```go
|
||||
func NewMyScheduler(name string, logger *logmon.Monitor, swapper Swapper, eff Effects) *MyScheduler {
|
||||
// ...
|
||||
}
|
||||
|
||||
// in the concrete router:
|
||||
base := newBaseRouter(name, conf, processes, proxylog,
|
||||
func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler {
|
||||
return scheduler.NewMyScheduler(name, logger, swapper, eff)
|
||||
})
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
- **Schedulers** are tested as pure state machines in the `scheduler` package:
|
||||
drive the `On*` methods directly against a `fakeEffects` and assert on the
|
||||
recorded grants/starts/stops. No goroutines, no sleeps. See
|
||||
`scheduler/fifo_test.go` as the reference; follow the `TestSchedulerName_<scenario>`
|
||||
naming convention.
|
||||
- **`baseRouter` mechanism** (run loop, `grant`/`ServeHTTP`, `Unload`,
|
||||
`Shutdown`) is tested in `base_test.go`. The run loop exposes a
|
||||
`testProcessed` channel so tests can wait for an event to be fully processed
|
||||
instead of sleeping.
|
||||
- Run new tests with `go test -v -run TestMyScheduler_... ./internal/router/scheduler/`,
|
||||
then `make test-dev` for a quick `go test` + `staticcheck` pass over `proxy/`.
|
||||
+16
-20
@@ -14,7 +14,7 @@ type Group struct {
|
||||
|
||||
func NewGroup(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Group, error) {
|
||||
modelToGroup := make(map[string]string)
|
||||
for gid, gcfg := range conf.Groups {
|
||||
for gid, gcfg := range conf.Routing.Router.Settings.Groups {
|
||||
for _, mid := range gcfg.Members {
|
||||
if existing, dup := modelToGroup[mid]; dup {
|
||||
return nil, fmt.Errorf("model %q is in multiple groups: %q and %q", mid, existing, gid)
|
||||
@@ -23,25 +23,29 @@ func NewGroup(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Group
|
||||
}
|
||||
}
|
||||
|
||||
planner := &groupPlanner{
|
||||
swapper := &groupSwapper{
|
||||
config: conf,
|
||||
modelToGroup: modelToGroup,
|
||||
}
|
||||
|
||||
processes := make(map[string]process.Process, len(modelToGroup))
|
||||
base := newBaseRouter("group", conf, processes, planner, proxylog)
|
||||
planner.processes = processes
|
||||
base, err := newBaseRouter("group", conf, processes, proxylog, swapper)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating base router: %w", err)
|
||||
}
|
||||
|
||||
for mid := range modelToGroup {
|
||||
modelCfg, _, ok := conf.FindConfig(mid)
|
||||
if !ok {
|
||||
base.shutdownFn()
|
||||
base.procCancel()
|
||||
return nil, fmt.Errorf("no model config for %q", mid)
|
||||
}
|
||||
procLog := logmon.NewWriter(upstreamlog)
|
||||
p, err := process.New(base.shutdownCtx, mid, modelCfg, procLog, proxylog)
|
||||
p, err := process.New(base.procCtx, mid, modelCfg, procLog, proxylog)
|
||||
if err != nil {
|
||||
base.shutdownFn()
|
||||
base.procCancel()
|
||||
return nil, fmt.Errorf("creating process for %q: %w", mid, err)
|
||||
}
|
||||
processes[mid] = p
|
||||
@@ -52,21 +56,20 @@ func NewGroup(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Group
|
||||
return g, nil
|
||||
}
|
||||
|
||||
// groupPlanner decides evictions from static group configuration.
|
||||
// groupSwapper decides evictions from static group configuration.
|
||||
//
|
||||
// Same-group siblings are stopped when the group has swap=true. Cross-group
|
||||
// members are stopped only when the target's group is exclusive; loading a
|
||||
// model from a non-exclusive group leaves running exclusive groups alone,
|
||||
// matching the gotcha in the original ProcessGroup behaviour.
|
||||
type groupPlanner struct {
|
||||
type groupSwapper struct {
|
||||
config config.Config
|
||||
modelToGroup map[string]string
|
||||
processes map[string]process.Process
|
||||
}
|
||||
|
||||
func (p *groupPlanner) EvictionFor(target string, alsoRunning []string) []string {
|
||||
func (p *groupSwapper) EvictionFor(target string, running []string) []string {
|
||||
tg := p.modelToGroup[target]
|
||||
tgCfg := p.config.Groups[tg]
|
||||
tgCfg := p.config.Routing.Router.Settings.Groups[tg]
|
||||
|
||||
seen := make(map[string]struct{})
|
||||
var result []string
|
||||
@@ -87,24 +90,17 @@ func (p *groupPlanner) EvictionFor(target string, alsoRunning []string) []string
|
||||
// for backwards compatibility. The newer swap matrix approach does not
|
||||
// have this issue.
|
||||
case og != tg && tgCfg.Exclusive:
|
||||
if ogCfg := p.config.Groups[og]; !ogCfg.Persistent {
|
||||
if ogCfg := p.config.Routing.Router.Settings.Groups[og]; !ogCfg.Persistent {
|
||||
seen[mID] = struct{}{}
|
||||
result = append(result, mID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for mID, proc := range p.processes {
|
||||
st := proc.State()
|
||||
if st == process.StateStopped || st == process.StateShutdown {
|
||||
continue
|
||||
}
|
||||
consider(mID)
|
||||
}
|
||||
for _, mID := range alsoRunning {
|
||||
for _, mID := range running {
|
||||
consider(mID)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (p *groupPlanner) OnSwapStart(target string) {}
|
||||
func (p *groupSwapper) OnSwapStart(target string, running []string) {}
|
||||
|
||||
@@ -17,17 +17,19 @@ import (
|
||||
func newTestGroup(t *testing.T, conf config.Config, processes map[string]process.Process) *Group {
|
||||
t.Helper()
|
||||
modelToGroup := make(map[string]string)
|
||||
for gid, gcfg := range conf.Groups {
|
||||
for gid, gcfg := range conf.Routing.Router.Settings.Groups {
|
||||
for _, mid := range gcfg.Members {
|
||||
modelToGroup[mid] = gid
|
||||
}
|
||||
}
|
||||
planner := &groupPlanner{
|
||||
swapper := &groupSwapper{
|
||||
config: conf,
|
||||
modelToGroup: modelToGroup,
|
||||
processes: processes,
|
||||
}
|
||||
base := newBaseRouter("group", conf, processes, planner, logmon.NewWriter(io.Discard))
|
||||
base, err := newBaseRouter("group", conf, processes, logmon.NewWriter(io.Discard), swapper)
|
||||
if err != nil {
|
||||
t.Fatalf("newBaseRouter: %v", err)
|
||||
}
|
||||
base.testProcessed = make(chan struct{}, 64)
|
||||
g := &Group{baseRouter: base}
|
||||
go base.run()
|
||||
@@ -41,10 +43,10 @@ func newTestGroup(t *testing.T, conf config.Config, processes map[string]process
|
||||
|
||||
func TestGroup_NewGroup_DuplicateMembership(t *testing.T) {
|
||||
conf := config.Config{
|
||||
Groups: map[string]config.GroupConfig{
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"g1": {Swap: true, Members: []string{"a"}},
|
||||
"g2": {Swap: true, Members: []string{"a"}},
|
||||
},
|
||||
}),
|
||||
Models: map[string]config.ModelConfig{
|
||||
"a": {},
|
||||
},
|
||||
@@ -65,9 +67,9 @@ func TestGroup_ServeHTTP_SwapStopsPrevious(t *testing.T) {
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Groups: map[string]config.GroupConfig{
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"g": {Swap: true, Exclusive: true, Members: []string{"a", "b"}},
|
||||
},
|
||||
}),
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
@@ -97,9 +99,9 @@ func TestGroup_NonSwapGroup_NoStop(t *testing.T) {
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Groups: map[string]config.GroupConfig{
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"g": {Swap: false, Exclusive: false, Members: []string{"a", "b"}},
|
||||
},
|
||||
}),
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
@@ -127,10 +129,10 @@ func TestGroup_CrossGroupExclusive(t *testing.T) {
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Groups: map[string]config.GroupConfig{
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"g1": {Swap: true, Exclusive: true, Members: []string{"a"}},
|
||||
"g2": {Swap: true, Exclusive: true, Members: []string{"b"}},
|
||||
},
|
||||
}),
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
@@ -154,10 +156,10 @@ func TestGroup_CrossGroupNonExclusiveParallel(t *testing.T) {
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Groups: map[string]config.GroupConfig{
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"g1": {Swap: true, Exclusive: false, Members: []string{"a"}},
|
||||
"g2": {Swap: true, Exclusive: false, Members: []string{"b"}},
|
||||
},
|
||||
}),
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": pb})
|
||||
|
||||
@@ -202,16 +204,17 @@ func TestGroup_CrossGroupNonExclusiveParallel(t *testing.T) {
|
||||
|
||||
// TestGroup_SameGroupSwapSerialises verifies that two same-group requests
|
||||
// (Swap=true) serialise even when both arrive while neither has reached
|
||||
// StateStarting yet — the alsoRunning hint to the planner closes that race.
|
||||
// StateStarting yet — the in-flight swap target the scheduler folds into the
|
||||
// running set closes that race.
|
||||
func TestGroup_SameGroupSwapSerialises(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Groups: map[string]config.GroupConfig{
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"g": {Swap: true, Exclusive: false, Members: []string{"a", "b"}},
|
||||
},
|
||||
}),
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": pb})
|
||||
|
||||
@@ -224,8 +227,9 @@ func TestGroup_SameGroupSwapSerialises(t *testing.T) {
|
||||
waitProcessed(t, g.testProcessed, 1)
|
||||
|
||||
// Request B arrives before A transitions to StateStarting in the process
|
||||
// state machine. Without the alsoRunning hint, the planner would not see
|
||||
// A as running, and B would start in parallel, violating Swap=true.
|
||||
// state machine. Without folding the in-flight swap target into the running
|
||||
// set, the swapper would not see A as running, and B would start in
|
||||
// parallel, violating Swap=true.
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
@@ -269,10 +273,10 @@ func TestGroup_PersistentNotEvicted(t *testing.T) {
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Groups: map[string]config.GroupConfig{
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"persist": {Swap: true, Exclusive: false, Persistent: true, Members: []string{"a"}},
|
||||
"other": {Swap: true, Exclusive: true, Members: []string{"b"}},
|
||||
},
|
||||
}),
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
@@ -306,10 +310,10 @@ func TestGroup_NonExclusiveDoesNotUnloadExclusive(t *testing.T) {
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Groups: map[string]config.GroupConfig{
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"g1": {Swap: true, Exclusive: true, Members: []string{"a"}},
|
||||
"g2": {Swap: true, Exclusive: false, Members: []string{"b"}},
|
||||
},
|
||||
}),
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
|
||||
@@ -12,10 +12,23 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
)
|
||||
|
||||
// groupRouting builds a normalized RoutingConfig for the group router, mirroring
|
||||
// what config.LoadConfigFromReader produces. Tests use it to populate
|
||||
// config.Config.Routing without going through LoadConfig.
|
||||
func groupRouting(groups map[string]config.GroupConfig) config.RoutingConfig {
|
||||
return config.RoutingConfig{
|
||||
Router: config.RouterConfig{
|
||||
Use: "group",
|
||||
Settings: config.RouterSettings{Groups: groups},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// fakeProcess is an in-memory implementation of process.Process used to drive
|
||||
// the routers through their state machine without spawning real upstreams.
|
||||
type fakeProcess struct {
|
||||
|
||||
@@ -38,6 +38,13 @@ type loadingWriter struct {
|
||||
pendingMu sync.Mutex
|
||||
pendingUpdate string
|
||||
|
||||
// writeMu serializes writes to the underlying writer and guards released.
|
||||
// Once released is set, the streaming goroutine must not touch the writer
|
||||
// again — ServeHTTP has reclaimed it (to run the real handler or to return)
|
||||
// and writing/flushing a finalized response panics.
|
||||
writeMu sync.Mutex
|
||||
released bool
|
||||
|
||||
// closed by start when the goroutine finishes (after cleanup messages)
|
||||
done chan struct{}
|
||||
|
||||
@@ -217,12 +224,33 @@ func (s *loadingWriter) sendData(data string) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData)
|
||||
if err != nil {
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
// Once ServeHTTP has reclaimed the writer (release), writing/flushing it
|
||||
// races the real handler or panics on a finalized response. Stop here.
|
||||
if s.released {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData); err != nil {
|
||||
s.logger.Debugf("<%s> Failed to write SSE data (client likely disconnected): %v", s.modelName, err)
|
||||
return
|
||||
}
|
||||
s.Flush()
|
||||
if flusher, ok := s.writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// release fences the loadingWriter off from the underlying ResponseWriter.
|
||||
// After it returns, the streaming goroutine will not write to or flush the
|
||||
// writer again: any in-flight write completes under writeMu first, and later
|
||||
// writes short-circuit on released. The caller can then safely hand the writer
|
||||
// to the real handler or let ServeHTTP return without racing a finalized
|
||||
// response (a use-after-return Flush panics on the recycled *bufio.Writer).
|
||||
func (s *loadingWriter) release() {
|
||||
s.writeMu.Lock()
|
||||
s.released = true
|
||||
s.writeMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *loadingWriter) Header() http.Header {
|
||||
|
||||
@@ -226,69 +226,6 @@ func TestIsLoadingPath(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_Streaming_GET(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
wantStreaming bool
|
||||
}{
|
||||
{"streaming true", "model=llama3&stream=true", true},
|
||||
{"streaming false", "model=llama3&stream=false", false},
|
||||
{"no stream param", "model=llama3", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
|
||||
got, err := ExtractContext(r)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
if got.Streaming != tt.wantStreaming {
|
||||
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_Streaming_JSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantStreaming bool
|
||||
}{
|
||||
{"streaming true", `{"model":"llama3","stream":true}`, true},
|
||||
{"streaming false", `{"model":"llama3","stream":false}`, false},
|
||||
{"no stream param", `{"model":"llama3"}`, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
got, err := ExtractContext(r)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
if got.Streaming != tt.wantStreaming {
|
||||
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_Streaming_URLEncodedForm(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader("model=whisper-1&stream=true"))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
got, err := ExtractContext(r)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
if !got.Streaming {
|
||||
t.Error("Streaming should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func countSSEMessages(s string) int {
|
||||
scanner := bufio.NewScanner(strings.NewReader(s))
|
||||
count := 0
|
||||
|
||||
+18
-46
@@ -2,7 +2,6 @@ package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
@@ -14,26 +13,30 @@ type Matrix struct {
|
||||
}
|
||||
|
||||
func NewMatrix(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Matrix, error) {
|
||||
if conf.Matrix == nil {
|
||||
mtx := conf.Routing.Router.Settings.Matrix
|
||||
if mtx == nil {
|
||||
return nil, fmt.Errorf("matrix router requires a matrix configuration")
|
||||
}
|
||||
|
||||
planner := &matrixPlanner{
|
||||
solver: newMatrixSolver(conf.ExpandedSets, conf.Matrix.ResolvedEvictCosts()),
|
||||
swapper := &matrixSwapper{
|
||||
solver: newMatrixSolver(mtx.ExpandedSets, mtx.ResolvedEvictCosts()),
|
||||
logger: proxylog,
|
||||
}
|
||||
|
||||
// Build a process for every model in the config. Any model can run alone
|
||||
// even if it is not part of a set; this mirrors proxy.NewMatrix.
|
||||
processes := make(map[string]process.Process, len(conf.Models))
|
||||
base := newBaseRouter("matrix", conf, processes, planner, proxylog)
|
||||
planner.processes = processes
|
||||
base, err := newBaseRouter("matrix", conf, processes, proxylog, swapper)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating base router: %w", err)
|
||||
}
|
||||
|
||||
for mid, modelCfg := range conf.Models {
|
||||
procLog := logmon.NewWriter(upstreamlog)
|
||||
p, err := process.New(base.shutdownCtx, mid, modelCfg, procLog, proxylog)
|
||||
p, err := process.New(base.procCtx, mid, modelCfg, procLog, proxylog)
|
||||
if err != nil {
|
||||
base.shutdownFn()
|
||||
base.procCancel()
|
||||
return nil, fmt.Errorf("creating process for %q: %w", mid, err)
|
||||
}
|
||||
processes[mid] = p
|
||||
@@ -44,20 +47,18 @@ func NewMatrix(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Matr
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// matrixPlanner decides evictions by asking the matrix solver against the
|
||||
// current running set.
|
||||
type matrixPlanner struct {
|
||||
solver *matrixSolver
|
||||
processes map[string]process.Process
|
||||
logger *logmon.Monitor
|
||||
// matrixSwapper decides evictions by asking the matrix solver against the
|
||||
// running set the scheduler hands it.
|
||||
type matrixSwapper struct {
|
||||
solver *matrixSolver
|
||||
logger *logmon.Monitor
|
||||
}
|
||||
|
||||
func (p *matrixPlanner) EvictionFor(target string, alsoRunning []string) []string {
|
||||
return p.solver.Solve(target, p.runningSet(alsoRunning)).Evict
|
||||
func (p *matrixSwapper) EvictionFor(target string, running []string) []string {
|
||||
return p.solver.Solve(target, running).Evict
|
||||
}
|
||||
|
||||
func (p *matrixPlanner) OnSwapStart(target string) {
|
||||
running := p.runningModels()
|
||||
func (p *matrixSwapper) OnSwapStart(target string, running []string) {
|
||||
result := p.solver.Solve(target, running)
|
||||
switch {
|
||||
case len(result.Evict) > 0:
|
||||
@@ -69,32 +70,3 @@ func (p *matrixPlanner) OnSwapStart(target string) {
|
||||
p.logger.Debugf("matrix: model=%s already running in set=%s dsl=%q", target, result.SetName, result.DSL)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *matrixPlanner) runningModels() []string {
|
||||
return p.runningSet(nil)
|
||||
}
|
||||
|
||||
// runningSet returns the union of live processes (State != Stopped/Shutdown)
|
||||
// and any extra IDs the baseRouter has already committed to loading but which
|
||||
// the process state machine has not yet reflected.
|
||||
func (p *matrixPlanner) runningSet(alsoRunning []string) []string {
|
||||
seen := make(map[string]struct{}, len(p.processes))
|
||||
var running []string
|
||||
for id, proc := range p.processes {
|
||||
st := proc.State()
|
||||
if st == process.StateStopped || st == process.StateShutdown {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
running = append(running, id)
|
||||
}
|
||||
for _, id := range alsoRunning {
|
||||
if _, dup := seen[id]; dup {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
running = append(running, id)
|
||||
}
|
||||
sort.Strings(running)
|
||||
return running
|
||||
}
|
||||
|
||||
@@ -17,12 +17,14 @@ import (
|
||||
func newTestMatrix(t *testing.T, conf config.Config, expanded []config.ExpandedSet, evictCosts map[string]int, processes map[string]process.Process) *Matrix {
|
||||
t.Helper()
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
planner := &matrixPlanner{
|
||||
solver: newMatrixSolver(expanded, evictCosts),
|
||||
processes: processes,
|
||||
logger: logger,
|
||||
swapper := &matrixSwapper{
|
||||
solver: newMatrixSolver(expanded, evictCosts),
|
||||
logger: logger,
|
||||
}
|
||||
base, err := newBaseRouter("matrix", conf, processes, logger, swapper)
|
||||
if err != nil {
|
||||
t.Fatalf("newBaseRouter: %v", err)
|
||||
}
|
||||
base := newBaseRouter("matrix", conf, processes, planner, logger)
|
||||
base.testProcessed = make(chan struct{}, 64)
|
||||
r := &Matrix{baseRouter: base}
|
||||
go base.run()
|
||||
@@ -153,8 +155,8 @@ func TestMatrix_CoexistingSetParallel(t *testing.T) {
|
||||
|
||||
// TestMatrix_IncompatibleQueues verifies that the second request for a model
|
||||
// that cannot coexist with the in-flight first model queues until the first
|
||||
// completes, and then evicts it. This exercises the alsoRunning hint via the
|
||||
// matrix solver's union into runningSet.
|
||||
// completes, and then evicts it. This exercises the scheduler folding in-flight
|
||||
// swap targets into the running set it hands the swapper.
|
||||
func TestMatrix_IncompatibleQueues(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
@@ -173,8 +175,9 @@ func TestMatrix_IncompatibleQueues(t *testing.T) {
|
||||
}()
|
||||
waitProcessed(t, r.testProcessed, 1)
|
||||
|
||||
// B arrives before A transitions to StateStarting. The solver sees A via
|
||||
// alsoRunning and returns evict=[a], so collidesWith forces B to queue.
|
||||
// B arrives before A transitions to StateStarting. The running set the
|
||||
// scheduler builds includes A (an in-flight swap target), so the solver
|
||||
// returns evict=[a] and collidesWith forces B to queue.
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
|
||||
+11
-11
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
type peerMember struct {
|
||||
@@ -62,13 +63,12 @@ func NewPeer(cfg config.Config, logger *logmon.Monitor) (*Peer, error) {
|
||||
IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second,
|
||||
}
|
||||
|
||||
reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL)
|
||||
reverseProxy.Transport = peerTransport
|
||||
|
||||
originalDirector := reverseProxy.Director
|
||||
reverseProxy.Director = func(req *http.Request) {
|
||||
originalDirector(req)
|
||||
req.Host = req.URL.Host
|
||||
reverseProxy := &httputil.ReverseProxy{
|
||||
Transport: peerTransport,
|
||||
Rewrite: func(r *httputil.ProxyRequest) {
|
||||
r.SetURL(peer.ProxyURL)
|
||||
r.Out.Host = r.Out.URL.Host
|
||||
},
|
||||
}
|
||||
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
@@ -147,22 +147,22 @@ func (r *Peer) Shutdown(timeout time.Duration) error {
|
||||
|
||||
func (r *Peer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
if r.shuttingDown.Load() {
|
||||
SendError(w, req, fmt.Errorf("peer proxy is shutting down"))
|
||||
shared.SendError(w, req, fmt.Errorf("peer proxy is shutting down"))
|
||||
return
|
||||
}
|
||||
r.inflight.Add(1)
|
||||
defer r.inflight.Done()
|
||||
|
||||
data, err := FetchContext(req, r.cfg)
|
||||
data, err := shared.FetchContext(req, r.cfg)
|
||||
if err != nil {
|
||||
SendError(w, req, err)
|
||||
shared.SendError(w, req, err)
|
||||
return
|
||||
}
|
||||
|
||||
pp, found := r.peers[data.ModelID]
|
||||
if !found {
|
||||
r.logger.Warnf("peer model not found: %s", data.ModelID)
|
||||
SendError(w, req, ErrNoPeerModelFound)
|
||||
shared.SendError(w, req, ErrNoPeerModelFound)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
var testLogger = logmon.NewWriter(os.Stdout)
|
||||
@@ -142,7 +143,7 @@ func TestPeer_ServeHTTP_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
@@ -178,7 +179,7 @@ func TestPeer_ServeHTTP_PeerModelNotFound(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "nonexistent-model", ModelID: "nonexistent-model"}))
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "nonexistent-model", ModelID: "nonexistent-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
@@ -212,7 +213,7 @@ func TestPeer_ServeHTTP_ApiKeyInjection(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
@@ -246,7 +247,7 @@ func TestPeer_ServeHTTP_NoApiKey(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
@@ -279,7 +280,7 @@ func TestPeer_ServeHTTP_HostHeaderSet(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
@@ -311,7 +312,7 @@ func TestPeer_ServeHTTP_SSEHeaderModification(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
@@ -347,7 +348,7 @@ func TestPeer_ServeHTTP_ShutdownRejectsNewRequests(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
@@ -385,7 +386,7 @@ func TestPeer_ServeHTTP_WaitsForInflightDuringShutdown(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
@@ -448,7 +449,7 @@ func TestPeer_ServeHTTP_ShutdownTimeoutCancelsInflight(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
@@ -551,7 +552,7 @@ func TestPeer_ServeHTTP_ContextOverridesBodyModel(t *testing.T) {
|
||||
body := strings.NewReader(`{"model":"body-model","prompt":"hello"}`)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "context-model", ModelID: "context-model"}))
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "context-model", ModelID: "context-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
|
||||
+4
-151
@@ -1,39 +1,18 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
type contextkey struct {
|
||||
name string
|
||||
}
|
||||
|
||||
type ReqContextData struct {
|
||||
Model string
|
||||
ModelID string
|
||||
Streaming bool
|
||||
SendLoadingState bool
|
||||
}
|
||||
|
||||
var (
|
||||
ErrNoModelInContext = fmt.Errorf("no model in request context")
|
||||
ErrNoRouterFound = fmt.Errorf("no router found for model")
|
||||
ErrNoPeerModelFound = fmt.Errorf("peer model not found")
|
||||
ErrNoLocalModelFound = fmt.Errorf("local model not found")
|
||||
|
||||
ContextKey = &contextkey{"context"}
|
||||
ErrNoRouterFound = shared.ErrNoRouterFound
|
||||
ErrNoPeerModelFound = shared.ErrNoPeerModelFound
|
||||
ErrNoLocalModelFound = shared.ErrNoLocalModelFound
|
||||
)
|
||||
|
||||
type Router interface {
|
||||
@@ -71,129 +50,3 @@ type LocalRouter interface {
|
||||
// model is not known to this router.
|
||||
ProcessLogger(modelID string) (*logmon.Monitor, bool)
|
||||
}
|
||||
|
||||
// FetchContext will attempt to get the model id from the context then
|
||||
// from the model body. If it extracts the model from the body it will
|
||||
// store the model in the context for downstream handlers. An error
|
||||
// will be returned when model can not be fetch from either location.
|
||||
func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
|
||||
data, ok := ReadContext(r.Context())
|
||||
if ok {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
if data, err := ExtractContext(r); err == nil {
|
||||
realName, _ := cfg.RealModelName(data.Model)
|
||||
if realName == "" {
|
||||
realName = data.Model
|
||||
}
|
||||
data.ModelID = realName
|
||||
if mc, ok := cfg.Models[realName]; ok {
|
||||
data.SendLoadingState = mc.SendLoadingState != nil && *mc.SendLoadingState
|
||||
}
|
||||
*r = *r.WithContext(SetContext(r.Context(), data))
|
||||
return data, nil
|
||||
}
|
||||
|
||||
return ReqContextData{}, ErrNoModelInContext
|
||||
}
|
||||
|
||||
func SetContext(ctx context.Context, data ReqContextData) context.Context {
|
||||
return context.WithValue(ctx, ContextKey, data)
|
||||
}
|
||||
|
||||
func ReadContext(ctx context.Context) (ReqContextData, bool) {
|
||||
data, ok := ctx.Value(ContextKey).(ReqContextData)
|
||||
return data, ok
|
||||
}
|
||||
|
||||
// ExtractContext pulls the model name from an HTTP request without consuming the
|
||||
// body. For GET requests it reads the "model" query parameter. For POST
|
||||
// requests it inspects Content-Type and parses JSON, multipart/form-data, or
|
||||
// application/x-www-form-urlencoded bodies. The request body is always restored
|
||||
// before returning so downstream handlers — including reverse proxies that
|
||||
// forward raw bytes upstream — can still read it.
|
||||
func ExtractContext(r *http.Request) (ReqContextData, error) {
|
||||
if r.Method == http.MethodGet {
|
||||
if model := r.URL.Query().Get("model"); model != "" {
|
||||
return ReqContextData{Model: model, Streaming: r.URL.Query().Get("stream") == "true"}, nil
|
||||
}
|
||||
return ReqContextData{}, fmt.Errorf("missing 'model' query parameter")
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return ReqContextData{}, fmt.Errorf("error reading request body: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
}()
|
||||
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
model := gjson.GetBytes(bodyBytes, "model").String()
|
||||
if model == "" {
|
||||
return ReqContextData{}, fmt.Errorf("missing or empty 'model' in JSON body")
|
||||
}
|
||||
return ReqContextData{Model: model, Streaming: gjson.GetBytes(bodyBytes, "stream").Bool()}, nil
|
||||
}
|
||||
|
||||
// Form parsers read from r.Body, so feed them a fresh reader over the
|
||||
// buffered bytes. The deferred restore above will reset r.Body again
|
||||
// after parsing.
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
if strings.Contains(contentType, "multipart/form-data") {
|
||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||
return ReqContextData{}, fmt.Errorf("error parsing multipart form: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return ReqContextData{}, fmt.Errorf("error parsing form: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if model := r.FormValue("model"); model != "" {
|
||||
return ReqContextData{Model: model, Streaming: r.FormValue("stream") == "true"}, nil
|
||||
}
|
||||
|
||||
return ReqContextData{}, fmt.Errorf("missing 'model' parameter")
|
||||
}
|
||||
|
||||
func SendError(w http.ResponseWriter, r *http.Request, err error) {
|
||||
switch {
|
||||
case errors.Is(err, ErrNoModelInContext):
|
||||
SendResponse(w, r, http.StatusNotFound, "no model id could be identified")
|
||||
case errors.Is(err, ErrNoPeerModelFound):
|
||||
SendResponse(w, r, http.StatusNotFound, "no peer found for requested model")
|
||||
case errors.Is(err, ErrNoLocalModelFound):
|
||||
SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
|
||||
case errors.Is(err, ErrNoRouterFound):
|
||||
SendResponse(w, r, http.StatusNotFound, "no router for requested model")
|
||||
default:
|
||||
SendResponse(w, r, http.StatusInternalServerError, fmt.Sprintf("unspecific error: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// SendResponse detects what content type the client prefers and returns an error response in that format.
|
||||
func SendResponse(w http.ResponseWriter, r *http.Request, status int, message string) {
|
||||
// Check Accept header for preferred response format
|
||||
acceptHeader := r.Header.Get("Accept")
|
||||
if strings.Contains(acceptHeader, "text/plain") {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(fmt.Sprintf("llama-swap: %s", message)))
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(acceptHeader, "text/html") {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(fmt.Sprintf(`<html><body><h1>llama-swap</h1><p>%s</p></body></html>`, message)))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(fmt.Sprintf(`{"src":"llama-swap", "error": "%s"}`, message)))
|
||||
}
|
||||
|
||||
@@ -1,275 +0,0 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractContext_GET(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
wantModel string
|
||||
wantErr bool
|
||||
}{
|
||||
{"model present", "model=llama3", "llama3", false},
|
||||
{"model with slashes", "model=author/model-7b", "author/model-7b", false},
|
||||
{"model missing", "", "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
|
||||
got, err := ExtractContext(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||
}
|
||||
if got.Model != tt.wantModel {
|
||||
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_JSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantModel string
|
||||
wantErr bool
|
||||
}{
|
||||
{"model present", `{"model":"llama3","stream":true}`, "llama3", false},
|
||||
{"model with slashes", `{"model":"author/model-7b"}`, "author/model-7b", false},
|
||||
{"model empty string", `{"model":""}`, "", true},
|
||||
{"model key missing", `{"stream":true}`, "", true},
|
||||
{"invalid json", `not-json`, "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
got, err := ExtractContext(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||
}
|
||||
if got.Model != tt.wantModel {
|
||||
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_URLEncodedForm(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
formModel string
|
||||
wantModel string
|
||||
wantErr bool
|
||||
}{
|
||||
{"model present", "whisper-1", "whisper-1", false},
|
||||
{"model missing", "", "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
form := url.Values{}
|
||||
if tt.formModel != "" {
|
||||
form.Set("model", tt.formModel)
|
||||
}
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(form.Encode()))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
got, err := ExtractContext(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||
}
|
||||
if got.Model != tt.wantModel {
|
||||
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_MultipartForm(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
formModel string
|
||||
wantModel string
|
||||
wantErr bool
|
||||
}{
|
||||
{"model present", "whisper-1", "whisper-1", false},
|
||||
{"model missing", "", "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
mw := multipart.NewWriter(&buf)
|
||||
if tt.formModel != "" {
|
||||
fw, _ := mw.CreateFormField("model")
|
||||
fw.Write([]byte(tt.formModel))
|
||||
}
|
||||
mw.Close()
|
||||
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf)
|
||||
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
got, err := ExtractContext(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||
}
|
||||
if got.Model != tt.wantModel {
|
||||
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_JSONBodyRestored(t *testing.T) {
|
||||
body := `{"model":"llama3","stream":true}`
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
if _, err := ExtractContext(r); err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
|
||||
remaining, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body after ExtractContext: %v", err)
|
||||
}
|
||||
if string(remaining) != body {
|
||||
t.Errorf("body not restored: want %q got %q", body, string(remaining))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_MultipartBodyRestored(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
mw := multipart.NewWriter(&buf)
|
||||
fw, _ := mw.CreateFormField("model")
|
||||
fw.Write([]byte("whisper-1"))
|
||||
ff, _ := mw.CreateFormFile("file", "audio.wav")
|
||||
ff.Write([]byte("fake-audio-bytes"))
|
||||
mw.Close()
|
||||
|
||||
original := buf.Bytes()
|
||||
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", bytes.NewReader(original))
|
||||
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
|
||||
if _, err := ExtractContext(r); err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
|
||||
remaining, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body after ExtractContext: %v", err)
|
||||
}
|
||||
if !bytes.Equal(remaining, original) {
|
||||
t.Errorf("multipart body not restored: want %d bytes got %d bytes", len(original), len(remaining))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_URLEncodedBodyRestored(t *testing.T) {
|
||||
body := "model=whisper-1&extra=value"
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(body))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
if _, err := ExtractContext(r); err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
|
||||
remaining, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body after ExtractContext: %v", err)
|
||||
}
|
||||
if string(remaining) != body {
|
||||
t.Errorf("url-encoded body not restored: want %q got %q", body, string(remaining))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetContext(t *testing.T) {
|
||||
ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"})
|
||||
data, ok := ctx.Value(ContextKey).(ReqContextData)
|
||||
if !ok {
|
||||
t.Fatalf("ContextKey not set or wrong type")
|
||||
}
|
||||
if data.Model != "llama3" {
|
||||
t.Errorf("want %q got %q", "llama3", data.Model)
|
||||
}
|
||||
if data.ModelID != "llama3" {
|
||||
t.Errorf("want %q got %q", "llama3", data.ModelID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetContext_WithAlias(t *testing.T) {
|
||||
ctx := SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"})
|
||||
data, _ := ctx.Value(ContextKey).(ReqContextData)
|
||||
if data.Model != "llama" {
|
||||
t.Errorf("want requested %q got %q", "llama", data.Model)
|
||||
}
|
||||
if data.ModelID != "llama3" {
|
||||
t.Errorf("want real %q got %q", "llama3", data.ModelID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetContext_DoesNotMutateParent(t *testing.T) {
|
||||
parent := context.Background()
|
||||
_ = SetContext(parent, ReqContextData{Model: "llama3", ModelID: "llama3"})
|
||||
if v := parent.Value(ContextKey); v != nil {
|
||||
t.Errorf("parent context was mutated: %v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
wantReq string
|
||||
wantReal string
|
||||
wantBool bool
|
||||
}{
|
||||
{
|
||||
name: "model present, same name",
|
||||
ctx: SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"}),
|
||||
wantReq: "llama3",
|
||||
wantReal: "llama3",
|
||||
wantBool: true,
|
||||
},
|
||||
{
|
||||
name: "model present, aliased",
|
||||
ctx: SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"}),
|
||||
wantReq: "llama",
|
||||
wantReal: "llama3",
|
||||
wantBool: true,
|
||||
},
|
||||
{
|
||||
name: "model absent",
|
||||
ctx: context.Background(),
|
||||
wantReq: "",
|
||||
wantReal: "",
|
||||
wantBool: false,
|
||||
},
|
||||
{
|
||||
name: "model is empty string",
|
||||
ctx: SetContext(context.Background(), ReqContextData{Model: "", ModelID: ""}),
|
||||
wantReq: "",
|
||||
wantReal: "",
|
||||
wantBool: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotData, ok := ReadContext(tt.ctx)
|
||||
if gotData.Model != tt.wantReq || gotData.ModelID != tt.wantReal || ok != tt.wantBool {
|
||||
t.Errorf("want (%q, %q, %v) got (%q, %q, %v)", tt.wantReq, tt.wantReal, tt.wantBool, gotData.Model, gotData.ModelID, ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,489 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// defaultConcurrencyLimit caps simultaneous in-flight requests per model when
|
||||
// the model config leaves concurrencyLimit unset.
|
||||
const defaultConcurrencyLimit = 10
|
||||
|
||||
// activeSwap tracks one in-flight swap and the callers waiting on it.
|
||||
type activeSwap struct {
|
||||
modelID string
|
||||
evict []string
|
||||
waiters []HandlerReq
|
||||
}
|
||||
|
||||
// FIFO is the default scheduler. Requests are handled in a first-in, first-out order.
|
||||
// To reduce swapping requests for a model that is already running will be handled
|
||||
// immediately by the running process.
|
||||
//
|
||||
// Requests into this schedule are handled like this:
|
||||
//
|
||||
// A B C A B C --> A A B B C C
|
||||
//
|
||||
// The strategy is simple and reduces the number of swaps required.
|
||||
type FIFO struct {
|
||||
name string
|
||||
logger *logmon.Monitor
|
||||
planner Swapper
|
||||
cfg config.FifoConfig
|
||||
effects Effects
|
||||
|
||||
limits map[string]int
|
||||
active map[string]*activeSwap
|
||||
inFlight map[string]int
|
||||
queued []HandlerReq
|
||||
}
|
||||
|
||||
// NewFIFO builds a FIFO scheduler. Per-model concurrency limits are derived
|
||||
// from models: each model's ConcurrencyLimit overrides defaultConcurrencyLimit
|
||||
// when set to a value greater than zero.
|
||||
func NewFIFO(name string, logger *logmon.Monitor, planner Swapper, cfg config.FifoConfig, models map[string]config.ModelConfig, eff Effects) *FIFO {
|
||||
limits := make(map[string]int, len(models))
|
||||
for id, mc := range models {
|
||||
limit := defaultConcurrencyLimit
|
||||
if mc.ConcurrencyLimit > 0 {
|
||||
limit = mc.ConcurrencyLimit
|
||||
}
|
||||
limits[id] = limit
|
||||
}
|
||||
|
||||
return &FIFO{
|
||||
name: name,
|
||||
logger: logger,
|
||||
planner: planner,
|
||||
cfg: cfg,
|
||||
effects: eff,
|
||||
limits: limits,
|
||||
active: make(map[string]*activeSwap),
|
||||
inFlight: make(map[string]int),
|
||||
}
|
||||
}
|
||||
|
||||
// OnRequest decides what to do with one incoming ServeHTTP request. It never
|
||||
// blocks indefinitely: any work that has to wait (starting a process, stopping
|
||||
// siblings, waiting for ready) is deferred to a swap goroutine and reported back
|
||||
// via OnSwapDone.
|
||||
//
|
||||
// The decision tree, in order:
|
||||
//
|
||||
// 1. Unknown model — respond with ErrModelNotFound and move on.
|
||||
// 2. A swap to the same model is already in flight — attach this waiter so
|
||||
// one swap serves all callers that asked for the same model.
|
||||
// 3. Fast path — the target process is already ready, the planner sees
|
||||
// nothing to evict, and no in-flight swap is evicting it. Hand back its
|
||||
// ServeHTTP immediately.
|
||||
// 4. Would collide with an in-flight swap (we'd stop their target, or they're
|
||||
// stopping us) — park in the queue for OnSwapDone to drain.
|
||||
// 5. Would evict a process that is still handling requests — park in the
|
||||
// queue. OnServeDone will retry when the busy process drains.
|
||||
// 6. Otherwise — start a new swap. This may run in parallel with other active
|
||||
// swaps when their evict sets don't intersect.
|
||||
func (s *FIFO) OnRequest(req HandlerReq) {
|
||||
// (1) Unknown model.
|
||||
state, ok := s.effects.ModelState(req.Model)
|
||||
if !ok {
|
||||
s.logger.Debugf("%s: model %s not handled by this router", s.name, req.Model)
|
||||
s.effects.GrantError(req, ErrModelNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// (2) Join an in-flight swap for the same model.
|
||||
if sw, ok := s.active[req.Model]; ok {
|
||||
s.logger.Debugf("%s: joining in-flight swap for model %s (%d waiters)", s.name, req.Model, len(sw.waiters)+1)
|
||||
sw.waiters = append(sw.waiters, req)
|
||||
return
|
||||
}
|
||||
|
||||
running := s.runningSet(req.Model)
|
||||
evict := s.planner.EvictionFor(req.Model, running)
|
||||
|
||||
// (3) Fast path: ready, nothing to evict, and nobody is evicting us.
|
||||
if state == process.StateReady && len(evict) == 0 && !collidesWith(req.Model, evict, s.active) {
|
||||
s.logger.Debugf("%s: fast-path serving model %s (already ready)", s.name, req.Model)
|
||||
s.grantHandler(req, req.Model)
|
||||
return
|
||||
}
|
||||
|
||||
// (4) Collision with an in-flight swap — queue.
|
||||
if collidesWith(req.Model, evict, s.active) {
|
||||
s.logger.Debugf("%s: queuing request for model %s (collides with in-flight swap)", s.name, req.Model)
|
||||
s.enqueue(req)
|
||||
return
|
||||
}
|
||||
|
||||
// (5) Would evict a busy process — queue until it drains.
|
||||
if conflictsWithInFlight(evict, s.inFlight) {
|
||||
s.logger.Debugf("%s: queuing request for model %s (would evict in-flight process)", s.name, req.Model)
|
||||
s.enqueue(req)
|
||||
return
|
||||
}
|
||||
|
||||
// (6) Start a new (possibly parallel) swap.
|
||||
s.logger.Debugf("%s: starting swap for model %s, evicting %v", s.name, req.Model, evict)
|
||||
s.startSwap(req, evict, running)
|
||||
}
|
||||
|
||||
// OnCancel removes a request whose client has disconnected from the queue and
|
||||
// from every in-flight swap's waiters. If the request was the sole waiter of an
|
||||
// active swap, the swap goroutine is left to complete on its own — OnSwapDone
|
||||
// will find no waiters and simply clean up. This prevents drainQueue from ever
|
||||
// starting a model load for a caller that is no longer there.
|
||||
func (s *FIFO) OnCancel(req HandlerReq) {
|
||||
removed := false
|
||||
|
||||
// Prune from the queue.
|
||||
if len(s.queued) > 0 {
|
||||
kept := s.queued[:0]
|
||||
for _, q := range s.queued {
|
||||
if q.Respond == req.Respond {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
kept = append(kept, q)
|
||||
}
|
||||
s.queued = kept
|
||||
}
|
||||
|
||||
// Prune from any active swap's waiters.
|
||||
for _, sw := range s.active {
|
||||
filtered := sw.waiters[:0]
|
||||
for _, w := range sw.waiters {
|
||||
if w.Respond == req.Respond {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, w)
|
||||
}
|
||||
sw.waiters = filtered
|
||||
}
|
||||
|
||||
if removed {
|
||||
s.logger.Debugf("%s: cancelled request for model %s pruned from scheduler", s.name, req.Model)
|
||||
broadcastQueuePositions(s.queued)
|
||||
}
|
||||
}
|
||||
|
||||
// OnSwapDone fans the result out to every waiter that joined this swap, removes
|
||||
// the swap from the active map, then walks the queue once, promoting any items
|
||||
// that no longer collide with the remaining active set. FIFO order is preserved:
|
||||
// items still blocked stay in place.
|
||||
func (s *FIFO) OnSwapDone(ev SwapDone) {
|
||||
sw, ok := s.active[ev.ModelID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(s.active, ev.ModelID)
|
||||
|
||||
for _, w := range sw.waiters {
|
||||
if ev.Err != nil {
|
||||
s.effects.GrantError(w, ev.Err)
|
||||
} else {
|
||||
s.grantHandler(w, ev.ModelID)
|
||||
}
|
||||
}
|
||||
|
||||
s.drainQueue()
|
||||
}
|
||||
|
||||
// OnServeDone decrements the per-model in-flight count and, when that drops to
|
||||
// zero, retries the queue: requests whose swap was deferred because they would
|
||||
// have evicted this (now-idle) process can now proceed.
|
||||
func (s *FIFO) OnServeDone(ev ServeDoneEvent) {
|
||||
s.inFlight[ev.ModelID]--
|
||||
if s.inFlight[ev.ModelID] <= 0 {
|
||||
delete(s.inFlight, ev.ModelID)
|
||||
s.drainQueue()
|
||||
}
|
||||
}
|
||||
|
||||
// OnUnload reconciles router-owned state with the impending Stop, performs the
|
||||
// Stop (synchronously, via Effects) so callers of Unload remain blocked until
|
||||
// each targeted process has exited, then drains the queue.
|
||||
func (s *FIFO) OnUnload(targets []string, timeout time.Duration) {
|
||||
unloadErr := fmt.Errorf("%s: model unloaded", s.name)
|
||||
|
||||
targetSet := make(map[string]bool, len(targets))
|
||||
for _, id := range targets {
|
||||
targetSet[id] = true
|
||||
}
|
||||
|
||||
// Release waiters of any in-flight swap whose target is being unloaded.
|
||||
// The swap goroutine itself is left to finish on its own; when its
|
||||
// SwapDone arrives, OnSwapDone will find no entry in active and drop it.
|
||||
for id := range targetSet {
|
||||
sw, ok := s.active[id]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, w := range sw.waiters {
|
||||
s.effects.GrantError(w, unloadErr)
|
||||
}
|
||||
delete(s.active, id)
|
||||
}
|
||||
|
||||
// Drop queued requests addressed to unloaded models. Requests for other
|
||||
// models stay queued and may benefit from drainQueue at the end.
|
||||
if len(s.queued) > 0 {
|
||||
kept := s.queued[:0]
|
||||
for _, w := range s.queued {
|
||||
if targetSet[w.Model] {
|
||||
s.effects.GrantError(w, unloadErr)
|
||||
continue
|
||||
}
|
||||
kept = append(kept, w)
|
||||
}
|
||||
s.queued = kept
|
||||
}
|
||||
|
||||
// Stop the targeted processes. Done synchronously so Unload's caller can
|
||||
// rely on "after Unload returns, the process is stopped". inFlight is
|
||||
// intentionally NOT cleared here: each dying handler will fire its tracked
|
||||
// serve and reach OnServeDone in the normal way.
|
||||
s.effects.StopProcesses(timeout, targets)
|
||||
|
||||
// Removing entries from active above may have unblocked queued requests
|
||||
// that previously collided with the now-cancelled swaps.
|
||||
s.drainQueue()
|
||||
}
|
||||
|
||||
// OnShutdown grants err to every waiter still held by the scheduler.
|
||||
func (s *FIFO) OnShutdown(err error) {
|
||||
for _, sw := range s.active {
|
||||
for _, w := range sw.waiters {
|
||||
s.effects.GrantError(w, err)
|
||||
}
|
||||
}
|
||||
for _, w := range s.queued {
|
||||
s.effects.GrantError(w, err)
|
||||
}
|
||||
}
|
||||
|
||||
// grantHandler hands the caller a tracked handler for modelID and, only if the
|
||||
// caller was still there to receive it, bumps the in-flight count. Incrementing
|
||||
// when the grant failed would strand the counter and block future evictions.
|
||||
// Requests that would exceed the model's concurrency limit are rejected with a
|
||||
// shared.NewConcurrencyLimitError (HTTP 429 with Retry-After).
|
||||
func (s *FIFO) grantHandler(req HandlerReq, modelID string) {
|
||||
if s.inFlight[modelID] >= s.limit(modelID) {
|
||||
s.effects.GrantError(req, shared.ConcurrencyLimitError{})
|
||||
return
|
||||
}
|
||||
|
||||
if err := shared.SetReqData(req.Ctx, "fifo_priority", strconv.Itoa(s.cfg.Priority[req.Model])); err != nil {
|
||||
s.logger.Debugf("failed to set fifo_priority metadata: %v", err)
|
||||
}
|
||||
|
||||
if s.effects.GrantServe(req, modelID) {
|
||||
s.inFlight[modelID]++
|
||||
}
|
||||
}
|
||||
|
||||
// limit returns the per-model concurrency cap, defaulting to
|
||||
// defaultConcurrencyLimit when the model has no explicit entry.
|
||||
func (s *FIFO) limit(modelID string) int {
|
||||
if l, ok := s.limits[modelID]; ok {
|
||||
return l
|
||||
}
|
||||
return defaultConcurrencyLimit
|
||||
}
|
||||
|
||||
// startSwap records the swap as active and launches it via Effects. running is
|
||||
// the set EvictionFor saw, forwarded to OnSwapStart so the planner logs against
|
||||
// the same picture it decided on.
|
||||
func (s *FIFO) startSwap(initial HandlerReq, evict, running []string) {
|
||||
s.active[initial.Model] = &activeSwap{
|
||||
modelID: initial.Model,
|
||||
evict: evict,
|
||||
waiters: []HandlerReq{initial},
|
||||
}
|
||||
s.planner.OnSwapStart(initial.Model, running)
|
||||
s.effects.StartSwap(initial.Model, evict)
|
||||
}
|
||||
|
||||
// enqueue inserts req into the queue in priority order: it goes just before the
|
||||
// first queued item whose priority is strictly lower, so higher-priority models
|
||||
// are serviced first while equal-priority requests keep their arrival (FIFO)
|
||||
// order. Priorities come from the FifoConfig; unlisted models default to 0.
|
||||
func (s *FIFO) enqueue(req HandlerReq) {
|
||||
p := s.cfg.Priority[req.Model]
|
||||
i := len(s.queued)
|
||||
for j, q := range s.queued {
|
||||
if s.cfg.Priority[q.Model] < p {
|
||||
i = j
|
||||
break
|
||||
}
|
||||
}
|
||||
s.queued = append(s.queued, HandlerReq{})
|
||||
copy(s.queued[i+1:], s.queued[i:])
|
||||
s.queued[i] = req
|
||||
broadcastQueuePositions(s.queued)
|
||||
}
|
||||
|
||||
// drainQueue walks the queued requests in order, re-running the OnRequest
|
||||
// decision tree against the (now smaller) active set. Items that can now start
|
||||
// or join become satisfied; items still blocked remain queued in original order
|
||||
// so they get another chance on the next swap completion.
|
||||
func (s *FIFO) drainQueue() {
|
||||
if len(s.queued) == 0 {
|
||||
return
|
||||
}
|
||||
pending := s.queued
|
||||
var remaining []HandlerReq
|
||||
for _, req := range pending {
|
||||
state, ok := s.effects.ModelState(req.Model)
|
||||
if !ok {
|
||||
s.effects.GrantError(req, ErrModelNotFound)
|
||||
continue
|
||||
}
|
||||
if sw, ok := s.active[req.Model]; ok {
|
||||
s.logger.Debugf("%s: queued request for model %s now joining in-flight swap", s.name, req.Model)
|
||||
sw.waiters = append(sw.waiters, req)
|
||||
continue
|
||||
}
|
||||
running := s.runningSet(req.Model)
|
||||
evict := s.planner.EvictionFor(req.Model, running)
|
||||
if state == process.StateReady && len(evict) == 0 && !collidesWith(req.Model, evict, s.active) {
|
||||
s.logger.Debugf("%s: queued request for model %s now served fast-path", s.name, req.Model)
|
||||
s.grantHandler(req, req.Model)
|
||||
continue
|
||||
}
|
||||
if collidesWith(req.Model, evict, s.active) {
|
||||
remaining = append(remaining, req)
|
||||
continue
|
||||
}
|
||||
if conflictsWithInFlight(evict, s.inFlight) {
|
||||
remaining = append(remaining, req)
|
||||
continue
|
||||
}
|
||||
s.logger.Debugf("%s: queued request for model %s now starting swap, evicting %v", s.name, req.Model, evict)
|
||||
s.startSwap(req, evict, running)
|
||||
}
|
||||
s.queued = remaining
|
||||
broadcastQueuePositions(s.queued)
|
||||
}
|
||||
|
||||
// runningSet is the live model set handed to the Swapper: every process the
|
||||
// baseRouter reports as running, unioned with the targets of in-flight swaps
|
||||
// (excluding excludeActive, the model whose own swap is being decided — its
|
||||
// in-flight entry must not count as "already running"). The result is sorted so
|
||||
// eviction decisions derived from it are deterministic.
|
||||
func (s *FIFO) runningSet(excludeActive string) []string {
|
||||
seen := make(map[string]struct{})
|
||||
var out []string
|
||||
add := func(id string) {
|
||||
if _, dup := seen[id]; dup {
|
||||
return
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
for id := range s.effects.RunningModels() {
|
||||
add(id)
|
||||
}
|
||||
for _, id := range activeTargets(s.active, excludeActive) {
|
||||
add(id)
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// activeTargets returns the IDs of every in-flight swap target except exclude.
|
||||
// The planner uses this to account for models committed to but not yet reflected
|
||||
// in process state.
|
||||
func activeTargets(active map[string]*activeSwap, exclude string) []string {
|
||||
if len(active) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(active))
|
||||
for id := range active {
|
||||
if id == exclude {
|
||||
continue
|
||||
}
|
||||
out = append(out, id)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// collidesWith reports whether a new swap with this target and evict set can
|
||||
// safely run alongside the currently active swaps. Same-target callers should
|
||||
// JOIN (handled before this) — they do not collide with themselves.
|
||||
func collidesWith(target string, evict []string, active map[string]*activeSwap) bool {
|
||||
for id, sw := range active {
|
||||
if id == target {
|
||||
continue
|
||||
}
|
||||
if containsString(evict, id) {
|
||||
return true
|
||||
}
|
||||
if containsString(sw.evict, target) {
|
||||
return true
|
||||
}
|
||||
if slicesOverlap(evict, sw.evict) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// slicesOverlap reports whether xs and ys share any common element.
|
||||
func slicesOverlap(xs, ys []string) bool {
|
||||
for _, x := range xs {
|
||||
if containsString(ys, x) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// conflictsWithInFlight reports whether any model in evict is still handling
|
||||
// requests. Stopping a busy process would cancel its callers' connections, so
|
||||
// the scheduler defers the swap until those callers finish.
|
||||
func conflictsWithInFlight(evict []string, inFlight map[string]int) bool {
|
||||
for _, m := range evict {
|
||||
if inFlight[m] > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsString(xs []string, s string) bool {
|
||||
for _, x := range xs {
|
||||
if x == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// broadcastQueuePositions sends each queued request its current 1-indexed
|
||||
// position. Sends are non-blocking: if the channel is full, the old value is
|
||||
// drained first so the consumer always sees the latest position.
|
||||
func broadcastQueuePositions(queued []HandlerReq) {
|
||||
for i, req := range queued {
|
||||
pos := i + 1
|
||||
select {
|
||||
case req.PositionCh <- pos:
|
||||
default:
|
||||
select {
|
||||
case <-req.PositionCh:
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case req.PositionCh <- pos:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,779 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// FIFO methods all run on the router's single run-loop goroutine, so these
|
||||
// tests drive them directly and synchronously. A swap is "completed" by calling
|
||||
// OnSwapDone, a served request "finishes" by calling OnServeDone — exactly the
|
||||
// events the run loop would deliver. fakeEffects records every side-effect and
|
||||
// stubPlanner supplies a fixed eviction set per target.
|
||||
|
||||
// stubPlanner returns a fixed eviction list per target.
|
||||
type stubPlanner struct {
|
||||
evict map[string][]string
|
||||
}
|
||||
|
||||
func (s *stubPlanner) EvictionFor(target string, _ []string) []string {
|
||||
if s.evict == nil {
|
||||
return nil
|
||||
}
|
||||
return s.evict[target]
|
||||
}
|
||||
|
||||
func (s *stubPlanner) OnSwapStart(string, []string) {}
|
||||
|
||||
// grantRec is one GrantError / GrantServe call. err!=nil marks an error grant;
|
||||
// otherwise it is a serve grant and serve reports whether the caller received it.
|
||||
type grantRec struct {
|
||||
model string
|
||||
err error
|
||||
serve bool
|
||||
}
|
||||
|
||||
type startRec struct {
|
||||
model string
|
||||
evict []string
|
||||
}
|
||||
|
||||
type stopRec struct {
|
||||
timeout time.Duration
|
||||
ids []string
|
||||
}
|
||||
|
||||
// fakeEffects is an in-memory scheduler.Effects. Tests program process states
|
||||
// and GrantServe outcomes, then assert on the recorded calls.
|
||||
type fakeEffects struct {
|
||||
states map[string]process.ProcessState // model -> state; missing => not handled
|
||||
serveResult map[string]bool // GrantServe return per model (default true)
|
||||
lastServeReq HandlerReq
|
||||
|
||||
starts []startRec
|
||||
grants []grantRec
|
||||
stops []stopRec
|
||||
}
|
||||
|
||||
func newFakeEffects() *fakeEffects {
|
||||
return &fakeEffects{
|
||||
states: map[string]process.ProcessState{},
|
||||
serveResult: map[string]bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeEffects) ModelState(modelID string) (process.ProcessState, bool) {
|
||||
st, ok := f.states[modelID]
|
||||
return st, ok
|
||||
}
|
||||
|
||||
func (f *fakeEffects) RunningModels() map[string]process.ProcessState {
|
||||
out := make(map[string]process.ProcessState)
|
||||
for id, st := range f.states {
|
||||
if st == process.StateStopped || st == process.StateShutdown {
|
||||
continue
|
||||
}
|
||||
out[id] = st
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (f *fakeEffects) StartSwap(modelID string, evict []string) {
|
||||
f.starts = append(f.starts, startRec{model: modelID, evict: evict})
|
||||
}
|
||||
|
||||
func (f *fakeEffects) GrantError(req HandlerReq, err error) {
|
||||
f.grants = append(f.grants, grantRec{model: req.Model, err: err})
|
||||
}
|
||||
|
||||
func (f *fakeEffects) GrantServe(req HandlerReq, modelID string) bool {
|
||||
ok := true
|
||||
if v, set := f.serveResult[modelID]; set {
|
||||
ok = v
|
||||
}
|
||||
f.lastServeReq = req
|
||||
f.grants = append(f.grants, grantRec{model: modelID, serve: ok})
|
||||
return ok
|
||||
}
|
||||
|
||||
func (f *fakeEffects) StopProcesses(timeout time.Duration, ids []string) {
|
||||
f.stops = append(f.stops, stopRec{timeout: timeout, ids: ids})
|
||||
}
|
||||
|
||||
// served counts grants that handed modelID a handler and were received.
|
||||
func (f *fakeEffects) served(modelID string) int {
|
||||
n := 0
|
||||
for _, g := range f.grants {
|
||||
if g.err == nil && g.serve && g.model == modelID {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// errored counts error grants, optionally filtered by model ("" = any).
|
||||
func (f *fakeEffects) errored(model string) int {
|
||||
n := 0
|
||||
for _, g := range f.grants {
|
||||
if g.err != nil && (model == "" || g.model == model) {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// startsFor counts StartSwap calls for modelID.
|
||||
func (f *fakeEffects) startsFor(modelID string) int {
|
||||
n := 0
|
||||
for _, s := range f.starts {
|
||||
if s.model == modelID {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func newFIFO(planner Swapper, eff Effects) *FIFO {
|
||||
return NewFIFO("test", logmon.NewWriter(io.Discard), planner, config.FifoConfig{}, nil, eff)
|
||||
}
|
||||
|
||||
func req(model string) HandlerReq { return HandlerReq{Model: model} }
|
||||
|
||||
// reqCh creates a HandlerReq with a unique Respond channel so OnCancel can
|
||||
// identify it among queued requests and swap waiters.
|
||||
func reqCh(model string) HandlerReq {
|
||||
return HandlerReq{
|
||||
Model: model,
|
||||
Respond: make(chan HandlerResp, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_FastPath(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateReady
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
|
||||
if got := eff.startsFor("a"); got != 0 {
|
||||
t.Errorf("StartSwap calls=%d want 0 (fast path should not swap)", got)
|
||||
}
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Errorf("served(a)=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_GrantSetsPriorityMetadata(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateReady
|
||||
cfg := config.FifoConfig{Priority: map[string]int{"a": 7}}
|
||||
s := NewFIFO("test", logmon.NewWriter(io.Discard), &stubPlanner{}, cfg, nil, eff)
|
||||
|
||||
ctx := shared.SetContext(context.Background(), shared.ReqContextData{ModelID: "a", Metadata: make(map[string]string)})
|
||||
s.OnRequest(HandlerReq{Model: "a", Ctx: ctx})
|
||||
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Fatalf("served(a)=%d want 1", got)
|
||||
}
|
||||
data, ok := shared.ReadContext(eff.lastServeReq.Ctx)
|
||||
if !ok {
|
||||
t.Fatal("context data missing from granted request")
|
||||
}
|
||||
if data.Metadata["fifo_priority"] != "7" {
|
||||
t.Errorf("fifo_priority = %q, want 7", data.Metadata["fifo_priority"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_ModelNotFound(t *testing.T) {
|
||||
eff := newFakeEffects() // no states => model unknown
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
s.OnRequest(req("ghost"))
|
||||
|
||||
if got := len(eff.starts); got != 0 {
|
||||
t.Errorf("StartSwap calls=%d want 0", got)
|
||||
}
|
||||
if eff.errored("ghost") != 1 {
|
||||
t.Fatalf("want 1 error grant for ghost, grants=%+v", eff.grants)
|
||||
}
|
||||
if !errors.Is(eff.grants[0].err, ErrModelNotFound) {
|
||||
t.Errorf("err=%v want ErrModelNotFound", eff.grants[0].err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_OnDemandStartThenServe(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
if got := eff.startsFor("a"); got != 1 {
|
||||
t.Fatalf("StartSwap(a)=%d want 1", got)
|
||||
}
|
||||
if got := eff.served("a"); got != 0 {
|
||||
t.Errorf("served(a)=%d want 0 before swap completes", got)
|
||||
}
|
||||
|
||||
// Swap finishes, model is now ready.
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Errorf("served(a)=%d want 1 after swap done", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_JoinInFlightSwap(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // starts swap
|
||||
s.OnRequest(req("a")) // joins
|
||||
s.OnRequest(req("a")) // joins
|
||||
|
||||
if got := eff.startsFor("a"); got != 1 {
|
||||
t.Fatalf("StartSwap(a)=%d want 1 (all three share one swap)", got)
|
||||
}
|
||||
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
|
||||
if got := eff.served("a"); got != 3 {
|
||||
t.Errorf("served(a)=%d want 3 (one swap serves all waiters)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_SwapDoneError_FailsAllWaiters(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
s.OnRequest(req("a"))
|
||||
|
||||
s.OnSwapDone(SwapDone{ModelID: "a", Err: errors.New("boom")})
|
||||
|
||||
if eff.served("a") != 0 {
|
||||
t.Errorf("served(a)=%d want 0 on swap error", eff.served("a"))
|
||||
}
|
||||
if eff.errored("a") != 2 {
|
||||
t.Errorf("errored(a)=%d want 2 (both waiters fail)", eff.errored("a"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_QueueOnEvictionCollision covers a request whose target evicts the
|
||||
// model currently being swapped: it must queue until that swap finishes AND its
|
||||
// served request drains, because starting it would stop a busy process.
|
||||
func TestFIFO_QueueOnEvictionCollision(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
// Loading b evicts a.
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"b": {"a"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // StartSwap(a)
|
||||
s.OnRequest(req("b")) // collides with a's in-flight swap -> queue
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Fatalf("b started early: StartSwap(b)=%d want 0", got)
|
||||
}
|
||||
|
||||
// a becomes ready and is granted (now serving, inFlight[a]=1).
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Fatalf("b started while a is serving: StartSwap(b)=%d want 0", got)
|
||||
}
|
||||
|
||||
// a's request finishes -> a no longer in-flight -> b may now swap.
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
|
||||
if got := eff.startsFor("b"); got != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 after a drained", got)
|
||||
}
|
||||
if got := eff.starts[len(eff.starts)-1].evict; len(got) != 1 || got[0] != "a" {
|
||||
t.Errorf("b swap evict=%v want [a]", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_DisjointSwapsRunInParallel verifies two requests with
|
||||
// non-conflicting evict sets both start without waiting for each other.
|
||||
func TestFIFO_DisjointSwapsRunInParallel(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
s := newFIFO(&stubPlanner{}, eff) // empty evicts
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
s.OnRequest(req("b"))
|
||||
|
||||
if eff.startsFor("a") != 1 || eff.startsFor("b") != 1 {
|
||||
t.Fatalf("StartSwap a=%d b=%d want 1 each (parallel)", eff.startsFor("a"), eff.startsFor("b"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_OverlappingEvictSetsDoNotRunInParallel verifies two swaps with
|
||||
// different targets that evict the *same* model do not run concurrently: the
|
||||
// second must queue rather than double-evict the shared model. Neither target is
|
||||
// in the other's evict set, so this is only caught by the evict-set overlap
|
||||
// check in collidesWith.
|
||||
func TestFIFO_OverlappingEvictSetsDoNotRunInParallel(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
eff.states["x"] = process.StateReady // shared eviction target, running
|
||||
// Loading a or b both require evicting x.
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"a": {"x"}, "b": {"x"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // StartSwap(a, [x])
|
||||
s.OnRequest(req("b")) // overlaps a's evict set ([x]) -> queue
|
||||
if eff.startsFor("a") != 1 {
|
||||
t.Fatalf("StartSwap(a)=%d want 1", eff.startsFor("a"))
|
||||
}
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Fatalf("b started in parallel while a evicts x: StartSwap(b)=%d want 0", got)
|
||||
}
|
||||
|
||||
// a's swap completes and x is gone; b can now evict nothing and start.
|
||||
eff.states["a"] = process.StateReady
|
||||
eff.states["x"] = process.StateStopped
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
if got := eff.startsFor("b"); got != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 after a's swap drained", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_QueueDrainPromotesMultiple verifies completing one swap unblocks
|
||||
// every queued request that no longer collides — they all start together.
|
||||
func TestFIFO_QueueDrainPromotesMultiple(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
eff.states["c"] = process.StateStopped
|
||||
// a's swap evicts both b and c; b and c evict nothing.
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"a": {"b", "c"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // StartSwap(a, [b,c])
|
||||
s.OnRequest(req("b")) // collides (in a's evict set) -> queue
|
||||
s.OnRequest(req("c")) // collides -> queue
|
||||
if eff.startsFor("b") != 0 || eff.startsFor("c") != 0 {
|
||||
t.Fatalf("b/c started early")
|
||||
}
|
||||
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
|
||||
// b and c have empty evict sets and don't evict a, so both start now.
|
||||
if eff.startsFor("b") != 1 || eff.startsFor("c") != 1 {
|
||||
t.Fatalf("StartSwap b=%d c=%d want 1 each after a done", eff.startsFor("b"), eff.startsFor("c"))
|
||||
}
|
||||
if eff.served("a") != 1 {
|
||||
t.Errorf("served(a)=%d want 1", eff.served("a"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_QueueCollation verifies duplicate requests collapse into one swap
|
||||
// per model: the second request for each model joins the active swap (at arrival
|
||||
// or at drain time) rather than triggering its own swap.
|
||||
func TestFIFO_QueueCollation(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
for _, id := range []string{"a", "b", "c"} {
|
||||
eff.states[id] = process.StateStopped
|
||||
}
|
||||
// Each model evicts the other two: all swaps are mutually exclusive.
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{
|
||||
"a": {"b", "c"},
|
||||
"b": {"a", "c"},
|
||||
"c": {"a", "b"},
|
||||
}}, eff)
|
||||
|
||||
for _, id := range []string{"a", "b", "c", "a", "b", "c"} {
|
||||
s.OnRequest(req(id))
|
||||
}
|
||||
|
||||
// Drain a, then its served requests, which promotes b; repeat for b -> c.
|
||||
drain := func(model string, waiters int) {
|
||||
eff.states[model] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: model})
|
||||
for i := 0; i < waiters; i++ {
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: model})
|
||||
}
|
||||
}
|
||||
drain("a", 2)
|
||||
drain("b", 2)
|
||||
drain("c", 2)
|
||||
|
||||
for _, id := range []string{"a", "b", "c"} {
|
||||
if got := eff.startsFor(id); got != 1 {
|
||||
t.Errorf("StartSwap(%s)=%d want 1 (collation)", id, got)
|
||||
}
|
||||
if got := eff.served(id); got != 2 {
|
||||
t.Errorf("served(%s)=%d want 2", id, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_NoSwapWhileServing verifies a model still handling requests is not
|
||||
// evicted: the evicting request waits until every in-flight request drains.
|
||||
func TestFIFO_NoSwapWhileServing(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateReady
|
||||
eff.states["b"] = process.StateStopped
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"b": {"a"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // fast path, inFlight[a]=1
|
||||
s.OnRequest(req("a")) // fast path, inFlight[a]=2
|
||||
s.OnRequest(req("b")) // would evict busy a -> queue
|
||||
if eff.startsFor("b") != 0 {
|
||||
t.Fatalf("b started while a serving")
|
||||
}
|
||||
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"}) // inFlight[a]=1
|
||||
if eff.startsFor("b") != 0 {
|
||||
t.Fatalf("b started while a still serving one request")
|
||||
}
|
||||
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"}) // inFlight[a]=0
|
||||
if eff.startsFor("b") != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 after a fully drained", eff.startsFor("b"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_GrantServeFalseDoesNotLeakInFlight verifies that when a caller has
|
||||
// walked away (GrantServe returns false) the in-flight count is not bumped, so a
|
||||
// later evicting request is not blocked forever.
|
||||
func TestFIFO_GrantServeFalseDoesNotLeakInFlight(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
eff.serveResult["a"] = false // a's waiter is gone by grant time
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"b": {"a"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"}) // grant fails, inFlight[a] stays 0
|
||||
|
||||
// b evicts a; since a is not in-flight, b should start immediately.
|
||||
s.OnRequest(req("b"))
|
||||
if eff.startsFor("b") != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 (no leaked in-flight on a)", eff.startsFor("b"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_OnShutdown_FailsAllWaiters verifies shutdown errors every waiter the
|
||||
// scheduler holds: active-swap waiters and queued requests alike.
|
||||
func TestFIFO_OnShutdown_FailsAllWaiters(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
for _, id := range []string{"a", "b", "c"} {
|
||||
eff.states[id] = process.StateStopped
|
||||
}
|
||||
// a and b load in parallel; c collides with both and queues.
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"c": {"a", "b"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // StartSwap(a)
|
||||
s.OnRequest(req("a")) // join a
|
||||
s.OnRequest(req("b")) // StartSwap(b)
|
||||
s.OnRequest(req("b")) // join b
|
||||
s.OnRequest(req("c")) // queued
|
||||
|
||||
s.OnShutdown(errors.New("shutting down"))
|
||||
|
||||
if got := eff.errored(""); got != 5 {
|
||||
t.Errorf("error grants=%d want 5 (2 a + 2 b + 1 c)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_OnUnload_ReleasesActiveWaiters(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // active swap a with one waiter
|
||||
s.OnRequest(req("a")) // join
|
||||
|
||||
s.OnUnload([]string{"a"}, time.Second)
|
||||
|
||||
if got := eff.errored("a"); got != 2 {
|
||||
t.Errorf("errored(a)=%d want 2 (active swap waiters released)", got)
|
||||
}
|
||||
if len(eff.stops) != 1 || len(eff.stops[0].ids) != 1 || eff.stops[0].ids[0] != "a" {
|
||||
t.Errorf("StopProcesses=%+v want one call stopping [a]", eff.stops)
|
||||
}
|
||||
if eff.stops[0].timeout != time.Second {
|
||||
t.Errorf("StopProcesses timeout=%v want 1s", eff.stops[0].timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_OnUnload_DropsQueuedRequests(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
// b evicts a, so a request for b queues while a is loading.
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"b": {"a"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // StartSwap(a)
|
||||
s.OnRequest(req("b")) // queued
|
||||
|
||||
s.OnUnload([]string{"b"}, time.Second)
|
||||
|
||||
if got := eff.errored("b"); got != 1 {
|
||||
t.Errorf("errored(b)=%d want 1 (queued request dropped)", got)
|
||||
}
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Errorf("StartSwap(b)=%d want 0 (b should never start)", got)
|
||||
}
|
||||
// a's swap is untouched: its waiter is neither served nor errored yet.
|
||||
if eff.served("a") != 0 || eff.errored("a") != 0 {
|
||||
t.Errorf("a swap should be untouched: served=%d errored=%d", eff.served("a"), eff.errored("a"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_PriorityQueueOrder verifies queued requests are ordered by descending
|
||||
// priority, with arrival (FIFO) order preserved among equal-priority models.
|
||||
func TestFIFO_PriorityQueueOrder(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
for _, m := range []string{"z", "A", "B", "C", "D"} {
|
||||
eff.states[m] = process.StateStopped
|
||||
}
|
||||
// z's swap evicts every other model, so any request that arrives while z is
|
||||
// loading collides with z's in-flight swap and parks in the queue.
|
||||
planner := &stubPlanner{evict: map[string][]string{"z": {"A", "B", "C", "D"}}}
|
||||
cfg := config.FifoConfig{Priority: map[string]int{"A": 10, "B": 5, "C": 5, "D": 1}}
|
||||
s := NewFIFO("test", logmon.NewWriter(io.Discard), planner, cfg, nil, eff)
|
||||
|
||||
s.OnRequest(req("z")) // StartSwap(z, [A,B,C,D])
|
||||
|
||||
// Arrive out of priority order; B before C exercises FIFO tie-breaking.
|
||||
for _, m := range []string{"B", "D", "C", "A"} {
|
||||
s.OnRequest(req(m))
|
||||
}
|
||||
|
||||
got := make([]string, len(s.queued))
|
||||
for i, q := range s.queued {
|
||||
got[i] = q.Model
|
||||
}
|
||||
want := []string{"A", "B", "C", "D"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("queue=%v want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("queue=%v want %v", got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_OnCancel_QueuedRequest verifies that cancelling a queued request
|
||||
// prevents drainQueue from ever starting a model load for it. Without OnCancel
|
||||
// the dead request would sit in the queue until a drain triggers a wasted swap.
|
||||
func TestFIFO_OnCancel_QueuedRequest(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
// b evicts a, so a request for b queues while a is loading.
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"b": {"a"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // StartSwap(a)
|
||||
|
||||
cancelledReq := reqCh("b")
|
||||
s.OnRequest(cancelledReq) // queued (collides with a's in-flight swap)
|
||||
if len(s.queued) != 1 {
|
||||
t.Fatalf("queue len=%d want 1 before cancel", len(s.queued))
|
||||
}
|
||||
|
||||
// Client disconnects.
|
||||
s.OnCancel(cancelledReq)
|
||||
|
||||
if len(s.queued) != 0 {
|
||||
t.Fatalf("queue len=%d want 0 after cancel", len(s.queued))
|
||||
}
|
||||
|
||||
// a's swap finishes; drainQueue runs but b is gone — no swap for b.
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Errorf("StartSwap(b)=%d want 0 (cancelled request should not trigger a load)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_OnCancel_SwapWaiter verifies that cancelling a request that joined an
|
||||
// in-flight swap removes it from the waiter list. When the swap completes, the
|
||||
// cancelled waiter receives no grant and does not bump the in-flight count.
|
||||
func TestFIFO_OnCancel_SwapWaiter(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
liveReq := reqCh("a")
|
||||
cancelledReq := reqCh("a")
|
||||
s.OnRequest(liveReq) // starts swap
|
||||
s.OnRequest(cancelledReq) // joins
|
||||
|
||||
if sw := s.active["a"]; len(sw.waiters) != 2 {
|
||||
t.Fatalf("waiters=%d want 2", len(sw.waiters))
|
||||
}
|
||||
|
||||
s.OnCancel(cancelledReq)
|
||||
|
||||
if sw := s.active["a"]; len(sw.waiters) != 1 {
|
||||
t.Fatalf("waiters=%d want 1 after cancel", len(sw.waiters))
|
||||
}
|
||||
|
||||
// Swap finishes: only the live waiter is granted.
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Errorf("served(a)=%d want 1 (only the non-cancelled waiter)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_OnCancel_NotPresent is a no-op: cancelling a request that was already
|
||||
// granted (and is no longer queued or waiting) must not affect anything.
|
||||
func TestFIFO_OnCancel_NotPresent(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateReady
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
r := reqCh("a")
|
||||
s.OnRequest(r) // fast-path served immediately
|
||||
|
||||
// Cancel after grant — should be a harmless no-op.
|
||||
s.OnCancel(r)
|
||||
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Errorf("served(a)=%d want 1 (cancel of granted request is a no-op)", got)
|
||||
}
|
||||
if len(s.queued) != 0 {
|
||||
t.Errorf("queue should be empty, len=%d", len(s.queued))
|
||||
}
|
||||
}
|
||||
|
||||
// newFIFOWithLimit builds a FIFO whose single model has the given concurrency
|
||||
// limit, already in StateReady so every request exercises the fast path.
|
||||
func newFIFOWithLimit(t *testing.T, model string, limit int) (*FIFO, *fakeEffects) {
|
||||
t.Helper()
|
||||
eff := newFakeEffects()
|
||||
eff.states[model] = process.StateReady
|
||||
models := map[string]config.ModelConfig{
|
||||
model: {ConcurrencyLimit: limit},
|
||||
}
|
||||
s := NewFIFO("test", logmon.NewWriter(io.Discard), &stubPlanner{}, config.FifoConfig{}, models, eff)
|
||||
return s, eff
|
||||
}
|
||||
|
||||
// TestFIFO_ConcurrencyLimit_RejectsOverLimit verifies that a request arriving
|
||||
// while the model is at capacity gets an error grant instead of being served,
|
||||
// and that a new request succeeds once an in-flight one completes.
|
||||
func TestFIFO_ConcurrencyLimit_RejectsOverLimit(t *testing.T) {
|
||||
s, eff := newFIFOWithLimit(t, "a", 1)
|
||||
|
||||
// First request: served (inFlight 0 → 1).
|
||||
s.OnRequest(req("a"))
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Fatalf("served(a)=%d want 1", got)
|
||||
}
|
||||
|
||||
// Second request while slot is occupied: rejected with HTTPError 429.
|
||||
s.OnRequest(req("a"))
|
||||
if got := eff.errored("a"); got != 1 {
|
||||
t.Fatalf("errored(a)=%d want 1 (over-limit)", got)
|
||||
}
|
||||
var httpErr shared.HTTPError
|
||||
if !errors.As(eff.grants[len(eff.grants)-1].err, &httpErr) {
|
||||
t.Fatalf("err=%v want HTTPError", eff.grants[len(eff.grants)-1].err)
|
||||
}
|
||||
if httpErr.StatusCode() != http.StatusTooManyRequests {
|
||||
t.Fatalf("StatusCode()=%d want 429", httpErr.StatusCode())
|
||||
}
|
||||
if httpErr.Header().Get("Retry-After") == "" {
|
||||
t.Fatal("missing Retry-After header")
|
||||
}
|
||||
|
||||
// After the in-flight request finishes, a new request succeeds.
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
|
||||
s.OnRequest(req("a"))
|
||||
if got := eff.served("a"); got != 2 {
|
||||
t.Fatalf("served(a)=%d want 2 after drain", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_ConcurrencyLimit_DefaultIsTen verifies that a model without an
|
||||
// explicit ConcurrencyLimit gets the default cap of 10.
|
||||
func TestFIFO_ConcurrencyLimit_DefaultIsTen(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateReady
|
||||
// nil models → every model gets defaultConcurrencyLimit (10).
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
s.OnRequest(req("a"))
|
||||
}
|
||||
if got := eff.served("a"); got != 10 {
|
||||
t.Fatalf("served(a)=%d want 10 (default limit)", got)
|
||||
}
|
||||
|
||||
// 11th request is rejected.
|
||||
s.OnRequest(req("a"))
|
||||
if got := eff.errored("a"); got != 1 {
|
||||
t.Fatalf("errored(a)=%d want 1 (over default limit)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_ConcurrencyLimit_CustomLimit verifies a ConcurrencyLimit greater
|
||||
// than zero overrides the default.
|
||||
func TestFIFO_ConcurrencyLimit_CustomLimit(t *testing.T) {
|
||||
s, eff := newFIFOWithLimit(t, "a", 2)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
s.OnRequest(req("a"))
|
||||
s.OnRequest(req("a"))
|
||||
|
||||
if got := eff.served("a"); got != 2 {
|
||||
t.Fatalf("served(a)=%d want 2 (custom limit)", got)
|
||||
}
|
||||
if got := eff.errored("a"); got != 1 {
|
||||
t.Fatalf("errored(a)=%d want 1 (over custom limit)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_ConcurrencyLimit_SwapWaiters verifies that when more swap waiters
|
||||
// exist than the concurrency limit, excess waiters are rejected on swap
|
||||
// completion rather than exceeding the limit.
|
||||
func TestFIFO_ConcurrencyLimit_SwapWaiters(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
models := map[string]config.ModelConfig{
|
||||
"a": {ConcurrencyLimit: 2},
|
||||
}
|
||||
s := NewFIFO("test", logmon.NewWriter(io.Discard), &stubPlanner{}, config.FifoConfig{}, models, eff)
|
||||
|
||||
// Three requests arrive while model is loading: one starts swap, two join.
|
||||
s.OnRequest(req("a"))
|
||||
s.OnRequest(req("a"))
|
||||
s.OnRequest(req("a"))
|
||||
|
||||
if got := eff.startsFor("a"); got != 1 {
|
||||
t.Fatalf("StartSwap(a)=%d want 1", got)
|
||||
}
|
||||
|
||||
// Swap completes: two served (limit), one rejected.
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
|
||||
if got := eff.served("a"); got != 2 {
|
||||
t.Fatalf("served(a)=%d want 2 (limit on swap completion)", got)
|
||||
}
|
||||
if got := eff.errored("a"); got != 1 {
|
||||
t.Fatalf("errored(a)=%d want 1 (excess waiter rejected)", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
// Package scheduler contains the request-scheduling strategies used by the
|
||||
// router's baseRouter. A Scheduler owns the queue, in-flight tracking, and the
|
||||
// decision tree for when to start a swap versus queue a request. The baseRouter
|
||||
// owns the channels, run loop, and process machinery, and exposes the
|
||||
// side-effects a scheduler needs through the Effects interface.
|
||||
//
|
||||
// Splitting these apart lets the scheduling strategy be swapped out
|
||||
// independently of both the process machinery (baseRouter) and the eviction
|
||||
// policy (Swapper). FIFO is the first and currently only implementation.
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// ErrModelNotFound is granted to callers whose model is not handled by this
|
||||
// router. It is an alias for shared.ErrNoLocalModelFound.
|
||||
var ErrModelNotFound = shared.ErrNoLocalModelFound
|
||||
|
||||
// Swapper is the eviction policy: it decides which running models must be
|
||||
// stopped before a target can serve. It is orthogonal to the scheduling
|
||||
// strategy — any Scheduler works with any Swapper.
|
||||
type Swapper interface {
|
||||
// EvictionFor returns running model IDs that must be stopped before
|
||||
// target can serve. running is the complete set the scheduler considers
|
||||
// live: every process that is not stopped, unioned with the targets of
|
||||
// in-flight swaps the scheduler has already committed to (which are not yet
|
||||
// visible in process state). The planner does not inspect process state
|
||||
// itself. Pure decision; must not log.
|
||||
EvictionFor(target string, running []string) []string
|
||||
|
||||
// OnSwapStart runs once at the start of every swap, with the same running
|
||||
// set EvictionFor was given for this decision. Planners may log their
|
||||
// decision here at whatever verbosity they choose.
|
||||
OnSwapStart(target string, running []string)
|
||||
}
|
||||
|
||||
// Scheduler decides what happens to each event the router's run loop receives.
|
||||
// All methods run on that single run-loop goroutine, so implementations need no
|
||||
// internal locking for their own state.
|
||||
type Scheduler interface {
|
||||
// OnRequest handles one incoming ServeHTTP request.
|
||||
OnRequest(req HandlerReq)
|
||||
// OnCancel handles a request whose client has disconnected before it was
|
||||
// granted. The scheduler must remove the request from its queue and from
|
||||
// any in-flight swap's waiters so it never triggers a model load or grant
|
||||
// for a caller that is no longer there.
|
||||
OnCancel(req HandlerReq)
|
||||
// OnSwapDone handles a swap goroutine reporting completion.
|
||||
OnSwapDone(ev SwapDone)
|
||||
// OnServeDone handles a tracked ServeHTTP finishing (in-flight decrement).
|
||||
OnServeDone(ev ServeDoneEvent)
|
||||
// OnUnload reconciles scheduler state for an unload, stops the targeted
|
||||
// processes via Effects, and drains the queue. It must block until the
|
||||
// targeted processes have stopped.
|
||||
OnUnload(targets []string, timeout time.Duration)
|
||||
// OnShutdown grants err to every waiter the scheduler still holds (active
|
||||
// swap waiters and queued requests). Process teardown is the baseRouter's
|
||||
// responsibility.
|
||||
OnShutdown(err error)
|
||||
}
|
||||
|
||||
// Effects is implemented by the baseRouter. The scheduler calls back through it
|
||||
// for every side-effect: inspecting process state, launching swaps, responding
|
||||
// to callers, and stopping processes.
|
||||
type Effects interface {
|
||||
// ModelState returns the current state of a model's process. ok is false
|
||||
// when the model is not handled by this router.
|
||||
ModelState(modelID string) (process.ProcessState, bool)
|
||||
// RunningModels returns the state of every process that is not stopped or
|
||||
// shut down, keyed by model ID. The scheduler uses it to build the running
|
||||
// set it hands the Swapper.
|
||||
RunningModels() map[string]process.ProcessState
|
||||
// StartSwap launches the swap goroutine for modelID, stopping evict first.
|
||||
StartSwap(modelID string, evict []string)
|
||||
// GrantError responds to a caller with an error.
|
||||
GrantError(req HandlerReq, err error)
|
||||
// GrantServe hands a caller the wrapped handler for modelID and reports
|
||||
// whether the caller was still there to receive it. The scheduler bumps
|
||||
// its in-flight count only when this returns true.
|
||||
GrantServe(req HandlerReq, modelID string) bool
|
||||
// StopProcesses stops the named processes in parallel and blocks until all
|
||||
// have stopped. Unknown IDs are skipped.
|
||||
StopProcesses(timeout time.Duration, ids []string)
|
||||
}
|
||||
|
||||
// New returns a Scheduler selected by conf.Routing.Scheduler.Use, configured
|
||||
// from conf and bound to the given planner and effects. Currently only "fifo"
|
||||
// (the default) is supported.
|
||||
func New(conf config.Config, name string, logger *logmon.Monitor, planner Swapper, eff Effects) (Scheduler, error) {
|
||||
use := conf.Routing.Scheduler.Use
|
||||
if use == "" {
|
||||
use = "fifo"
|
||||
}
|
||||
switch use {
|
||||
case "fifo":
|
||||
return NewFIFO(name, logger, planner, conf.Routing.Scheduler.Settings.Fifo, conf.Models, eff), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported scheduler type: %q", use)
|
||||
}
|
||||
}
|
||||
|
||||
// HandlerReq is one in-flight ServeHTTP request waiting for a routing decision.
|
||||
type HandlerReq struct {
|
||||
Model string
|
||||
Ctx context.Context
|
||||
Respond chan HandlerResp
|
||||
PositionCh chan int
|
||||
}
|
||||
|
||||
// HandlerResp is the routing decision returned to a HandlerReq's caller: either
|
||||
// a handler to serve with, or an error.
|
||||
type HandlerResp struct {
|
||||
HandleFunc http.HandlerFunc
|
||||
Err error
|
||||
}
|
||||
|
||||
// SwapDone is reported by a swap goroutine when its target is ready (or failed).
|
||||
type SwapDone struct {
|
||||
ModelID string
|
||||
Err error
|
||||
}
|
||||
|
||||
// ServeDoneEvent is reported when a tracked ServeHTTP handler returns.
|
||||
type ServeDoneEvent struct {
|
||||
ModelID string
|
||||
}
|
||||
+153
-44
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -9,19 +10,127 @@ import (
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// apiUnloadTimeout is used by the API endpoints to stop processes
|
||||
const apiUnloadTimeout = 10 * time.Second
|
||||
|
||||
// modelRecord is one entry in the OpenAI-compatible /v1/models listing.
|
||||
type modelRecord struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Meta map[string]any `json:"meta,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Architecture map[string]any `json:"architecture,omitempty"`
|
||||
Capabilities map[string]any `json:"capabilities,omitempty"`
|
||||
SupportedParameters []string `json:"supported_parameters,omitempty"`
|
||||
ContextLength int `json:"context_length,omitempty"`
|
||||
Meta map[string]any `json:"meta,omitempty"`
|
||||
}
|
||||
|
||||
// cappedMetadataKeys are top-level /v1/models fields produced by the
|
||||
// capabilities renderer. If a model's metadata block defines any of these
|
||||
// keys, the renderer's values win and the metadata keys are dropped.
|
||||
var cappedMetadataKeys = map[string]struct{}{
|
||||
"architecture": {},
|
||||
"capabilities": {},
|
||||
"supported_parameters": {},
|
||||
"context_length": {},
|
||||
}
|
||||
|
||||
// renderCapabilities converts a model's capabilities config into additional
|
||||
// /v1/models fields. Returns zero values when caps.Empty() is true.
|
||||
func renderCapabilities(caps config.ModelCapConfig) (arch map[string]any, capsMap map[string]any, params []string, ctxLen int) {
|
||||
if caps.Empty() {
|
||||
return
|
||||
}
|
||||
|
||||
hasIn := len(caps.In) > 0
|
||||
hasOut := len(caps.Out) > 0
|
||||
|
||||
if hasIn || hasOut {
|
||||
arch = make(map[string]any)
|
||||
}
|
||||
if hasIn {
|
||||
arch["input_modalities"] = caps.In
|
||||
}
|
||||
if hasOut {
|
||||
arch["output_modalities"] = caps.Out
|
||||
}
|
||||
if hasIn && hasOut {
|
||||
arch["modality"] = strings.Join(caps.In, "+") + "->" + strings.Join(caps.Out, "+")
|
||||
}
|
||||
|
||||
// Build capabilities map only if there's something to put in it.
|
||||
if hasIn || hasOut || caps.Tools || caps.Reranker {
|
||||
capsMap = make(map[string]any)
|
||||
}
|
||||
|
||||
if hasIn {
|
||||
if contains(caps.In, "image") {
|
||||
capsMap["vision"] = true
|
||||
}
|
||||
}
|
||||
if hasIn && hasOut {
|
||||
if contains(caps.In, "audio") && contains(caps.Out, "text") {
|
||||
capsMap["audio_transcriptions"] = true
|
||||
}
|
||||
if contains(caps.In, "text") && contains(caps.Out, "audio") {
|
||||
capsMap["audio_speech"] = true
|
||||
}
|
||||
if contains(caps.In, "text") && contains(caps.Out, "image") {
|
||||
capsMap["image_generation"] = true
|
||||
}
|
||||
if contains(caps.In, "image") && contains(caps.Out, "image") {
|
||||
capsMap["image_to_image"] = true
|
||||
}
|
||||
}
|
||||
|
||||
if caps.Tools {
|
||||
capsMap["function_calling"] = true
|
||||
params = []string{"tools", "tool_choice"}
|
||||
}
|
||||
|
||||
if caps.Reranker {
|
||||
capsMap["reranker"] = true
|
||||
}
|
||||
|
||||
if caps.Context > 0 {
|
||||
ctxLen = caps.Context
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// contains reports whether s is present in ss.
|
||||
func contains(ss []string, s string) bool {
|
||||
for _, v := range ss {
|
||||
if v == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// filterCappedMetadata returns metadata with renderer-owned keys removed.
|
||||
func filterCappedMetadata(md map[string]any) map[string]any {
|
||||
if len(md) == 0 {
|
||||
return nil
|
||||
}
|
||||
filtered := make(map[string]any, len(md))
|
||||
for k, v := range md {
|
||||
if _, capped := cappedMetadataKeys[k]; !capped {
|
||||
filtered[k] = v
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return nil
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// handleListModels serves the OpenAI-compatible model listing: local models
|
||||
@@ -30,7 +139,7 @@ func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
|
||||
created := time.Now().Unix()
|
||||
data := make([]modelRecord, 0, len(s.cfg.Models))
|
||||
|
||||
newRecord := func(id, name, description string, metadata map[string]any) modelRecord {
|
||||
newRecord := func(id, name, description string, metadata map[string]any, caps config.ModelCapConfig) modelRecord {
|
||||
rec := modelRecord{
|
||||
ID: id,
|
||||
Object: "model",
|
||||
@@ -39,6 +148,10 @@ func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
|
||||
Name: strings.TrimSpace(name),
|
||||
Description: strings.TrimSpace(description),
|
||||
}
|
||||
rec.Architecture, rec.Capabilities, rec.SupportedParameters, rec.ContextLength = renderCapabilities(caps)
|
||||
if !caps.Empty() {
|
||||
metadata = filterCappedMetadata(metadata)
|
||||
}
|
||||
if len(metadata) > 0 {
|
||||
rec.Meta = map[string]any{"llamaswap": metadata}
|
||||
}
|
||||
@@ -49,12 +162,12 @@ func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
|
||||
if mc.Unlisted {
|
||||
continue
|
||||
}
|
||||
data = append(data, newRecord(id, mc.Name, mc.Description, mc.Metadata))
|
||||
data = append(data, newRecord(id, mc.Name, mc.Description, mc.Metadata, mc.Capabilities))
|
||||
|
||||
if s.cfg.IncludeAliasesInList {
|
||||
for _, alias := range mc.Aliases {
|
||||
if alias := strings.TrimSpace(alias); alias != "" {
|
||||
data = append(data, newRecord(alias, mc.Name, mc.Description, mc.Metadata))
|
||||
data = append(data, newRecord(alias, mc.Name, mc.Description, mc.Metadata, mc.Capabilities))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -62,7 +175,7 @@ func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
for peerID, peer := range s.cfg.Peers {
|
||||
for _, modelID := range peer.Models {
|
||||
data = append(data, newRecord(modelID, peerID+": "+modelID, "", map[string]any{"peerID": peerID}))
|
||||
data = append(data, newRecord(modelID, peerID+": "+modelID, "", map[string]any{"peerID": peerID}, config.ModelCapConfig{}))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,7 +207,7 @@ type runningModel struct {
|
||||
// handleUnload stops every running local process. Peer models are remote and
|
||||
// unaffected.
|
||||
func (s *Server) handleUnload(w http.ResponseWriter, r *http.Request) {
|
||||
s.local.Unload(0)
|
||||
s.local.Unload(apiUnloadTimeout)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}
|
||||
@@ -160,7 +273,7 @@ func (s *Server) startPreload() {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
req = req.WithContext(router.SetContext(req.Context(), router.ReqContextData{Model: modelID, ModelID: modelID}))
|
||||
req = req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: modelID, ModelID: modelID, Metadata: make(map[string]string)}))
|
||||
|
||||
dw := &discardResponseWriter{status: http.StatusOK}
|
||||
s.local.ServeHTTP(dw, req)
|
||||
@@ -203,9 +316,9 @@ func handleUpstreamRedirect(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
||||
upstreamPath := r.PathValue("upstreamPath")
|
||||
|
||||
searchName, modelID, remainingPath, found := findModelInPath(s.cfg, "/"+upstreamPath)
|
||||
searchName, modelID, remainingPath, found := shared.FindModelInPath(s.cfg, "/"+upstreamPath)
|
||||
if !found {
|
||||
router.SendResponse(w, r, http.StatusNotFound, "model not found")
|
||||
shared.SendResponse(w, r, http.StatusNotFound, "model not found")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -227,7 +340,29 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
||||
// Strip the /upstream/<model> prefix before forwarding.
|
||||
r.URL.Path = remainingPath
|
||||
// Pin the resolved model so the router skips body/query extraction.
|
||||
*r = *r.WithContext(router.SetContext(r.Context(), router.ReqContextData{Model: searchName, ModelID: modelID}))
|
||||
*r = *r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: searchName, ModelID: modelID, Metadata: make(map[string]string)}))
|
||||
|
||||
// If the path matches an upstream.ignorePaths entry and the model is
|
||||
// not already loaded, refuse the request without triggering a swap. The
|
||||
// server was not able to process the response because the model was not
|
||||
// already loaded.
|
||||
for _, re := range s.cfg.Upstream.IgnorePaths {
|
||||
if !re.MatchString(remainingPath) {
|
||||
continue
|
||||
}
|
||||
if s.local.Handles(modelID) {
|
||||
state, ok := s.local.RunningModels()[modelID]
|
||||
if !ok || state != process.StateReady {
|
||||
shared.SendResponse(w, r, http.StatusConflict,
|
||||
fmt.Sprintf("model %s is not loaded; path matches upstream.ignorePaths", modelID))
|
||||
return
|
||||
}
|
||||
}
|
||||
// Either the model is already loaded (no swap would be triggered)
|
||||
// or this is a peer model (peer proxying never swaps). Fall through
|
||||
// to normal dispatch.
|
||||
break
|
||||
}
|
||||
|
||||
switch {
|
||||
case s.local.Handles(modelID):
|
||||
@@ -235,32 +370,6 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
||||
case s.peer.Handles(modelID):
|
||||
s.peer.ServeHTTP(w, r)
|
||||
default:
|
||||
router.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID)
|
||||
shared.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID)
|
||||
}
|
||||
}
|
||||
|
||||
// findModelInPath walks a slash-separated path, building up segments until one
|
||||
// matches a configured model. This resolves model names that contain slashes
|
||||
// (e.g. "author/model"). Returns the matched name, its real model ID, the
|
||||
// remaining path, and whether a match was found.
|
||||
func findModelInPath(cfg config.Config, path string) (searchName, realName, remainingPath string, found bool) {
|
||||
parts := strings.Split(strings.TrimSpace(path), "/")
|
||||
name := ""
|
||||
|
||||
for i, part := range parts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
if name == "" {
|
||||
name = part
|
||||
} else {
|
||||
name = name + "/" + part
|
||||
}
|
||||
|
||||
if modelID, ok := cfg.RealModelName(name); ok {
|
||||
return name, modelID, "/" + strings.Join(parts[i+1:], "/"), true
|
||||
}
|
||||
}
|
||||
|
||||
return "", "", "", false
|
||||
}
|
||||
|
||||
+428
-2
@@ -2,11 +2,17 @@ package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
func TestServer_HandleListModels(t *testing.T) {
|
||||
@@ -78,6 +84,7 @@ func TestServer_HandleListModels_Aliases(t *testing.T) {
|
||||
|
||||
func TestServer_FindModelInPath(t *testing.T) {
|
||||
cfg := config.Config{Models: map[string]config.ModelConfig{
|
||||
"author": {},
|
||||
"author/model": {},
|
||||
"simple": {},
|
||||
}}
|
||||
@@ -91,13 +98,14 @@ func TestServer_FindModelInPath(t *testing.T) {
|
||||
{"/simple/v1/chat", "simple", "/v1/chat", true},
|
||||
{"/author/model/v1/chat", "author/model", "/v1/chat", true},
|
||||
{"/author/model", "author/model", "/", true},
|
||||
{"/author/v1/chat", "author", "/v1/chat", true},
|
||||
{"/missing/v1", "", "", false},
|
||||
{"/", "", "", false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
name, _, rem, found := findModelInPath(cfg, c.path)
|
||||
name, _, rem, found := shared.FindModelInPath(cfg, c.path)
|
||||
if found != c.wantFound || name != c.wantName || (found && rem != c.wantRem) {
|
||||
t.Errorf("findModelInPath(%q) = (%q,%q,%v), want (%q,%q,%v)",
|
||||
t.Errorf("FindModelInPath(%q) = (%q,%q,%v), want (%q,%q,%v)",
|
||||
c.path, name, rem, found, c.wantName, c.wantRem, c.wantFound)
|
||||
}
|
||||
}
|
||||
@@ -133,6 +141,165 @@ func TestServer_HandleUpstream(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func upstreamMetricsServer(response string) *Server {
|
||||
cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
|
||||
proxylog := logmon.NewWriter(io.Discard)
|
||||
s := &Server{
|
||||
cfg: cfg,
|
||||
muxlog: logmon.NewWriter(io.Discard),
|
||||
proxylog: proxylog,
|
||||
upstreamlog: logmon.NewWriter(io.Discard),
|
||||
inflight: &inflightCounter{},
|
||||
metrics: newMetricsMonitor(proxylog, 10, 0),
|
||||
local: newStubRouter([]string{"m1"}, response),
|
||||
peer: newStubRouter(nil, ""),
|
||||
}
|
||||
s.routes()
|
||||
return s
|
||||
}
|
||||
|
||||
func TestServer_HandleUpstream_IgnorePaths(t *testing.T) {
|
||||
// Compile a pattern that matches static asset suffixes.
|
||||
pattern := regexp.MustCompile(`.*\.(js|json|css|png|gif|jpg|jpeg|txt)$`)
|
||||
|
||||
t.Run("matched path, model not loaded, returns 409", func(t *testing.T) {
|
||||
local := newStubRouter([]string{"m1"}, "upstream-body")
|
||||
// running is nil/empty: model is not in RunningModels() => not loaded.
|
||||
s := newTestServer(local, newStubRouter(nil, ""))
|
||||
s.cfg = config.Config{
|
||||
Models: map[string]config.ModelConfig{"m1": {}},
|
||||
Upstream: config.UpstreamConfig{
|
||||
IgnorePaths: []*regexp.Regexp{pattern},
|
||||
},
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/foo.js", nil))
|
||||
|
||||
if w.Code != http.StatusConflict {
|
||||
t.Fatalf("status = %d, want %d (body=%q)", w.Code, http.StatusConflict, w.Body.String())
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "not loaded") {
|
||||
t.Errorf("body = %q, want it to contain 'not loaded'", w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("matched path, model already loaded, serves normally", func(t *testing.T) {
|
||||
local := newStubRouter([]string{"m1"}, "upstream-body")
|
||||
local.running = map[string]process.ProcessState{"m1": process.StateReady}
|
||||
s := newTestServer(local, newStubRouter(nil, ""))
|
||||
s.cfg = config.Config{
|
||||
Models: map[string]config.ModelConfig{"m1": {}},
|
||||
Upstream: config.UpstreamConfig{
|
||||
IgnorePaths: []*regexp.Regexp{pattern},
|
||||
},
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/foo.js", nil))
|
||||
|
||||
if w.Code != http.StatusOK || w.Body.String() != "upstream-body" {
|
||||
t.Fatalf("status=%d body=%q, want 200 'upstream-body'", w.Code, w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-matched path, model not loaded, serves normally", func(t *testing.T) {
|
||||
local := newStubRouter([]string{"m1"}, "upstream-body")
|
||||
s := newTestServer(local, newStubRouter(nil, ""))
|
||||
s.cfg = config.Config{
|
||||
Models: map[string]config.ModelConfig{"m1": {}},
|
||||
Upstream: config.UpstreamConfig{
|
||||
IgnorePaths: []*regexp.Regexp{pattern},
|
||||
},
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/v1/chat/completions", nil))
|
||||
|
||||
if w.Code != http.StatusOK || w.Body.String() != "upstream-body" {
|
||||
t.Fatalf("status=%d body=%q, want 200 'upstream-body'", w.Code, w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("matched path, peer model, serves normally", func(t *testing.T) {
|
||||
// Peer routers do not appear via RunningModels on the local router;
|
||||
// they should fall through to normal dispatch without 409.
|
||||
local := newStubRouter(nil, "")
|
||||
peer := newStubRouter([]string{"m1"}, "peer-body")
|
||||
s := newTestServer(local, peer)
|
||||
s.cfg = config.Config{
|
||||
Models: map[string]config.ModelConfig{"m1": {}},
|
||||
Upstream: config.UpstreamConfig{
|
||||
IgnorePaths: []*regexp.Regexp{pattern},
|
||||
},
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/foo.js", nil))
|
||||
|
||||
if w.Code != http.StatusOK || w.Body.String() != "peer-body" {
|
||||
t.Fatalf("status=%d body=%q, want 200 'peer-body'", w.Code, w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_HandleUpstream_MetricsRecordsSupportedPath(t *testing.T) {
|
||||
resp := `{"usage":{"prompt_tokens":3,"completion_tokens":5}}`
|
||||
s := upstreamMetricsServer(resp)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/chat/completions", strings.NewReader(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
s.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK || w.Body.String() != resp {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
entries := s.metrics.getMetrics()
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("want 1 metrics entry, got %d", len(entries))
|
||||
}
|
||||
if entries[0].Model != "m1" {
|
||||
t.Errorf("model = %q, want m1", entries[0].Model)
|
||||
}
|
||||
if entries[0].ReqPath != "/v1/chat/completions" {
|
||||
t.Errorf("req_path = %q, want /v1/chat/completions", entries[0].ReqPath)
|
||||
}
|
||||
if entries[0].Tokens.InputTokens != 3 || entries[0].Tokens.OutputTokens != 5 {
|
||||
t.Errorf("tokens = %+v, want input=3 output=5", entries[0].Tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_HandleUpstream_MetricsSkipsUnsupportedPath(t *testing.T) {
|
||||
s := upstreamMetricsServer("ok")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/upstream/m1/probe", strings.NewReader(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
s.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK || w.Body.String() != "ok" {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if len(s.metrics.getMetrics()) != 0 {
|
||||
t.Errorf("want no metrics entries for unsupported path, got %d", len(s.metrics.getMetrics()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_HandleUpstream_MetricsSkipsGET(t *testing.T) {
|
||||
s := upstreamMetricsServer(`{"usage":{}}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/v1/chat/completions", nil))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d", w.Code)
|
||||
}
|
||||
if len(s.metrics.getMetrics()) != 0 {
|
||||
t.Errorf("want no metrics entries for GET upstream, got %d", len(s.metrics.getMetrics()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_HandleMetrics_Unavailable(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
@@ -157,3 +324,262 @@ func TestServer_Redirects(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_HandleListModels_Capabilities(t *testing.T) {
|
||||
newServer := func(mc config.ModelConfig) *Server {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
s.cfg = config.Config{Models: map[string]config.ModelConfig{"m": mc}}
|
||||
return s
|
||||
}
|
||||
getModel := func(t *testing.T, s *Server) modelRecord {
|
||||
t.Helper()
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/v1/models", nil))
|
||||
var resp struct {
|
||||
Data []modelRecord `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if len(resp.Data) != 1 {
|
||||
t.Fatalf("expected 1 model, got %d", len(resp.Data))
|
||||
}
|
||||
return resp.Data[0]
|
||||
}
|
||||
|
||||
t.Run("all_fields", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{
|
||||
Capabilities: config.ModelCapConfig{
|
||||
In: []string{"text", "image"},
|
||||
Out: []string{"text", "audio"},
|
||||
Tools: true,
|
||||
Context: 100000,
|
||||
},
|
||||
}))
|
||||
if m.Architecture == nil {
|
||||
t.Fatal("architecture is nil")
|
||||
}
|
||||
if !anySliceStrEqual(m.Architecture["input_modalities"], []string{"text", "image"}) {
|
||||
t.Errorf("input_modalities = %v", m.Architecture["input_modalities"])
|
||||
}
|
||||
if !anySliceStrEqual(m.Architecture["output_modalities"], []string{"text", "audio"}) {
|
||||
t.Errorf("output_modalities = %v", m.Architecture["output_modalities"])
|
||||
}
|
||||
if m.Architecture["modality"] != "text+image->text+audio" {
|
||||
t.Errorf("modality = %v", m.Architecture["modality"])
|
||||
}
|
||||
if m.Capabilities == nil || m.Capabilities["vision"] != true {
|
||||
t.Errorf("vision = %v", m.Capabilities)
|
||||
}
|
||||
if m.Capabilities["audio_speech"] != true {
|
||||
t.Errorf("audio_speech = %v", m.Capabilities["audio_speech"])
|
||||
}
|
||||
if m.Capabilities["function_calling"] != true {
|
||||
t.Errorf("function_calling = %v", m.Capabilities["function_calling"])
|
||||
}
|
||||
if !stringSliceEqual(m.SupportedParameters, []string{"tools", "tool_choice"}) {
|
||||
t.Errorf("supported_parameters = %v", m.SupportedParameters)
|
||||
}
|
||||
if m.ContextLength != 100000 {
|
||||
t.Errorf("context_length = %d", m.ContextLength)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("in_only", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{
|
||||
Capabilities: config.ModelCapConfig{In: []string{"text", "image"}},
|
||||
}))
|
||||
if m.Architecture == nil {
|
||||
t.Fatal("architecture is nil")
|
||||
}
|
||||
if _, ok := m.Architecture["output_modalities"]; ok {
|
||||
t.Error("should not have output_modalities")
|
||||
}
|
||||
if _, ok := m.Architecture["modality"]; ok {
|
||||
t.Error("should not have modality")
|
||||
}
|
||||
if m.Capabilities == nil || m.Capabilities["vision"] != true {
|
||||
t.Error("expected vision: true")
|
||||
}
|
||||
if m.SupportedParameters != nil {
|
||||
t.Error("should not have supported_parameters")
|
||||
}
|
||||
if m.ContextLength != 0 {
|
||||
t.Error("should not have context_length")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("out_only", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{
|
||||
Capabilities: config.ModelCapConfig{Out: []string{"audio"}},
|
||||
}))
|
||||
if m.Architecture == nil {
|
||||
t.Fatal("architecture is nil")
|
||||
}
|
||||
if _, ok := m.Architecture["input_modalities"]; ok {
|
||||
t.Error("should not have input_modalities")
|
||||
}
|
||||
if len(m.Capabilities) > 0 {
|
||||
t.Errorf("expected no capabilities, got %v", m.Capabilities)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tools", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{
|
||||
Capabilities: config.ModelCapConfig{Tools: true},
|
||||
}))
|
||||
if m.Capabilities == nil || m.Capabilities["function_calling"] != true {
|
||||
t.Error("expected function_calling: true")
|
||||
}
|
||||
if !stringSliceEqual(m.SupportedParameters, []string{"tools", "tool_choice"}) {
|
||||
t.Errorf("supported_parameters = %v", m.SupportedParameters)
|
||||
}
|
||||
if m.Architecture != nil {
|
||||
t.Error("should not have architecture")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("reranker", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{
|
||||
Capabilities: config.ModelCapConfig{Reranker: true},
|
||||
}))
|
||||
if m.Capabilities == nil || m.Capabilities["reranker"] != true {
|
||||
t.Error("expected reranker: true")
|
||||
}
|
||||
if m.Architecture != nil {
|
||||
t.Error("should not have architecture")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("context", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{
|
||||
Capabilities: config.ModelCapConfig{Context: 32768},
|
||||
}))
|
||||
if m.ContextLength != 32768 {
|
||||
t.Errorf("context_length = %d", m.ContextLength)
|
||||
}
|
||||
if m.Architecture != nil {
|
||||
t.Error("should not have architecture")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("audio_transcriptions", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{
|
||||
Capabilities: config.ModelCapConfig{In: []string{"audio"}, Out: []string{"text"}},
|
||||
}))
|
||||
if m.Capabilities == nil || m.Capabilities["audio_transcriptions"] != true {
|
||||
t.Error("expected audio_transcriptions: true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("image_generation", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{
|
||||
Capabilities: config.ModelCapConfig{In: []string{"text"}, Out: []string{"image"}},
|
||||
}))
|
||||
if m.Capabilities == nil || m.Capabilities["image_generation"] != true {
|
||||
t.Error("expected image_generation: true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("image_to_image", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{
|
||||
Capabilities: config.ModelCapConfig{In: []string{"image"}, Out: []string{"image"}},
|
||||
}))
|
||||
if m.Capabilities == nil || m.Capabilities["image_to_image"] != true {
|
||||
t.Error("expected image_to_image: true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty_skip", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{}))
|
||||
if m.Architecture != nil {
|
||||
t.Error("should not have architecture")
|
||||
}
|
||||
if m.Capabilities != nil {
|
||||
t.Error("should not have capabilities")
|
||||
}
|
||||
if m.SupportedParameters != nil {
|
||||
t.Error("should not have supported_parameters")
|
||||
}
|
||||
if m.ContextLength != 0 {
|
||||
t.Error("should not have context_length")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("metadata_precedence", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{
|
||||
Capabilities: config.ModelCapConfig{In: []string{"text"}},
|
||||
Metadata: map[string]any{
|
||||
"architecture": "should-be-dropped",
|
||||
"custom_field": "should-remain",
|
||||
"capabilities": "also-dropped",
|
||||
"other_metadata": "also-remain",
|
||||
},
|
||||
}))
|
||||
if m.Architecture == nil || m.Architecture["input_modalities"] == nil {
|
||||
t.Fatal("architecture should be rendered, not from metadata")
|
||||
}
|
||||
if m.Meta == nil || m.Meta["llamaswap"] == nil {
|
||||
t.Fatal("meta.llamaswap should exist")
|
||||
}
|
||||
meta := m.Meta["llamaswap"].(map[string]any)
|
||||
if _, ok := meta["architecture"]; ok {
|
||||
t.Error("architecture should be filtered from metadata")
|
||||
}
|
||||
if _, ok := meta["custom_field"]; !ok {
|
||||
t.Error("custom_field should remain in metadata")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("metadata_passthrough_no_caps", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{
|
||||
Metadata: map[string]any{
|
||||
"architecture": "preserved",
|
||||
"context_length": 4096,
|
||||
"capabilities": "preserved",
|
||||
"custom_field": "preserved",
|
||||
},
|
||||
}))
|
||||
if m.Architecture != nil {
|
||||
t.Error("should not have architecture when caps is empty")
|
||||
}
|
||||
if m.Meta == nil || m.Meta["llamaswap"] == nil {
|
||||
t.Fatal("meta.llamaswap should exist")
|
||||
}
|
||||
meta := m.Meta["llamaswap"].(map[string]any)
|
||||
if _, ok := meta["architecture"]; !ok {
|
||||
t.Error("architecture should be preserved in metadata when caps is empty")
|
||||
}
|
||||
if _, ok := meta["context_length"]; !ok {
|
||||
t.Error("context_length should be preserved in metadata when caps is empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func stringSliceEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func anySliceStrEqual(v any, want []string) bool {
|
||||
arr, ok := v.([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if len(arr) != len(want) {
|
||||
return false
|
||||
}
|
||||
for i := range arr {
|
||||
if s, ok := arr[i].(string); !ok || s != want[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
+30
-25
@@ -12,19 +12,19 @@ import (
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// apiModel is one entry in the /api/events modelStatus payload.
|
||||
type apiModel struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
Unlisted bool `json:"unlisted"`
|
||||
PeerID string `json:"peerID"`
|
||||
Aliases []string `json:"aliases,omitempty"`
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
Unlisted bool `json:"unlisted"`
|
||||
PeerID string `json:"peerID"`
|
||||
Aliases []string `json:"aliases,omitempty"`
|
||||
Capabilities map[string]any `json:"capabilities,omitempty"`
|
||||
}
|
||||
|
||||
// modelStatus returns every configured model joined with its current process
|
||||
@@ -45,13 +45,15 @@ func (s *Server) modelStatus() []apiModel {
|
||||
if st, ok := running[id]; ok {
|
||||
state = string(st)
|
||||
}
|
||||
_, capsMap, _, _ := renderCapabilities(mc.Capabilities)
|
||||
models = append(models, apiModel{
|
||||
Id: id,
|
||||
Name: mc.Name,
|
||||
Description: mc.Description,
|
||||
State: state,
|
||||
Unlisted: mc.Unlisted,
|
||||
Aliases: mc.Aliases,
|
||||
Id: id,
|
||||
Name: mc.Name,
|
||||
Description: mc.Description,
|
||||
State: state,
|
||||
Unlisted: mc.Unlisted,
|
||||
Aliases: mc.Aliases,
|
||||
Capabilities: capsMap,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -66,7 +68,7 @@ func (s *Server) modelStatus() []apiModel {
|
||||
|
||||
// handleAPIUnloadAll stops every running local process.
|
||||
func (s *Server) handleAPIUnloadAll(w http.ResponseWriter, r *http.Request) {
|
||||
s.local.Unload(0)
|
||||
s.local.Unload(apiUnloadTimeout)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{"msg": "ok"})
|
||||
}
|
||||
@@ -76,14 +78,14 @@ func (s *Server) handleAPIUnloadModel(w http.ResponseWriter, r *http.Request) {
|
||||
requested := strings.TrimPrefix(r.PathValue("model"), "/")
|
||||
realName, found := s.cfg.RealModelName(requested)
|
||||
if !found {
|
||||
router.SendResponse(w, r, http.StatusNotFound, "model not found")
|
||||
shared.SendResponse(w, r, http.StatusNotFound, "model not found")
|
||||
return
|
||||
}
|
||||
if !s.local.Handles(realName) {
|
||||
router.SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
|
||||
shared.SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
|
||||
return
|
||||
}
|
||||
s.local.Unload(0, realName)
|
||||
s.local.Unload(apiUnloadTimeout, realName)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}
|
||||
@@ -92,7 +94,7 @@ func (s *Server) handleAPIUnloadModel(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Server) handleAPIMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
data, err := s.metrics.getMetricsJSON()
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, "failed to get metrics")
|
||||
shared.SendResponse(w, r, http.StatusInternalServerError, "failed to get metrics")
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@@ -103,7 +105,9 @@ func (s *Server) handleAPIMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
// filtered to samples after the ?after=<RFC3339> timestamp.
|
||||
func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
||||
if s.perf == nil {
|
||||
router.SendResponse(w, r, http.StatusServiceUnavailable, "performance monitor not available")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
json.NewEncoder(w).Encode(map[string]bool{"enabled": false})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -112,7 +116,7 @@ func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
||||
if afterStr := r.URL.Query().Get("after"); afterStr != "" {
|
||||
after, err := time.Parse(time.RFC3339, afterStr)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusBadRequest, "invalid 'after' timestamp, use RFC3339 format")
|
||||
shared.SendResponse(w, r, http.StatusBadRequest, "invalid 'after' timestamp, use RFC3339 format")
|
||||
return
|
||||
}
|
||||
filteredSys := make([]perf.SysStat, 0, len(sysStats))
|
||||
@@ -134,6 +138,7 @@ func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"enabled": true,
|
||||
"sys_stats": sysStats,
|
||||
"gpu_stats": gpuStats,
|
||||
})
|
||||
@@ -153,19 +158,19 @@ func (s *Server) handleAPIVersion(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Server) handleAPICapture(w http.ResponseWriter, r *http.Request) {
|
||||
id, err := strconv.Atoi(r.PathValue("id"))
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusBadRequest, "invalid capture ID")
|
||||
shared.SendResponse(w, r, http.StatusBadRequest, "invalid capture ID")
|
||||
return
|
||||
}
|
||||
|
||||
capture := s.metrics.getCaptureByID(id)
|
||||
if capture == nil {
|
||||
router.SendResponse(w, r, http.StatusNotFound, "capture not found")
|
||||
shared.SendResponse(w, r, http.StatusNotFound, "capture not found")
|
||||
return
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(capture)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, "failed to marshal capture")
|
||||
shared.SendResponse(w, r, http.StatusInternalServerError, "failed to marshal capture")
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@@ -198,7 +203,7 @@ func (s *Server) handleAPIEvents(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
|
||||
shared.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
+17
-31
@@ -1,19 +1,17 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// CreateAuthMiddleware returns middleware that validates API keys when the
|
||||
// config declares any. It accepts the key via Authorization: Bearer,
|
||||
// Authorization: Basic (password field), or x-api-key. On success the auth
|
||||
// headers are stripped so they never leak to upstream. When no keys are
|
||||
// Authorization: Basic (password field), or x-api-key. When no keys are
|
||||
// configured the middleware is a pass-through.
|
||||
func CreateAuthMiddleware(cfg config.Config) chain.Middleware {
|
||||
keys := cfg.RequiredAPIKeys
|
||||
@@ -22,7 +20,7 @@ func CreateAuthMiddleware(cfg config.Config) chain.Middleware {
|
||||
return next
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
provided := extractAPIKey(r)
|
||||
provided := shared.ExtractAPIKey(r)
|
||||
|
||||
valid := false
|
||||
for _, key := range keys {
|
||||
@@ -33,41 +31,29 @@ func CreateAuthMiddleware(cfg config.Config) chain.Middleware {
|
||||
}
|
||||
if !valid {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="llama-swap"`)
|
||||
router.SendResponse(w, r, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
|
||||
shared.SendResponse(w, r, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
|
||||
return
|
||||
}
|
||||
|
||||
r.Header.Del("Authorization")
|
||||
r.Header.Del("x-api-key")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// extractAPIKey pulls a candidate API key from the request, preferring Basic,
|
||||
// then Bearer, then x-api-key.
|
||||
func extractAPIKey(r *http.Request) string {
|
||||
var bearerKey, basicKey string
|
||||
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
bearerKey = strings.TrimPrefix(auth, "Bearer ")
|
||||
} else if strings.HasPrefix(auth, "Basic ") {
|
||||
encoded := strings.TrimPrefix(auth, "Basic ")
|
||||
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
|
||||
if parts := strings.SplitN(string(decoded), ":", 2); len(parts) == 2 {
|
||||
basicKey = parts[1] // password field is the API key
|
||||
}
|
||||
// CreateRequestContextMiddleware returns middleware that extracts model and
|
||||
// auth info from the request into the context. Requests where no model can be
|
||||
// identified are rejected with a 404.
|
||||
func CreateRequestContextMiddleware(cfg config.Config) chain.Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
data, err := shared.FetchContext(r, cfg)
|
||||
if err != nil {
|
||||
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case basicKey != "":
|
||||
return basicKey
|
||||
case bearerKey != "":
|
||||
return bearerKey
|
||||
default:
|
||||
return r.Header.Get("x-api-key")
|
||||
_ = data
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,48 +1,14 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
)
|
||||
|
||||
func TestServer_ExtractAPIKey(t *testing.T) {
|
||||
basicHeader := func(user, pass string) string {
|
||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
|
||||
}
|
||||
cases := []struct {
|
||||
name string
|
||||
auth string
|
||||
xapi string
|
||||
want string
|
||||
}{
|
||||
{"none", "", "", ""},
|
||||
{"bearer", "Bearer tok123", "", "tok123"},
|
||||
{"basic", basicHeader("user", "pw-key"), "", "pw-key"},
|
||||
{"x-api-key", "", "xkey", "xkey"},
|
||||
{"basic beats bearer", basicHeader("u", "bk"), "", "bk"},
|
||||
{"bearer beats x-api-key", "Bearer btok", "xkey", "btok"},
|
||||
{"malformed basic falls back to x-api-key", "Basic !!!notbase64", "xkey", "xkey"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if c.auth != "" {
|
||||
r.Header.Set("Authorization", c.auth)
|
||||
}
|
||||
if c.xapi != "" {
|
||||
r.Header.Set("x-api-key", c.xapi)
|
||||
}
|
||||
if got := extractAPIKey(r); got != c.want {
|
||||
t.Errorf("extractAPIKey() = %q, want %q", got, c.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_SanitizeAccessControlRequestHeaders(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
@@ -74,11 +40,42 @@ func TestServer_IsTokenChar(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_RequestContextMiddleware(t *testing.T) {
|
||||
cfg := config.Config{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"llama3": {},
|
||||
},
|
||||
}
|
||||
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
mw := CreateRequestContextMiddleware(cfg)
|
||||
|
||||
t.Run("known model passes through", func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"model":"llama3"}`))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
mw(final).ServeHTTP(w, r)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing model returns 404", func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{}`))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
mw(final).ServeHTTP(w, r)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("status = %d, want 404", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_AuthMiddleware(t *testing.T) {
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Authorization") != "" || r.Header.Get("x-api-key") != "" {
|
||||
t.Error("auth headers leaked to upstream")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/sync/semaphore"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
)
|
||||
|
||||
// defaultConcurrencyLimit caps simultaneous in-flight requests per model when
|
||||
// the model config leaves concurrencyLimit unset. Matches the legacy
|
||||
// proxy.Process default.
|
||||
const defaultConcurrencyLimit = 10
|
||||
|
||||
// CreateConcurrencyMiddleware returns middleware that limits simultaneous
|
||||
// model-dispatched requests per model. Each model gets a semaphore sized to
|
||||
// its concurrencyLimit (or defaultConcurrencyLimit). A request that cannot
|
||||
// immediately acquire a slot is rejected with 429. Models without a local
|
||||
// config entry (e.g. peer-routed models) are not limited.
|
||||
func CreateConcurrencyMiddleware(cfg config.Config) chain.Middleware {
|
||||
semaphores := make(map[string]*semaphore.Weighted, len(cfg.Models))
|
||||
for id, mc := range cfg.Models {
|
||||
limit := defaultConcurrencyLimit
|
||||
if mc.ConcurrencyLimit > 0 {
|
||||
limit = mc.ConcurrencyLimit
|
||||
}
|
||||
semaphores[id] = semaphore.NewWeighted(int64(limit))
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
data, err := router.FetchContext(r, cfg)
|
||||
if err != nil {
|
||||
router.SendError(w, r, router.ErrNoModelInContext)
|
||||
return
|
||||
}
|
||||
|
||||
// fall through for peer models
|
||||
sem, ok := semaphores[data.ModelID]
|
||||
if !ok {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
if !sem.TryAcquire(1) {
|
||||
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
defer sem.Release(1)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
)
|
||||
|
||||
func concurrencyTestReq(model string) *http.Request {
|
||||
r := httptest.NewRequest("GET", "/v1/chat/completions", nil)
|
||||
return r.WithContext(router.SetContext(r.Context(), router.ReqContextData{Model: model, ModelID: model}))
|
||||
}
|
||||
|
||||
func TestServer_ConcurrencyMiddleware_RejectsOverLimit(t *testing.T) {
|
||||
cfg := config.Config{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"m1": {ConcurrencyLimit: 1},
|
||||
},
|
||||
}
|
||||
|
||||
entered := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
var once sync.Once
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
once.Do(func() { close(entered) })
|
||||
<-release
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
h := CreateConcurrencyMiddleware(cfg)(final)
|
||||
|
||||
// First request occupies the only slot.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
h.ServeHTTP(httptest.NewRecorder(), concurrencyTestReq("m1"))
|
||||
}()
|
||||
<-entered
|
||||
|
||||
// Second concurrent request is rejected with 429.
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, concurrencyTestReq("m1"))
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("over-limit status = %d, want 429", w.Code)
|
||||
}
|
||||
|
||||
// Once the slot frees, a new request succeeds.
|
||||
close(release)
|
||||
<-done
|
||||
w = httptest.NewRecorder()
|
||||
h.ServeHTTP(w, concurrencyTestReq("m1"))
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("post-release status = %d, want 200", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ConcurrencyMiddleware_UnconfiguredModelPassesThrough(t *testing.T) {
|
||||
cfg := config.Config{Models: map[string]config.ModelConfig{}}
|
||||
|
||||
called := 0
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called++
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
h := CreateConcurrencyMiddleware(cfg)(final)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, concurrencyTestReq("peer-model"))
|
||||
if w.Code != http.StatusOK || called != 1 {
|
||||
t.Fatalf("unconfigured model: status=%d called=%d, want 200/1", w.Code, called)
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,12 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -75,6 +77,55 @@ func TestServer_BodyCopier_Flush(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// hijackRecorder is an httptest.ResponseRecorder that also implements
|
||||
// http.Hijacker, returning a pipe so Hijack forwarding can be exercised.
|
||||
type hijackRecorder struct {
|
||||
*httptest.ResponseRecorder
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (h *hijackRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return h.conn, bufio.NewReadWriter(bufio.NewReader(h.conn), bufio.NewWriter(h.conn)), nil
|
||||
}
|
||||
|
||||
func TestServer_BodyCopier_Hijack(t *testing.T) {
|
||||
t.Run("forwards to underlying hijacker", func(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
bc := newBodyCopier(&hijackRecorder{httptest.NewRecorder(), server})
|
||||
conn, _, err := bc.Hijack()
|
||||
if err != nil {
|
||||
t.Fatalf("Hijack: %v", err)
|
||||
}
|
||||
if conn != server {
|
||||
t.Errorf("Hijack returned unexpected conn")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("errors when underlying writer is not a hijacker", func(t *testing.T) {
|
||||
bc := newBodyCopier(httptest.NewRecorder())
|
||||
if _, _, err := bc.Hijack(); err == nil {
|
||||
t.Error("expected error hijacking a non-Hijacker ResponseWriter")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_BodyCopier_SkipsBufferingOnUpgrade(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
bc := newBodyCopier(rec)
|
||||
bc.WriteHeader(http.StatusSwitchingProtocols)
|
||||
bc.Write([]byte("websocket frame bytes"))
|
||||
|
||||
if bc.body.Len() != 0 {
|
||||
t.Errorf("upgrade body buffered = %q, want empty", bc.body.Bytes())
|
||||
}
|
||||
if got := rec.Body.String(); got != "websocket frame bytes" {
|
||||
t.Errorf("client body = %q, want %q", got, "websocket frame bytes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_HeaderMapAndRedact(t *testing.T) {
|
||||
h := http.Header{
|
||||
"Content-Type": {"application/json"},
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
@@ -34,9 +34,9 @@ func CreateFilterMiddleware(cfg config.Config) chain.Middleware {
|
||||
return
|
||||
}
|
||||
|
||||
data, err := router.FetchContext(r, cfg)
|
||||
data, err := shared.FetchContext(r, cfg)
|
||||
if err != nil {
|
||||
router.SendError(w, r, router.ErrNoModelInContext)
|
||||
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -48,13 +48,13 @@ func CreateFilterMiddleware(cfg config.Config) chain.Middleware {
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusBadRequest, "could not read request body")
|
||||
shared.SendResponse(w, r, http.StatusBadRequest, "could not read request body")
|
||||
return
|
||||
}
|
||||
|
||||
body, err = applyFilters(body, data.Model, useModelName, filters)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, err.Error())
|
||||
shared.SendResponse(w, r, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -84,9 +84,9 @@ func CreateFormFilterMiddleware(cfg config.Config) chain.Middleware {
|
||||
return
|
||||
}
|
||||
|
||||
data, err := router.FetchContext(r, cfg)
|
||||
data, err := shared.FetchContext(r, cfg)
|
||||
if err != nil {
|
||||
router.SendError(w, r, router.ErrNoModelInContext)
|
||||
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -97,13 +97,13 @@ func CreateFormFilterMiddleware(cfg config.Config) chain.Middleware {
|
||||
}
|
||||
|
||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||
router.SendResponse(w, r, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||
shared.SendResponse(w, r, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
body, contentType, err := rewriteMultipartModel(r.MultipartForm, useModelName)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, err.Error())
|
||||
shared.SendResponse(w, r, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
+16
-5
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -13,7 +14,7 @@ import (
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// NewLoggers builds the proxy, upstream, and combined (mux) log monitors,
|
||||
@@ -75,7 +76,7 @@ func (s *Server) getLogger(logMonitorID string) (*logmon.Monitor, error) {
|
||||
case "upstream":
|
||||
return s.upstreamlog, nil
|
||||
default:
|
||||
if _, modelID, _, found := findModelInPath(s.cfg, "/"+logMonitorID); found {
|
||||
if _, modelID, _, found := shared.FindModelInPath(s.cfg, "/"+logMonitorID); found {
|
||||
if log, ok := s.local.ProcessLogger(modelID); ok {
|
||||
return log, nil
|
||||
}
|
||||
@@ -101,13 +102,13 @@ func (s *Server) handleLogStream(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
logger, err := s.getLogger(logMonitorID)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusBadRequest, err.Error())
|
||||
shared.SendResponse(w, r, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
|
||||
shared.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -150,7 +151,8 @@ var requestLogPathSkips = []string{"/wol-health", "/api/performance", "/metrics"
|
||||
|
||||
// statusRecorder wraps an http.ResponseWriter to capture the response status
|
||||
// code and the number of body bytes written, so the access log can report
|
||||
// them. Flush is forwarded so streaming handlers (SSE) still work.
|
||||
// them. Flush is forwarded so streaming handlers (SSE) still work, and Hijack
|
||||
// is forwarded so httputil.ReverseProxy can upgrade websocket connections.
|
||||
type statusRecorder struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
@@ -174,6 +176,15 @@ func (sr *statusRecorder) Flush() {
|
||||
}
|
||||
}
|
||||
|
||||
// Hijack forwards to the underlying ResponseWriter so httputil.ReverseProxy can
|
||||
// take over the connection for websocket upgrades.
|
||||
func (sr *statusRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hj, ok := sr.ResponseWriter.(http.Hijacker); ok {
|
||||
return hj.Hijack()
|
||||
}
|
||||
return nil, nil, fmt.Errorf("underlying ResponseWriter does not support hijacking")
|
||||
}
|
||||
|
||||
// clientIP resolves the originating client address, preferring proxy headers
|
||||
// over the raw connection address.
|
||||
func clientIP(r *http.Request) string {
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
@@ -135,3 +140,103 @@ func TestServer_RequestLogMiddleware(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestServer_RequestLogMiddleware_WebSocketUpgrade verifies that the access-log
|
||||
// middleware (which wraps responses in statusRecorder) does not break websocket
|
||||
// upgrades proxied through httputil.ReverseProxy. ReverseProxy requires the
|
||||
// ResponseWriter to implement http.Hijacker to take over the connection; if
|
||||
// statusRecorder does not forward Hijack, the upgrade is refused with 502.
|
||||
func TestServer_RequestLogMiddleware_WebSocketUpgrade(t *testing.T) {
|
||||
// Upstream: complete the upgrade handshake then echo bytes back. This
|
||||
// stands in for an upstream that speaks websocket; ReverseProxy only cares
|
||||
// about the 101 response and then copies raw bytes both ways.
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
t.Errorf("upstream ResponseWriter is not an http.Hijacker")
|
||||
return
|
||||
}
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
t.Errorf("upstream hijack: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
brw.WriteString("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
|
||||
brw.Flush()
|
||||
// Echo whatever the client sends.
|
||||
buf := make([]byte, 64)
|
||||
n, err := brw.Read(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
brw.Write(buf[:n])
|
||||
brw.Flush()
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
upstreamURL, err := url.Parse(upstream.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse upstream URL: %v", err)
|
||||
}
|
||||
|
||||
// Front server: ReverseProxy wrapped in the access-log middleware, which is
|
||||
// the production statusRecorder-wrapped path.
|
||||
proxy := httputil.NewSingleHostReverseProxy(upstreamURL)
|
||||
mw := CreateRequestLogMiddleware(logmon.NewWriter(io.Discard))
|
||||
front := httptest.NewServer(mw(proxy))
|
||||
defer front.Close()
|
||||
|
||||
frontURL, err := url.Parse(front.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse front URL: %v", err)
|
||||
}
|
||||
|
||||
conn, err := net.DialTimeout("tcp", frontURL.Host, 5*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dial front: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
req := "GET / HTTP/1.1\r\n" +
|
||||
"Host: " + frontURL.Host + "\r\n" +
|
||||
"Connection: Upgrade\r\n" +
|
||||
"Upgrade: websocket\r\n" +
|
||||
"\r\n"
|
||||
if _, err := conn.Write([]byte(req)); err != nil {
|
||||
t.Fatalf("write upgrade request: %v", err)
|
||||
}
|
||||
|
||||
br := bufio.NewReader(conn)
|
||||
statusLine, err := br.ReadString('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("read status line: %v", err)
|
||||
}
|
||||
if !strings.Contains(statusLine, "101") {
|
||||
t.Fatalf("websocket upgrade failed: status line = %q, want 101 Switching Protocols", strings.TrimSpace(statusLine))
|
||||
}
|
||||
|
||||
// Drain the rest of the response headers.
|
||||
for {
|
||||
line, err := br.ReadString('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("read headers: %v", err)
|
||||
}
|
||||
if strings.TrimSpace(line) == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Verify bytes flow through the hijacked connection.
|
||||
if _, err := conn.Write([]byte("ping")); err != nil {
|
||||
t.Fatalf("write payload: %v", err)
|
||||
}
|
||||
echo := make([]byte, 4)
|
||||
if _, err := io.ReadFull(br, echo); err != nil {
|
||||
t.Fatalf("read echo: %v", err)
|
||||
}
|
||||
if string(echo) != "ping" {
|
||||
t.Errorf("echo = %q, want %q", echo, "ping")
|
||||
}
|
||||
}
|
||||
|
||||
+149
-33
@@ -1,12 +1,14 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -23,6 +25,8 @@ import (
|
||||
// TokenMetrics holds token usage and performance metrics.
|
||||
type TokenMetrics struct {
|
||||
CachedTokens int `json:"cache_tokens"`
|
||||
DraftTokens int `json:"draft_tokens"`
|
||||
DraftAccTokens int `json:"draft_acc_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||
@@ -31,15 +35,17 @@ type TokenMetrics struct {
|
||||
|
||||
// ActivityLogEntry represents parsed token statistics from llama-server logs.
|
||||
type ActivityLogEntry struct {
|
||||
ID int `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Model string `json:"model"`
|
||||
ReqPath string `json:"req_path"`
|
||||
RespContentType string `json:"resp_content_type"`
|
||||
RespStatusCode int `json:"resp_status_code"`
|
||||
Tokens TokenMetrics `json:"tokens"`
|
||||
DurationMs int `json:"duration_ms"`
|
||||
HasCapture bool `json:"has_capture"`
|
||||
ID int `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Model string `json:"model"`
|
||||
ReqPath string `json:"req_path"`
|
||||
RespContentType string `json:"resp_content_type"`
|
||||
RespStatusCode int `json:"resp_status_code"`
|
||||
Tokens TokenMetrics `json:"tokens"`
|
||||
DurationMs int `json:"duration_ms"`
|
||||
HasCapture bool `json:"has_capture"`
|
||||
ErrorMsg string `json:"error_msg,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ActivityLogEvent carries a single activity log entry to event subscribers.
|
||||
@@ -120,9 +126,11 @@ func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
|
||||
}
|
||||
|
||||
// record parses a completed response body and stores/emits an activity entry.
|
||||
// When captures are enabled, a zstd+CBOR capture is stored for successful
|
||||
// requests, with cf controlling which request/response parts are retained.
|
||||
// reqBody and reqHeaders are the request data buffered before dispatch.
|
||||
// Successful requests store a zstd+CBOR capture (when enabled) with cf
|
||||
// controlling which parts are retained. Failed (non-200) requests capture the
|
||||
// request only and set ErrorMsg to a description of the failure, so the error
|
||||
// can be inspected without storing unreadable raw response bytes. reqBody and
|
||||
// reqHeaders are the request data buffered before dispatch.
|
||||
func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *responseBodyCopier, cf captureFields, reqBody []byte, reqHeaders map[string]string) {
|
||||
tm := ActivityLogEntry{
|
||||
Timestamp: time.Now(),
|
||||
@@ -133,6 +141,13 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
|
||||
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
|
||||
}
|
||||
|
||||
if ctxData, ok := shared.ReadContext(r.Context()); ok && len(ctxData.Metadata) > 0 {
|
||||
tm.Metadata = make(map[string]string, len(ctxData.Metadata))
|
||||
for k, v := range ctxData.Metadata {
|
||||
tm.Metadata[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
queueAndEmit := func() {
|
||||
tm.ID = mp.queueMetrics(tm)
|
||||
mp.emitMetric(tm)
|
||||
@@ -140,7 +155,13 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
|
||||
|
||||
if recorder.Status() != http.StatusOK {
|
||||
mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), r.URL.Path)
|
||||
queueAndEmit()
|
||||
decoded, decErr := mp.decodeResponseBody(recorder, r.URL.Path)
|
||||
tm.ErrorMsg = failedErrorMessage(recorder.Status(), decoded, decErr)
|
||||
tm.ID = mp.queueMetrics(tm)
|
||||
// Capture the request only; the failure is surfaced via ErrorMsg
|
||||
// rather than storing the (possibly undisplayable) response body.
|
||||
tm.HasCapture = mp.storeCapture(tm.ID, r, recorder, cf&^captureRespBody, reqBody, reqHeaders, nil)
|
||||
mp.emitMetric(tm)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -155,6 +176,7 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
|
||||
decoded, err := decompressBody(body, encoding)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, r.URL.Path)
|
||||
tm.ErrorMsg = fmt.Sprintf("response decompression failed: %v", err)
|
||||
queueAndEmit()
|
||||
return
|
||||
}
|
||||
@@ -193,28 +215,99 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
|
||||
}
|
||||
|
||||
tm.ID = mp.queueMetrics(tm)
|
||||
if mp.enableCaptures {
|
||||
capture := ReqRespCapture{
|
||||
ID: tm.ID,
|
||||
ReqPath: r.URL.Path,
|
||||
ReqHeaders: reqHeaders,
|
||||
}
|
||||
if cf&captureReqBody != 0 {
|
||||
capture.ReqBody = reqBody
|
||||
}
|
||||
if cf&captureRespHeaders != 0 {
|
||||
capture.RespHeaders = headerMap(recorder.Header())
|
||||
redactHeaders(capture.RespHeaders)
|
||||
delete(capture.RespHeaders, "Content-Encoding")
|
||||
}
|
||||
if cf&captureRespBody != 0 {
|
||||
capture.RespBody = body
|
||||
}
|
||||
if mp.addCapture(capture) {
|
||||
tm.HasCapture = true
|
||||
tm.HasCapture = mp.storeCapture(tm.ID, r, recorder, cf, reqBody, reqHeaders, body)
|
||||
mp.emitMetric(tm)
|
||||
}
|
||||
|
||||
// storeCapture assembles a ReqRespCapture for id, honoring the captureFields
|
||||
// mask, and stores it when captures are enabled. body is the response body to
|
||||
// capture (already decompressed by the caller); pass nil to omit it. Returns
|
||||
// true if a capture was stored.
|
||||
func (mp *metricsMonitor) storeCapture(id int, r *http.Request, recorder *responseBodyCopier, cf captureFields, reqBody []byte, reqHeaders map[string]string, body []byte) bool {
|
||||
if !mp.enableCaptures {
|
||||
return false
|
||||
}
|
||||
capture := ReqRespCapture{
|
||||
ID: id,
|
||||
ReqPath: r.URL.Path,
|
||||
ReqHeaders: reqHeaders,
|
||||
}
|
||||
if cf&captureReqBody != 0 {
|
||||
capture.ReqBody = reqBody
|
||||
}
|
||||
if cf&captureRespHeaders != 0 {
|
||||
capture.RespHeaders = headerMap(recorder.Header())
|
||||
redactHeaders(capture.RespHeaders)
|
||||
delete(capture.RespHeaders, "Content-Encoding")
|
||||
}
|
||||
if cf&captureRespBody != 0 {
|
||||
capture.RespBody = body
|
||||
}
|
||||
return mp.addCapture(capture)
|
||||
}
|
||||
|
||||
// decodeResponseBody returns the buffered response body, decompressing it when
|
||||
// the upstream set a Content-Encoding we recognize. On decompression failure it
|
||||
// logs a warning and returns an error so the caller can record a description
|
||||
// (via ErrorMsg) instead of storing unreadable raw bytes.
|
||||
func (mp *metricsMonitor) decodeResponseBody(recorder *responseBodyCopier, path string) ([]byte, error) {
|
||||
body := recorder.body.Bytes()
|
||||
if len(body) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
encoding := recorder.Header().Get("Content-Encoding")
|
||||
if encoding == "" {
|
||||
return body, nil
|
||||
}
|
||||
decoded, err := decompressBody(body, encoding)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("metrics: response decompression failed: %v, path=%s", err, path)
|
||||
return nil, err
|
||||
}
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// errorMessagePaths lists JSON paths where a human-readable error message can
|
||||
// live across OpenAI- and llama.cpp-style error responses.
|
||||
var errorMessagePaths = []string{"error.message", "error", "message", "detail"}
|
||||
|
||||
// extractErrorMessage pulls a human-readable error string from a JSON error
|
||||
// response. Returns "" if no message is found or the body is not valid JSON.
|
||||
func extractErrorMessage(body []byte) string {
|
||||
if !gjson.ValidBytes(body) {
|
||||
return ""
|
||||
}
|
||||
parsed := gjson.ParseBytes(body)
|
||||
for _, path := range errorMessagePaths {
|
||||
v := parsed.Get(path)
|
||||
if v.Exists() && v.Type == gjson.String {
|
||||
if s := strings.TrimSpace(v.String()); s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
}
|
||||
mp.emitMetric(tm)
|
||||
return ""
|
||||
}
|
||||
|
||||
// failedErrorMessage builds a human-readable description for a non-200 response.
|
||||
// It prefers an error message parsed from the (decompressed) body and falls back
|
||||
// to the HTTP status text. A non-nil decErr indicates the body could not be
|
||||
// decoded, in which case the decode error is described instead.
|
||||
func failedErrorMessage(status int, body []byte, decErr error) string {
|
||||
const maxLen = 500
|
||||
if decErr != nil {
|
||||
return fmt.Sprintf("response decode failed: %v", decErr)
|
||||
}
|
||||
if msg := extractErrorMessage(body); msg != "" {
|
||||
if len(msg) > maxLen {
|
||||
msg = msg[:maxLen] + "..."
|
||||
}
|
||||
return msg
|
||||
}
|
||||
if text := http.StatusText(status); text != "" {
|
||||
return fmt.Sprintf("%d %s", status, text)
|
||||
}
|
||||
return fmt.Sprintf("HTTP %d", status)
|
||||
}
|
||||
|
||||
// usagePaths lists the JSON paths where a per-event usage object can live.
|
||||
@@ -335,6 +428,8 @@ func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, ca
|
||||
durationMs := wallDurationMs
|
||||
tokensPerSecond := -1.0
|
||||
promptPerSecond := -1.0
|
||||
draftTokens := -1
|
||||
draftAccTokens := -1
|
||||
|
||||
if timings.Exists() {
|
||||
inputTokens = timings.Get("prompt_n").Int()
|
||||
@@ -348,6 +443,10 @@ func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, ca
|
||||
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
|
||||
cachedTokens = cachedValue.Int()
|
||||
}
|
||||
if timings.Get("draft_n").Exists() && timings.Get("draft_n_accepted").Exists() {
|
||||
draftTokens = int(timings.Get("draft_n").Int())
|
||||
draftAccTokens = int(timings.Get("draft_n_accepted").Int())
|
||||
}
|
||||
}
|
||||
|
||||
return ActivityLogEntry{
|
||||
@@ -355,6 +454,8 @@ func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, ca
|
||||
Model: modelID,
|
||||
Tokens: TokenMetrics{
|
||||
CachedTokens: int(cachedTokens),
|
||||
DraftTokens: draftTokens,
|
||||
DraftAccTokens: draftAccTokens,
|
||||
InputTokens: int(inputTokens),
|
||||
OutputTokens: int(outputTokens),
|
||||
PromptPerSecond: promptPerSecond,
|
||||
@@ -427,6 +528,12 @@ func (w *responseBodyCopier) Write(b []byte) (int, error) {
|
||||
if !w.wroteHeader {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
// On a protocol upgrade (e.g. websocket) the body is raw framed data, not a
|
||||
// metrics-parseable response, so write straight to the client without
|
||||
// buffering a copy we can't use.
|
||||
if w.status == http.StatusSwitchingProtocols {
|
||||
return w.ResponseWriter.Write(b)
|
||||
}
|
||||
return w.tee.Write(b)
|
||||
}
|
||||
|
||||
@@ -446,5 +553,14 @@ func (w *responseBodyCopier) Flush() {
|
||||
}
|
||||
}
|
||||
|
||||
// Hijack forwards to the underlying writer so httputil.ReverseProxy can take
|
||||
// over the connection for websocket upgrades.
|
||||
func (w *responseBodyCopier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hj, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||
return hj.Hijack()
|
||||
}
|
||||
return nil, nil, fmt.Errorf("underlying ResponseWriter does not support hijacking")
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) Status() int { return w.status }
|
||||
func (w *responseBodyCopier) StartTime() time.Time { return w.start }
|
||||
|
||||
@@ -4,10 +4,11 @@ import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// CreateMetricsMiddleware returns middleware that records token metrics for
|
||||
@@ -21,17 +22,36 @@ func CreateMetricsMiddleware(mm *metricsMonitor, cfg config.Config) chain.Middle
|
||||
return
|
||||
}
|
||||
|
||||
// Determine the model-routed endpoint path. Regular routes are
|
||||
// already meterable; /upstream/<model>/<path> is metered only when
|
||||
// the remaining path matches a model-dispatched endpoint.
|
||||
checkPath := r.URL.Path
|
||||
if strings.HasPrefix(r.URL.Path, "/upstream/") {
|
||||
var found bool
|
||||
_, _, checkPath, found = shared.FindModelInPath(cfg, strings.TrimPrefix(r.URL.Path, "/upstream"))
|
||||
if !found {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !isMetricsRecordPath(checkPath) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve the model now so downstream dispatch hits the context
|
||||
// fast path; FetchContext restores the request body.
|
||||
data, err := router.FetchContext(r, cfg)
|
||||
// fast path; FetchContext restores the request body for regular
|
||||
// routes and extracts the model from the URL for /upstream routes.
|
||||
data, err := shared.FetchContext(r, cfg)
|
||||
if err != nil {
|
||||
router.SendError(w, r, router.ErrNoModelInContext)
|
||||
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||
return
|
||||
}
|
||||
|
||||
// Buffer the request body/headers for capture before dispatch
|
||||
// consumes them.
|
||||
cf := captureFieldsFor(r.URL.Path)
|
||||
cf := captureFieldsFor(checkPath)
|
||||
var reqBody []byte
|
||||
var reqHeaders map[string]string
|
||||
if mm.enableCaptures {
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
@@ -56,6 +63,199 @@ func TestServer_ProcessStreamingResponse_NoData(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_RecordMetadata(t *testing.T) {
|
||||
mm := newMetricsMonitor(nil, 10, 0)
|
||||
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"usage":{}}`))
|
||||
r = r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{
|
||||
ModelID: "m",
|
||||
Metadata: map[string]string{"client": "web", "trace": "abc"},
|
||||
}))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
copier := newBodyCopier(w)
|
||||
copier.WriteHeader(http.StatusOK)
|
||||
copier.Write([]byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2}}`))
|
||||
|
||||
mm.record("m", r, copier, 0, nil, nil)
|
||||
|
||||
entries := mm.getMetrics()
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("want 1 entry, got %d", len(entries))
|
||||
}
|
||||
if entries[0].Metadata["client"] != "web" {
|
||||
t.Errorf("client = %q, want web", entries[0].Metadata["client"])
|
||||
}
|
||||
if entries[0].Metadata["trace"] != "abc" {
|
||||
t.Errorf("trace = %q, want abc", entries[0].Metadata["trace"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_RecordFailedRequestCapture(t *testing.T) {
|
||||
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
|
||||
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
reqHeaders := map[string]string{"content-type": "application/json"}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
copier := newBodyCopier(w)
|
||||
copier.Header().Set("Content-Type", "application/json")
|
||||
copier.WriteHeader(http.StatusBadGateway)
|
||||
copier.Write([]byte(`{"error":{"message":"model unavailable"}}`))
|
||||
|
||||
reqBody := []byte(`{"model":"m","messages":[]}`)
|
||||
mm.record("m", r, copier, captureAll, reqBody, reqHeaders)
|
||||
|
||||
entries := mm.getMetrics()
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("want 1 entry, got %d", len(entries))
|
||||
}
|
||||
entry := entries[0]
|
||||
if entry.RespStatusCode != http.StatusBadGateway {
|
||||
t.Errorf("status = %d, want %d", entry.RespStatusCode, http.StatusBadGateway)
|
||||
}
|
||||
if entry.ErrorMsg != "model unavailable" {
|
||||
t.Errorf("error_msg = %q, want extracted message", entry.ErrorMsg)
|
||||
}
|
||||
if !entry.HasCapture {
|
||||
t.Fatal("failed request should capture the request so it can be inspected")
|
||||
}
|
||||
|
||||
got := mm.getCaptureByID(entry.ID)
|
||||
if got == nil {
|
||||
t.Fatal("capture not found")
|
||||
}
|
||||
if string(got.ReqBody) != `{"model":"m","messages":[]}` {
|
||||
t.Errorf("req body = %q", got.ReqBody)
|
||||
}
|
||||
if len(got.RespBody) != 0 {
|
||||
t.Errorf("resp body stored for failed request (len=%d); want none", len(got.RespBody))
|
||||
}
|
||||
if got.RespHeaders["Content-Type"] != "application/json" {
|
||||
t.Errorf("resp Content-Type = %q", got.RespHeaders["Content-Type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_RecordFailedRequestStatusFallback(t *testing.T) {
|
||||
// Non-JSON error body: ErrorMsg falls back to the HTTP status text.
|
||||
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
|
||||
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
copier := newBodyCopier(w)
|
||||
copier.WriteHeader(http.StatusBadGateway)
|
||||
copier.Write([]byte("<html>upstream down</html>"))
|
||||
|
||||
mm.record("m", r, copier, captureAll, nil, nil)
|
||||
|
||||
entries := mm.getMetrics()
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("want 1 entry, got %d", len(entries))
|
||||
}
|
||||
if entries[0].ErrorMsg != "502 Bad Gateway" {
|
||||
t.Errorf("error_msg = %q, want status text", entries[0].ErrorMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_RecordFailedRequestCaptureDisabled(t *testing.T) {
|
||||
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 0) // captures disabled
|
||||
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
copier := newBodyCopier(w)
|
||||
copier.WriteHeader(http.StatusInternalServerError)
|
||||
copier.Write([]byte(`{"error":"boom"}`))
|
||||
|
||||
mm.record("m", r, copier, captureAll, []byte("req"), nil)
|
||||
|
||||
entries := mm.getMetrics()
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("want 1 entry, got %d", len(entries))
|
||||
}
|
||||
if entries[0].HasCapture {
|
||||
t.Fatal("captures disabled, HasCapture should be false")
|
||||
}
|
||||
// ErrorMsg is independent of whether captures are enabled.
|
||||
if entries[0].ErrorMsg != "boom" {
|
||||
t.Errorf("error_msg = %q, want boom", entries[0].ErrorMsg)
|
||||
}
|
||||
if mm.getCaptureByID(entries[0].ID) != nil {
|
||||
t.Fatal("no capture should be stored when disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_RecordDecompressionFailureSetsErrorMsg(t *testing.T) {
|
||||
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
|
||||
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
copier := newBodyCopier(w)
|
||||
copier.Header().Set("Content-Encoding", "gzip")
|
||||
copier.WriteHeader(http.StatusOK)
|
||||
copier.Write([]byte("not-really-gzip"))
|
||||
|
||||
mm.record("m", r, copier, captureAll, []byte("req"), nil)
|
||||
|
||||
entries := mm.getMetrics()
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("want 1 entry, got %d", len(entries))
|
||||
}
|
||||
if entries[0].ErrorMsg == "" {
|
||||
t.Fatal("expected ErrorMsg for decompression failure")
|
||||
}
|
||||
// Raw bytes must not be stored when the body could not be decoded.
|
||||
if entries[0].HasCapture {
|
||||
t.Fatal("decompression failure should not store a capture")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_DecodeResponseBody(t *testing.T) {
|
||||
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
|
||||
|
||||
// No Content-Encoding: body returned unchanged.
|
||||
w := httptest.NewRecorder()
|
||||
copier := newBodyCopier(w)
|
||||
copier.Write([]byte("plain"))
|
||||
got, err := mm.decodeResponseBody(copier, "/p")
|
||||
if err != nil || string(got) != "plain" {
|
||||
t.Fatalf("plain body = %q, err = %v", got, err)
|
||||
}
|
||||
|
||||
// Bogus gzip payload: returns an error and no body (no raw bytes kept).
|
||||
w2 := httptest.NewRecorder()
|
||||
copier2 := newBodyCopier(w2)
|
||||
copier2.Header().Set("Content-Encoding", "gzip")
|
||||
copier2.Write([]byte("not-really-gzip"))
|
||||
got, err = mm.decodeResponseBody(copier2, "/p")
|
||||
if err == nil {
|
||||
t.Fatal("expected decompression error")
|
||||
}
|
||||
if got != nil {
|
||||
t.Errorf("expected nil body on failure, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ExtractErrorMessage(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
body string
|
||||
want string
|
||||
}{
|
||||
{"openai object", `{"error":{"message":"rate limited"}}`, "rate limited"},
|
||||
{"string error", `{"error":"bad request"}`, "bad request"},
|
||||
{"message field", `{"message":"nope"}`, "nope"},
|
||||
{"detail field", `{"detail":"oops"}`, "oops"},
|
||||
{"object error ignored", `{"error":{"code":42}}`, ""},
|
||||
{"no error", `{"usage":{}}`, ""},
|
||||
{"invalid json", `not-json`, ""},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := extractErrorMessage([]byte(tc.body)); got != tc.want {
|
||||
t.Errorf("extractErrorMessage = %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ParseMetrics_Infill(t *testing.T) {
|
||||
// /infill responses are arrays; timings live in the last element.
|
||||
body := `[{"content":"a"},{"content":"b","timings":{"prompt_n":5,"predicted_n":9,"prompt_ms":10,"predicted_ms":20}}]`
|
||||
@@ -72,3 +272,40 @@ func TestServer_ParseMetrics_Infill(t *testing.T) {
|
||||
t.Fatalf("tokens = %+v", entry.Tokens)
|
||||
}
|
||||
}
|
||||
|
||||
// TestServer_MetricsMiddleware_UpstreamAudioCaptureSkipsRespBody verifies that
|
||||
// an /upstream/<model>/v1/audio/speech request uses the path-specific capture
|
||||
// mask (headers only) rather than falling back to captureAll.
|
||||
func TestServer_MetricsMiddleware_UpstreamAudioCaptureSkipsRespBody(t *testing.T) {
|
||||
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 5)
|
||||
cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
|
||||
|
||||
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "audio/mpeg")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("BINARY-AUDIO-DATA"))
|
||||
})
|
||||
handler := CreateMetricsMiddleware(mm, cfg)(inner)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/audio/speech", strings.NewReader(`{"model":"m1"}`))
|
||||
handler.ServeHTTP(httptest.NewRecorder(), req)
|
||||
|
||||
entries := mm.getMetrics()
|
||||
if len(entries) == 0 {
|
||||
t.Fatal("no metrics recorded")
|
||||
}
|
||||
last := entries[len(entries)-1]
|
||||
if !last.HasCapture {
|
||||
t.Fatal("expected capture to be stored")
|
||||
}
|
||||
cap := mm.getCaptureByID(last.ID)
|
||||
if cap == nil {
|
||||
t.Fatal("capture not found")
|
||||
}
|
||||
if len(cap.RespBody) != 0 {
|
||||
t.Errorf("RespBody stored for /upstream audio route (len=%d); want path-specific mask to skip body", len(cap.RespBody))
|
||||
}
|
||||
if len(cap.RespHeaders) == 0 {
|
||||
t.Error("RespHeaders not stored; want captureRespHeaders mask")
|
||||
}
|
||||
}
|
||||
|
||||
+40
-23
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// Server owns the HTTP mux, cross-cutting middleware, and the local/peer model
|
||||
@@ -88,6 +89,27 @@ var modelGetRoutes = []string{
|
||||
"/sdapi/v1/loras",
|
||||
}
|
||||
|
||||
// isMetricsRecordPath reports whether path is one of the model-dispatched
|
||||
// endpoints that the metrics middleware records in the activity log.
|
||||
func isMetricsRecordPath(path string) bool {
|
||||
for _, p := range modelPostJSONRoutes {
|
||||
if p == path {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, p := range modelPostFormRoutes {
|
||||
if p == path {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, p := range modelGetRoutes {
|
||||
if p == path {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// BuildInfo carries version metadata surfaced by GET /api/version.
|
||||
type BuildInfo struct {
|
||||
Version string
|
||||
@@ -99,12 +121,13 @@ func New(cfg config.Config, muxlog *logmon.Monitor, proxylog *logmon.Monitor, up
|
||||
var local router.LocalRouter
|
||||
var err error
|
||||
|
||||
if cfg.Matrix != nil {
|
||||
switch cfg.Routing.Router.Use {
|
||||
case "matrix":
|
||||
local, err = router.NewMatrix(cfg, proxylog, upstreamlog)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating matrix router: %w", err)
|
||||
}
|
||||
} else {
|
||||
default: // "group"
|
||||
local, err = router.NewGroup(cfg, proxylog, upstreamlog)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating group router: %w", err)
|
||||
@@ -137,13 +160,13 @@ func New(cfg config.Config, muxlog *logmon.Monitor, proxylog *logmon.Monitor, up
|
||||
}
|
||||
|
||||
// localPeerHandler dispatches a model-routed request to the local or peer
|
||||
// router. The model is resolved once via router.FetchContext.
|
||||
// router. The model is resolved once via shared.FetchContext.
|
||||
func (s *Server) localPeerHandler(w http.ResponseWriter, r *http.Request) {
|
||||
stripVersionPrefix(r)
|
||||
|
||||
data, err := router.FetchContext(r, s.cfg)
|
||||
data, err := shared.FetchContext(r, s.cfg)
|
||||
if err != nil {
|
||||
router.SendError(w, r, router.ErrNoModelInContext)
|
||||
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -155,7 +178,7 @@ func (s *Server) localPeerHandler(w http.ResponseWriter, r *http.Request) {
|
||||
s.proxylog.Debugf("dispatch: using peer for model: %s", data.ModelID)
|
||||
s.peer.ServeHTTP(w, r)
|
||||
default:
|
||||
router.SendError(w, r, router.ErrNoRouterFound)
|
||||
shared.SendError(w, r, router.ErrNoRouterFound)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,21 +193,13 @@ func stripVersionPrefix(r *http.Request) {
|
||||
// routes builds the mux, registers every route, and wraps the mux with the
|
||||
// global CORS middleware.
|
||||
func (s *Server) routes() {
|
||||
authMW := CreateAuthMiddleware(s.cfg)
|
||||
filterMW := CreateFilterMiddleware(s.cfg)
|
||||
formFilterMW := CreateFormFilterMiddleware(s.cfg)
|
||||
|
||||
// Model-dispatched routes get auth + per-model concurrency limiting + body
|
||||
// filters + in-flight tracking + token metrics. concurrencyMW rejects with
|
||||
// 429 before the body filters do any rewrite work. filterMW rewrites JSON
|
||||
// bodies and formFilterMW rewrites multipart bodies; each is a no-op for the
|
||||
// other's Content-Type. Both run before the metrics middleware so it buffers
|
||||
// the rewritten body.
|
||||
authMW := CreateAuthMiddleware(s.cfg)
|
||||
modelChain := chain.New(
|
||||
authMW,
|
||||
CreateConcurrencyMiddleware(s.cfg),
|
||||
filterMW,
|
||||
formFilterMW,
|
||||
CreateRequestContextMiddleware(s.cfg),
|
||||
CreateFilterMiddleware(s.cfg),
|
||||
CreateFormFilterMiddleware(s.cfg),
|
||||
CreateInflightMiddleware(s.inflight),
|
||||
CreateMetricsMiddleware(s.metrics, s.cfg),
|
||||
)
|
||||
@@ -215,19 +230,21 @@ func (s *Server) routes() {
|
||||
mux.HandleFunc("GET /{$}", handleRootRedirect)
|
||||
|
||||
// Embedded UI.
|
||||
mux.HandleFunc("GET /ui/", s.handleUI)
|
||||
mux.Handle("GET /ui/", chain.New(authMW).ThenFunc(s.handleUI))
|
||||
mux.HandleFunc("GET /favicon.ico", s.handleFavicon)
|
||||
|
||||
// Prometheus metrics (no auth, matches the legacy endpoint).
|
||||
mux.HandleFunc("GET /metrics", s.handleMetrics)
|
||||
// Prometheus metrics (wrapped by apiChain, matches the legacy endpoint).
|
||||
mux.Handle("GET /metrics", apiChain.ThenFunc(s.handleMetrics))
|
||||
|
||||
// Operations endpoints.
|
||||
mux.Handle("GET /unload", apiChain.ThenFunc(s.handleUnload))
|
||||
mux.Handle("GET /running", apiChain.ThenFunc(s.handleRunning))
|
||||
|
||||
// Upstream passthrough.
|
||||
// Upstream passthrough. Meter only the model-dispatched endpoints that can
|
||||
// produce token usage/timings.
|
||||
upstreamChain := apiChain.Append(CreateMetricsMiddleware(s.metrics, s.cfg))
|
||||
mux.HandleFunc("GET /upstream", handleUpstreamRedirect)
|
||||
mux.Handle("/upstream/{upstreamPath...}", apiChain.ThenFunc(s.handleUpstream))
|
||||
mux.Handle("/upstream/{upstreamPath...}", upstreamChain.ThenFunc(s.handleUpstream))
|
||||
|
||||
// API group (API-key protected) consumed by the UI.
|
||||
mux.Handle("POST /api/models/unload", apiChain.ThenFunc(s.handleAPIUnloadAll))
|
||||
|
||||
@@ -84,10 +84,15 @@ func chatRequest(model string) *http.Request {
|
||||
|
||||
func TestServer_New_GroupConfig(t *testing.T) {
|
||||
discard := logmon.NewWriter(io.Discard)
|
||||
s, err := New(config.Config{HealthCheckTimeout: 15}, discard, discard, discard, nil, BuildInfo{})
|
||||
cfg := config.Config{HealthCheckTimeout: 15}
|
||||
cfg.Routing.Router.Use = "group"
|
||||
s, err := New(cfg, discard, discard, discard, nil, BuildInfo{})
|
||||
if err != nil {
|
||||
t.Fatalf("New (group): %v", err)
|
||||
}
|
||||
if _, ok := s.local.(*router.Group); !ok {
|
||||
t.Fatalf("localRouter=%T want *router.Group", s.local)
|
||||
}
|
||||
if err := s.Shutdown(time.Second); err != nil {
|
||||
t.Fatalf("Shutdown: %v", err)
|
||||
}
|
||||
@@ -95,11 +100,16 @@ func TestServer_New_GroupConfig(t *testing.T) {
|
||||
|
||||
func TestServer_New_MatrixConfig(t *testing.T) {
|
||||
discard := logmon.NewWriter(io.Discard)
|
||||
cfg := config.Config{HealthCheckTimeout: 15, Matrix: &config.MatrixConfig{}}
|
||||
cfg := config.Config{HealthCheckTimeout: 15}
|
||||
cfg.Routing.Router.Use = "matrix"
|
||||
cfg.Routing.Router.Settings.Matrix = &config.MatrixConfig{}
|
||||
s, err := New(cfg, discard, discard, discard, nil, BuildInfo{})
|
||||
if err != nil {
|
||||
t.Fatalf("New (matrix): %v", err)
|
||||
}
|
||||
if _, ok := s.local.(*router.Matrix); !ok {
|
||||
t.Fatalf("localRouter=%T want *router.Matrix", s.local)
|
||||
}
|
||||
if err := s.Shutdown(time.Second); err != nil {
|
||||
t.Fatalf("Shutdown: %v", err)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,298 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type contextkey struct {
|
||||
name string
|
||||
}
|
||||
|
||||
type ReqContextData struct {
|
||||
ApiKey string
|
||||
Model string
|
||||
ModelID string
|
||||
Streaming bool
|
||||
SendLoadingState bool
|
||||
// Metadata is a request-scoped key/value bag that handlers may mutate
|
||||
// while processing. The metrics middleware copies it into ActivityLogEntry.
|
||||
Metadata map[string]string
|
||||
}
|
||||
|
||||
var (
|
||||
ReqContextKey = &contextkey{"context"}
|
||||
ErrNoModelInContext = fmt.Errorf("no model in request context")
|
||||
ErrNoRouterFound = fmt.Errorf("no router found for model")
|
||||
ErrNoPeerModelFound = fmt.Errorf("peer model not found")
|
||||
ErrNoLocalModelFound = fmt.Errorf("local model not found")
|
||||
)
|
||||
|
||||
func SendError(w http.ResponseWriter, r *http.Request, err error) {
|
||||
var httpErr HTTPError
|
||||
if errors.As(err, &httpErr) {
|
||||
for k, v := range httpErr.Header() {
|
||||
w.Header()[k] = v
|
||||
}
|
||||
w.WriteHeader(httpErr.StatusCode())
|
||||
w.Write(httpErr.Body())
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case errors.Is(err, ErrNoModelInContext):
|
||||
SendResponse(w, r, http.StatusNotFound, "no model id could be identified")
|
||||
case errors.Is(err, ErrNoPeerModelFound):
|
||||
SendResponse(w, r, http.StatusNotFound, "no peer found for requested model")
|
||||
case errors.Is(err, ErrNoLocalModelFound):
|
||||
SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
|
||||
case errors.Is(err, ErrNoRouterFound):
|
||||
SendResponse(w, r, http.StatusNotFound, "no router for requested model")
|
||||
default:
|
||||
SendResponse(w, r, http.StatusInternalServerError, fmt.Sprintf("unspecific error: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// SendResponse detects what content type the client prefers and returns an error response in that format.
|
||||
func SendResponse(w http.ResponseWriter, r *http.Request, status int, message string) {
|
||||
acceptHeader := r.Header.Get("Accept")
|
||||
if strings.Contains(acceptHeader, "text/plain") {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(fmt.Sprintf("llama-swap: %s", message)))
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(acceptHeader, "text/html") {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(fmt.Sprintf(`<html><body><h1>llama-swap</h1><p>%s</p></body></html>`, html.EscapeString(message))))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
resp, err := json.Marshal(map[string]string{"src": "llama-swap", "error": message})
|
||||
if err != nil {
|
||||
w.Write([]byte(`{"src":"llama-swap", "error": "failed to marshal response"}`))
|
||||
return
|
||||
}
|
||||
w.Write(resp)
|
||||
}
|
||||
|
||||
// FetchContext will attempt to get the model id from the context, then
|
||||
// from an /upstream/<model> path prefix, then from the request body/query.
|
||||
// If it extracts the model it will store it in the context for downstream
|
||||
// handlers. An error will be returned when a model cannot be identified.
|
||||
func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
|
||||
data, ok := ReadContext(r.Context())
|
||||
if ok {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(r.URL.Path, "/upstream/") {
|
||||
if data, ok := extractUpstreamContext(r, cfg); ok {
|
||||
*r = *r.WithContext(SetContext(r.Context(), data))
|
||||
return data, nil
|
||||
}
|
||||
return ReqContextData{}, ErrNoModelInContext
|
||||
}
|
||||
|
||||
if data, err := extractContext(r); err == nil && data.Model != "" {
|
||||
realName, _ := cfg.RealModelName(data.Model)
|
||||
if realName == "" {
|
||||
realName = data.Model
|
||||
}
|
||||
data.ModelID = realName
|
||||
if mc, ok := cfg.Models[realName]; ok {
|
||||
data.SendLoadingState = mc.SendLoadingState != nil && *mc.SendLoadingState
|
||||
}
|
||||
*r = *r.WithContext(SetContext(r.Context(), data))
|
||||
return data, nil
|
||||
}
|
||||
|
||||
return ReqContextData{}, ErrNoModelInContext
|
||||
}
|
||||
|
||||
// extractUpstreamContext resolves the model from an /upstream/<model>/... path.
|
||||
func extractUpstreamContext(r *http.Request, cfg config.Config) (ReqContextData, bool) {
|
||||
searchName, realName, _, found := FindModelInPath(cfg, strings.TrimPrefix(r.URL.Path, "/upstream"))
|
||||
if !found {
|
||||
return ReqContextData{}, false
|
||||
}
|
||||
return ReqContextData{
|
||||
Model: searchName,
|
||||
ModelID: realName,
|
||||
ApiKey: ExtractAPIKey(r),
|
||||
Streaming: r.URL.Query().Get("stream") == "true",
|
||||
SendLoadingState: sendLoadingState(cfg, realName),
|
||||
Metadata: make(map[string]string),
|
||||
}, true
|
||||
}
|
||||
|
||||
// sendLoadingState reports whether the configured model wants loading-state SSEs.
|
||||
func sendLoadingState(cfg config.Config, modelID string) bool {
|
||||
if mc, ok := cfg.Models[modelID]; ok {
|
||||
return mc.SendLoadingState != nil && *mc.SendLoadingState
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// FindModelInPath walks a slash-separated path, building up segments until one
|
||||
// matches a configured model. This resolves model names that contain slashes
|
||||
// (e.g. "author/model"). Returns the matched name, its real model ID, the
|
||||
// remaining path, and whether a match was found.
|
||||
func FindModelInPath(cfg config.Config, path string) (searchName, realName, remainingPath string, found bool) {
|
||||
parts := strings.Split(strings.TrimSpace(path), "/")
|
||||
name := ""
|
||||
|
||||
for i, part := range parts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
if name == "" {
|
||||
name = part
|
||||
} else {
|
||||
name = name + "/" + part
|
||||
}
|
||||
|
||||
if modelID, ok := cfg.RealModelName(name); ok {
|
||||
searchName = name
|
||||
realName = modelID
|
||||
remainingPath = "/" + strings.Join(parts[i+1:], "/")
|
||||
found = true
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func SetContext(ctx context.Context, data ReqContextData) context.Context {
|
||||
return context.WithValue(ctx, ReqContextKey, data)
|
||||
}
|
||||
|
||||
func ReadContext(ctx context.Context) (ReqContextData, bool) {
|
||||
data, ok := ctx.Value(ReqContextKey).(ReqContextData)
|
||||
return data, ok
|
||||
}
|
||||
|
||||
// SetReqData attaches a key/value pair to the request context's metadata map.
|
||||
// The metadata map must already exist in the context's ReqContextData; callers
|
||||
// should ensure FetchContext has run or initialize the map themselves.
|
||||
// It returns an error for nil contexts or contexts without request data.
|
||||
func SetReqData(ctx context.Context, key, value string) error {
|
||||
if ctx == nil {
|
||||
return fmt.Errorf("cannot set request metadata on nil context")
|
||||
}
|
||||
data, ok := ReadContext(ctx)
|
||||
if !ok {
|
||||
return fmt.Errorf("no request context data found")
|
||||
}
|
||||
if data.Metadata == nil {
|
||||
return fmt.Errorf("no metadata map in request context")
|
||||
}
|
||||
data.Metadata[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractContext pulls fields from an HTTP request into a ReqContextData,
|
||||
// returning whatever is available. For GET requests it reads query parameters.
|
||||
// For POST requests it inspects Content-Type and parses JSON,
|
||||
// multipart/form-data, or application/x-www-form-urlencoded bodies. The
|
||||
// request body is always restored before returning. An error is returned only
|
||||
// for I/O or parse failures, not for missing fields.
|
||||
func extractContext(r *http.Request) (ReqContextData, error) {
|
||||
|
||||
apiKey := ExtractAPIKey(r)
|
||||
|
||||
if r.Method == http.MethodGet {
|
||||
q := r.URL.Query()
|
||||
return ReqContextData{
|
||||
Model: q.Get("model"),
|
||||
Streaming: q.Get("stream") == "true",
|
||||
ApiKey: apiKey,
|
||||
Metadata: make(map[string]string),
|
||||
}, nil
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return ReqContextData{}, fmt.Errorf("error reading request body: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
}()
|
||||
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
return ReqContextData{
|
||||
Model: gjson.GetBytes(bodyBytes, "model").String(),
|
||||
Streaming: gjson.GetBytes(bodyBytes, "stream").Bool(),
|
||||
ApiKey: apiKey,
|
||||
Metadata: make(map[string]string),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Form parsers read from r.Body, so feed them a fresh reader over the
|
||||
// buffered bytes. The deferred restore above will reset r.Body again
|
||||
// after parsing.
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
if strings.Contains(contentType, "multipart/form-data") {
|
||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||
return ReqContextData{}, fmt.Errorf("error parsing multipart form: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return ReqContextData{}, fmt.Errorf("error parsing form: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return ReqContextData{
|
||||
Model: r.FormValue("model"),
|
||||
Streaming: r.FormValue("stream") == "true",
|
||||
ApiKey: apiKey,
|
||||
Metadata: make(map[string]string),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractAPIKey pulls a candidate API key from the request, preferring Basic,
|
||||
// then Bearer, then x-api-key.
|
||||
func ExtractAPIKey(r *http.Request) string {
|
||||
var bearerKey, basicKey string
|
||||
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||
scheme, credentials, ok := strings.Cut(auth, " ")
|
||||
if ok {
|
||||
switch strings.ToLower(scheme) {
|
||||
case "bearer":
|
||||
bearerKey = credentials
|
||||
case "basic":
|
||||
if decoded, err := base64.StdEncoding.DecodeString(credentials); err == nil {
|
||||
if parts := strings.SplitN(string(decoded), ":", 2); len(parts) == 2 {
|
||||
basicKey = parts[1] // password field is the API key
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case basicKey != "":
|
||||
return basicKey
|
||||
case bearerKey != "":
|
||||
return bearerKey
|
||||
default:
|
||||
return r.Header.Get("x-api-key")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,525 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
)
|
||||
|
||||
func TestExtractContext_GET(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
wantModel string
|
||||
wantErr bool
|
||||
}{
|
||||
{"model present", "model=llama3", "llama3", false},
|
||||
{"model with slashes", "model=author/model-7b", "author/model-7b", false},
|
||||
{"model missing", "", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
|
||||
got, err := extractContext(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||
}
|
||||
if got.Model != tt.wantModel {
|
||||
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_JSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantModel string
|
||||
wantErr bool
|
||||
}{
|
||||
{"model present", `{"model":"llama3","stream":true}`, "llama3", false},
|
||||
{"model with slashes", `{"model":"author/model-7b"}`, "author/model-7b", false},
|
||||
{"model empty string", `{"model":""}`, "", false},
|
||||
{"model key missing", `{"stream":true}`, "", false},
|
||||
{"invalid json", `not-json`, "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
got, err := extractContext(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||
}
|
||||
if got.Model != tt.wantModel {
|
||||
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_URLEncodedForm(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
formModel string
|
||||
wantModel string
|
||||
wantErr bool
|
||||
}{
|
||||
{"model present", "whisper-1", "whisper-1", false},
|
||||
{"model missing", "", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
form := url.Values{}
|
||||
if tt.formModel != "" {
|
||||
form.Set("model", tt.formModel)
|
||||
}
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(form.Encode()))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
got, err := extractContext(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||
}
|
||||
if got.Model != tt.wantModel {
|
||||
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_MultipartForm(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
formModel string
|
||||
wantModel string
|
||||
wantErr bool
|
||||
}{
|
||||
{"model present", "whisper-1", "whisper-1", false},
|
||||
{"model missing", "", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
mw := multipart.NewWriter(&buf)
|
||||
if tt.formModel != "" {
|
||||
fw, _ := mw.CreateFormField("model")
|
||||
fw.Write([]byte(tt.formModel))
|
||||
}
|
||||
mw.Close()
|
||||
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf)
|
||||
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
got, err := extractContext(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||
}
|
||||
if got.Model != tt.wantModel {
|
||||
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_JSONBodyRestored(t *testing.T) {
|
||||
body := `{"model":"llama3","stream":true}`
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
if _, err := extractContext(r); err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
|
||||
remaining, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body after ExtractContext: %v", err)
|
||||
}
|
||||
if string(remaining) != body {
|
||||
t.Errorf("body not restored: want %q got %q", body, string(remaining))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_MultipartBodyRestored(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
mw := multipart.NewWriter(&buf)
|
||||
fw, _ := mw.CreateFormField("model")
|
||||
fw.Write([]byte("whisper-1"))
|
||||
ff, _ := mw.CreateFormFile("file", "audio.wav")
|
||||
ff.Write([]byte("fake-audio-bytes"))
|
||||
mw.Close()
|
||||
|
||||
original := buf.Bytes()
|
||||
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", bytes.NewReader(original))
|
||||
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
|
||||
if _, err := extractContext(r); err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
|
||||
remaining, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body after ExtractContext: %v", err)
|
||||
}
|
||||
if !bytes.Equal(remaining, original) {
|
||||
t.Errorf("multipart body not restored: want %d bytes got %d bytes", len(original), len(remaining))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_URLEncodedBodyRestored(t *testing.T) {
|
||||
body := "model=whisper-1&extra=value"
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(body))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
if _, err := extractContext(r); err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
|
||||
remaining, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body after ExtractContext: %v", err)
|
||||
}
|
||||
if string(remaining) != body {
|
||||
t.Errorf("url-encoded body not restored: want %q got %q", body, string(remaining))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetContext(t *testing.T) {
|
||||
ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"})
|
||||
data, ok := ctx.Value(ReqContextKey).(ReqContextData)
|
||||
if !ok {
|
||||
t.Fatalf("ContextKey not set or wrong type")
|
||||
}
|
||||
if data.Model != "llama3" {
|
||||
t.Errorf("want %q got %q", "llama3", data.Model)
|
||||
}
|
||||
if data.ModelID != "llama3" {
|
||||
t.Errorf("want %q got %q", "llama3", data.ModelID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetContext_WithAlias(t *testing.T) {
|
||||
ctx := SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"})
|
||||
data, _ := ctx.Value(ReqContextKey).(ReqContextData)
|
||||
if data.Model != "llama" {
|
||||
t.Errorf("want requested %q got %q", "llama", data.Model)
|
||||
}
|
||||
if data.ModelID != "llama3" {
|
||||
t.Errorf("want real %q got %q", "llama3", data.ModelID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetContext_DoesNotMutateParent(t *testing.T) {
|
||||
parent := context.Background()
|
||||
_ = SetContext(parent, ReqContextData{Model: "llama3", ModelID: "llama3"})
|
||||
if v := parent.Value(ReqContextKey); v != nil {
|
||||
t.Errorf("parent context was mutated: %v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
wantReq string
|
||||
wantReal string
|
||||
wantBool bool
|
||||
}{
|
||||
{
|
||||
name: "model present, same name",
|
||||
ctx: SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"}),
|
||||
wantReq: "llama3",
|
||||
wantReal: "llama3",
|
||||
wantBool: true,
|
||||
},
|
||||
{
|
||||
name: "model present, aliased",
|
||||
ctx: SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"}),
|
||||
wantReq: "llama",
|
||||
wantReal: "llama3",
|
||||
wantBool: true,
|
||||
},
|
||||
{
|
||||
name: "model absent",
|
||||
ctx: context.Background(),
|
||||
wantReq: "",
|
||||
wantReal: "",
|
||||
wantBool: false,
|
||||
},
|
||||
{
|
||||
name: "model is empty string",
|
||||
ctx: SetContext(context.Background(), ReqContextData{Model: "", ModelID: ""}),
|
||||
wantReq: "",
|
||||
wantReal: "",
|
||||
wantBool: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotData, ok := ReadContext(tt.ctx)
|
||||
if gotData.Model != tt.wantReq || gotData.ModelID != tt.wantReal || ok != tt.wantBool {
|
||||
t.Errorf("want (%q, %q, %v) got (%q, %q, %v)", tt.wantReq, tt.wantReal, tt.wantBool, gotData.Model, gotData.ModelID, ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_Streaming_GET(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
wantStreaming bool
|
||||
}{
|
||||
{"streaming true", "model=llama3&stream=true", true},
|
||||
{"streaming false", "model=llama3&stream=false", false},
|
||||
{"no stream param", "model=llama3", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
|
||||
got, err := extractContext(r)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
if got.Streaming != tt.wantStreaming {
|
||||
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_Streaming_JSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantStreaming bool
|
||||
}{
|
||||
{"streaming true", `{"model":"llama3","stream":true}`, true},
|
||||
{"streaming false", `{"model":"llama3","stream":false}`, false},
|
||||
{"no stream param", `{"model":"llama3"}`, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
got, err := extractContext(r)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
if got.Streaming != tt.wantStreaming {
|
||||
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_Streaming_URLEncodedForm(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader("model=whisper-1&stream=true"))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
got, err := extractContext(r)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
if !got.Streaming {
|
||||
t.Error("Streaming should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_ApiKey(t *testing.T) {
|
||||
basicHeader := func(user, pass string) string {
|
||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
|
||||
}
|
||||
cases := []struct {
|
||||
name string
|
||||
method string
|
||||
ct string
|
||||
body string
|
||||
auth string
|
||||
xapi string
|
||||
wantKey string
|
||||
}{
|
||||
{"GET bearer", http.MethodGet, "", "", "Bearer sk-get", "", "sk-get"},
|
||||
{"GET x-api-key", http.MethodGet, "", "", "", "xk-get", "xk-get"},
|
||||
{"GET basic", http.MethodGet, "", "", basicHeader("u", "pw-get"), "", "pw-get"},
|
||||
{"JSON bearer", http.MethodPost, "application/json", `{"model":"m"}`, "Bearer sk-json", "", "sk-json"},
|
||||
{"JSON x-api-key", http.MethodPost, "application/json", `{"model":"m"}`, "", "xk-json", "xk-json"},
|
||||
{"form bearer", http.MethodPost, "application/x-www-form-urlencoded", "model=m", "Bearer sk-form", "", "sk-form"},
|
||||
{"no key", http.MethodGet, "", "", "", "", ""},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
var body io.Reader
|
||||
if c.body != "" {
|
||||
body = strings.NewReader(c.body)
|
||||
}
|
||||
r, _ := http.NewRequest(c.method, "/", body)
|
||||
if c.ct != "" {
|
||||
r.Header.Set("Content-Type", c.ct)
|
||||
}
|
||||
if c.auth != "" {
|
||||
r.Header.Set("Authorization", c.auth)
|
||||
}
|
||||
if c.xapi != "" {
|
||||
r.Header.Set("x-api-key", c.xapi)
|
||||
}
|
||||
got, err := extractContext(r)
|
||||
if err != nil {
|
||||
t.Fatalf("extractContext: %v", err)
|
||||
}
|
||||
if got.ApiKey != c.wantKey {
|
||||
t.Errorf("ApiKey = %q, want %q", got.ApiKey, c.wantKey)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetReqData(t *testing.T) {
|
||||
ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3", Metadata: make(map[string]string)})
|
||||
|
||||
if err := SetReqData(ctx, "client", "web"); err != nil {
|
||||
t.Fatalf("SetReqData: %v", err)
|
||||
}
|
||||
if err := SetReqData(ctx, "trace", "abc123"); err != nil {
|
||||
t.Fatalf("SetReqData: %v", err)
|
||||
}
|
||||
|
||||
data, ok := ReadContext(ctx)
|
||||
if !ok {
|
||||
t.Fatal("context data missing")
|
||||
}
|
||||
if data.Metadata["client"] != "web" {
|
||||
t.Errorf("client = %q, want %q", data.Metadata["client"], "web")
|
||||
}
|
||||
if data.Metadata["trace"] != "abc123" {
|
||||
t.Errorf("trace = %q, want %q", data.Metadata["trace"], "abc123")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetReqData_Errors(t *testing.T) {
|
||||
if err := SetReqData(context.Background(), "k", "v"); err == nil {
|
||||
t.Error("expected error when no request context data exists")
|
||||
}
|
||||
ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"})
|
||||
if err := SetReqData(ctx, "k", "v"); err == nil {
|
||||
t.Error("expected error when metadata map is missing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ExtractAPIKey(t *testing.T) {
|
||||
basicHeader := func(user, pass string) string {
|
||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
|
||||
}
|
||||
cases := []struct {
|
||||
name string
|
||||
auth string
|
||||
xapi string
|
||||
want string
|
||||
}{
|
||||
{"none", "", "", ""},
|
||||
{"bearer", "Bearer tok123", "", "tok123"},
|
||||
{"basic", basicHeader("user", "pw-key"), "", "pw-key"},
|
||||
{"x-api-key", "", "xkey", "xkey"},
|
||||
{"basic beats bearer", basicHeader("u", "bk"), "", "bk"},
|
||||
{"bearer beats x-api-key", "Bearer btok", "xkey", "btok"},
|
||||
{"malformed basic falls back to x-api-key", "Basic !!!notbase64", "xkey", "xkey"},
|
||||
{"lowercase bearer", "bearer tok123", "", "tok123"},
|
||||
{"lowercase basic", "basic " + base64.StdEncoding.EncodeToString([]byte("user:pw-key")), "", "pw-key"},
|
||||
{"mixed case BEARER", "BEARER tok456", "", "tok456"},
|
||||
{"mixed case bAsIc", "bAsIc " + base64.StdEncoding.EncodeToString([]byte("u:bk")), "", "bk"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if c.auth != "" {
|
||||
r.Header.Set("Authorization", c.auth)
|
||||
}
|
||||
if c.xapi != "" {
|
||||
r.Header.Set("x-api-key", c.xapi)
|
||||
}
|
||||
if got := ExtractAPIKey(r); got != c.want {
|
||||
t.Errorf("extractAPIKey() = %q, want %q", got, c.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchContext_UpstreamPath(t *testing.T) {
|
||||
cfg := config.Config{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"m1": {},
|
||||
"author/model": {},
|
||||
"real": {Aliases: []string{"nick"}},
|
||||
},
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
path string
|
||||
wantModel string
|
||||
wantModelID string
|
||||
wantErr bool
|
||||
}{
|
||||
{"known model", "/upstream/m1/v1/chat/completions", "m1", "m1", false},
|
||||
{"model with slash", "/upstream/author/model/v1/chat", "author/model", "author/model", false},
|
||||
{"unknown model", "/upstream/nope/v1/chat/completions", "", "", true},
|
||||
{"bare model path", "/upstream/m1/", "m1", "m1", false},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodPost, c.path, strings.NewReader(`{}`))
|
||||
data, err := FetchContext(r, cfg)
|
||||
if (err != nil) != c.wantErr {
|
||||
t.Fatalf("wantErr=%v got err=%v", c.wantErr, err)
|
||||
}
|
||||
if c.wantErr {
|
||||
return
|
||||
}
|
||||
if data.Model != c.wantModel {
|
||||
t.Errorf("model = %q, want %q", data.Model, c.wantModel)
|
||||
}
|
||||
if data.ModelID != c.wantModelID {
|
||||
t.Errorf("modelID = %q, want %q", data.ModelID, c.wantModelID)
|
||||
}
|
||||
if data.Metadata == nil {
|
||||
t.Error("metadata map not initialized")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchContext_UpstreamPath_DoesNotReadBody(t *testing.T) {
|
||||
cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
|
||||
body := `{"model":"should-not-matter"}`
|
||||
r := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/chat/completions", strings.NewReader(body))
|
||||
|
||||
_, err := FetchContext(r, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("FetchContext: %v", err)
|
||||
}
|
||||
|
||||
// The body should be untouched so the upstream handler can still read it.
|
||||
got, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read body: %v", err)
|
||||
}
|
||||
if string(got) != body {
|
||||
t.Errorf("body was consumed: %q", string(got))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// HTTPError is an error that carries a complete HTTP response. A producer (e.g.
|
||||
// a scheduler shedding a request) returns one of these; a renderer (e.g.
|
||||
// router.SendError) writes the status, headers, and body verbatim instead of
|
||||
// mapping the error to a generic status. It is the seam that lets a component
|
||||
// shed a request with a rich response (e.g. a 429 with rate-limit headers and a
|
||||
// JSON hint body) without the renderer knowing the producer's internals.
|
||||
type HTTPError interface {
|
||||
error
|
||||
StatusCode() int
|
||||
Header() http.Header
|
||||
Body() []byte
|
||||
}
|
||||
|
||||
// ConcurrencyLimitError is an HTTPError for a 429 concurrency-limit rejection.
|
||||
// Zero-value fields fall back to sensible defaults: a 1-second Retry-After and a
|
||||
// JSON hint body.
|
||||
type ConcurrencyLimitError struct {
|
||||
// RetryAfter, when > 0, is sent as the Retry-After header (in seconds).
|
||||
// Defaults to 1.
|
||||
RetryAfter int
|
||||
|
||||
// Message overrides the JSON body's "error" field. Defaults to
|
||||
// "Too many requests".
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e ConcurrencyLimitError) Error() string { return "concurrency limit reached" }
|
||||
|
||||
func (e ConcurrencyLimitError) StatusCode() int { return http.StatusTooManyRequests }
|
||||
|
||||
func (e ConcurrencyLimitError) Header() http.Header {
|
||||
h := http.Header{}
|
||||
h.Set("Content-Type", "application/json")
|
||||
h.Set("Retry-After", e.retryAfter())
|
||||
return h
|
||||
}
|
||||
|
||||
func (e ConcurrencyLimitError) Body() []byte {
|
||||
b, _ := json.Marshal(map[string]string{"error": e.message()})
|
||||
return b
|
||||
}
|
||||
|
||||
func (e ConcurrencyLimitError) retryAfter() string {
|
||||
if e.RetryAfter > 0 {
|
||||
return strconv.Itoa(e.RetryAfter)
|
||||
}
|
||||
return "1"
|
||||
}
|
||||
|
||||
func (e ConcurrencyLimitError) message() string {
|
||||
if e.Message != "" {
|
||||
return e.Message
|
||||
}
|
||||
return "Too many requests"
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
package shared
|
||||
|
||||
import "net"
|
||||
|
||||
// IsLoopbackAddr reports whether listenAddr binds exclusively to loopback.
|
||||
// Addresses with an empty or wildcard host (e.g. ":8080", "0.0.0.0:8080",
|
||||
// "[::]:8080") bind on all interfaces and return false.
|
||||
func IsLoopbackAddr(listenAddr string) bool {
|
||||
host, _, err := net.SplitHostPort(listenAddr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if host == "" {
|
||||
return false
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil {
|
||||
return ip.IsLoopback()
|
||||
}
|
||||
// hostname case (e.g. "localhost")
|
||||
addrs, err := net.LookupHost(host)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, a := range addrs {
|
||||
if !net.ParseIP(a).IsLoopback() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return len(addrs) > 0
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
package configwatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DirWatcher polls a directory for changes to its set of *.yml / *.yaml files.
|
||||
// It fires OnChange when a file is added, removed, or has its mod time/size
|
||||
// change. Like Watcher it is poll-based so it works in Docker bind-mounts and
|
||||
// k8s ConfigMap projections where inotify is unreliable.
|
||||
//
|
||||
// The baseline poll establishes initial state and does not fire OnChange.
|
||||
type DirWatcher struct {
|
||||
Path string
|
||||
Interval time.Duration
|
||||
OnChange func()
|
||||
}
|
||||
|
||||
// dirSnapshot is an ordered map of file name -> file state. The ordering is
|
||||
// derived from sorted filenames so two snapshots compare deterministically
|
||||
// regardless of readdir order. exists reflects whether the directory was
|
||||
// readable at scan time; a missing directory yields exists=false.
|
||||
type dirSnapshot struct {
|
||||
exists bool
|
||||
names []string
|
||||
states map[string]snapshot
|
||||
}
|
||||
|
||||
func newDirSnapshot() dirSnapshot {
|
||||
return dirSnapshot{states: make(map[string]snapshot)}
|
||||
}
|
||||
|
||||
// equal reports whether two snapshots describe the same file set and per-file
|
||||
// state. A missing directory (exists=false) is treated as equal to any other
|
||||
// missing directory regardless of cached names.
|
||||
func (s dirSnapshot) equal(other dirSnapshot) bool {
|
||||
if !s.exists && !other.exists {
|
||||
return true
|
||||
}
|
||||
if s.exists != other.exists {
|
||||
return false
|
||||
}
|
||||
if len(s.names) != len(other.names) {
|
||||
return false
|
||||
}
|
||||
for i, n := range s.names {
|
||||
if other.names[i] != n {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for _, n := range s.names {
|
||||
a, b := s.states[n], other.states[n]
|
||||
if a.exists != b.exists || a.size != b.size || !a.modTime.Equal(b.modTime) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Run blocks until ctx is canceled. It polls Path on Interval and invokes
|
||||
// OnChange whenever the directory's YAML file set changes.
|
||||
//
|
||||
// Policy mirrors the single-file Watcher: disappearance (directory missing or
|
||||
// empty) is treated as a transient rename-style write and stays quiet; the
|
||||
// transition back to present-with-content fires OnChange.
|
||||
func (w *DirWatcher) Run(ctx context.Context) {
|
||||
interval := w.Interval
|
||||
if interval <= 0 {
|
||||
interval = DefaultInterval
|
||||
}
|
||||
|
||||
prev := scanDir(w.Path)
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
cur := scanDir(w.Path)
|
||||
// Suppress transitions involving an empty or missing directory —
|
||||
// these are treated as transient rename-style writes, mirroring
|
||||
// the single-file Watcher. Only present-with-content →
|
||||
// present-with-content (changed) or no-content →
|
||||
// present-with-content fires OnChange.
|
||||
prevHasContent := prev.exists && len(prev.names) > 0
|
||||
curHasContent := cur.exists && len(cur.names) > 0
|
||||
if curHasContent && (!prevHasContent || !prev.equal(cur)) && w.OnChange != nil {
|
||||
w.OnChange()
|
||||
}
|
||||
prev = cur
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// scanDir returns a snapshot of the *.yml/*.yaml files in dir. If the
|
||||
// directory cannot be read (missing, permission denied) the snapshot reports
|
||||
// exists=false; the next successful scan will detect the recovery and fire
|
||||
// OnChange.
|
||||
func scanDir(dir string) dirSnapshot {
|
||||
snap := newDirSnapshot()
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return snap // exists=false
|
||||
}
|
||||
snap.exists = true
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := e.Name()
|
||||
if !strings.HasSuffix(name, ".yml") && !strings.HasSuffix(name, ".yaml") {
|
||||
continue
|
||||
}
|
||||
fi, err := os.Stat(filepath.Join(dir, name))
|
||||
if err != nil {
|
||||
// File disappeared between ReadDir and Stat; skip it — the
|
||||
// next poll will observe the removal cleanly.
|
||||
continue
|
||||
}
|
||||
snap.names = append(snap.names, name)
|
||||
snap.states[name] = snapshot{
|
||||
exists: true,
|
||||
modTime: fi.ModTime(),
|
||||
size: fi.Size(),
|
||||
}
|
||||
}
|
||||
sort.Strings(snap.names)
|
||||
return snap
|
||||
}
|
||||
@@ -0,0 +1,199 @@
|
||||
package configwatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// startDirWatcher launches w.Run in a goroutine and returns a function that
|
||||
// cancels the context and waits for Run to return.
|
||||
func startDirWatcher(t *testing.T, w *DirWatcher) func() {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
w.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
return func() {
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("DirWatcher did not stop within 2s of cancel")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeYAMLInDir(t *testing.T, dir, name, content string) {
|
||||
t.Helper()
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, name), []byte(content), 0o644))
|
||||
}
|
||||
|
||||
func TestDirWatcher_NoFireOnBaseline(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
|
||||
var n int64
|
||||
stop := startDirWatcher(t, &DirWatcher{
|
||||
Path: dir,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
|
||||
time.Sleep(testInterval * 5)
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&n), "baseline poll must not fire")
|
||||
}
|
||||
|
||||
func TestDirWatcher_DetectsFileAdd(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
|
||||
var n int64
|
||||
stop := startDirWatcher(t, &DirWatcher{
|
||||
Path: dir,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
writeYAMLInDir(t, dir, "b.yaml", "b")
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when a file is added")
|
||||
}
|
||||
|
||||
func TestDirWatcher_DetectsFileRemoval(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
writeYAMLInDir(t, dir, "b.yaml", "b")
|
||||
|
||||
var n int64
|
||||
stop := startDirWatcher(t, &DirWatcher{
|
||||
Path: dir,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
require.NoError(t, os.Remove(filepath.Join(dir, "b.yaml")))
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when a file is removed")
|
||||
}
|
||||
|
||||
func TestDirWatcher_DetectsModTimeChange(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
|
||||
base := time.Now().Add(-1 * time.Hour).Truncate(time.Second)
|
||||
require.NoError(t, os.Chtimes(filepath.Join(dir, "a.yaml"), base, base))
|
||||
|
||||
var n int64
|
||||
stop := startDirWatcher(t, &DirWatcher{
|
||||
Path: dir,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
require.NoError(t, os.Chtimes(filepath.Join(dir, "a.yaml"), base.Add(10*time.Second), base.Add(10*time.Second)))
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire after mtime change")
|
||||
}
|
||||
|
||||
func TestDirWatcher_IgnoresNonYAMLFiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
|
||||
var n int64
|
||||
stop := startDirWatcher(t, &DirWatcher{
|
||||
Path: dir,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
// Adding a .txt file must not fire.
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("hi"), 0o644))
|
||||
time.Sleep(testInterval * 4)
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&n), "non-YAML files must be ignored")
|
||||
|
||||
// Adding a .yml file must fire.
|
||||
writeYAMLInDir(t, dir, "b.yml", "b")
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire for *.yml files")
|
||||
}
|
||||
|
||||
func TestDirWatcher_MissingDirRecovers(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
|
||||
var n int64
|
||||
stop := startDirWatcher(t, &DirWatcher{
|
||||
Path: dir,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
// Remove the directory. No fire expected on disappearance alone.
|
||||
require.NoError(t, os.RemoveAll(dir))
|
||||
time.Sleep(testInterval * 3)
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&n), "directory removal alone must not fire")
|
||||
|
||||
// Recreate the directory and a YAML file; the recovery should fire.
|
||||
require.NoError(t, os.MkdirAll(dir, 0o755))
|
||||
writeYAMLInDir(t, dir, "recovered.yaml", "r")
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when dir returns with content")
|
||||
}
|
||||
|
||||
func TestDirWatcher_EmptyDirSuppressedThenRecovers(t *testing.T) {
|
||||
// Present-with-content → empty (all YAML removed, dir still exists)
|
||||
// must stay quiet — treated as transient per the documented policy.
|
||||
// The transition back to content fires.
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
|
||||
var n int64
|
||||
stop := startDirWatcher(t, &DirWatcher{
|
||||
Path: dir,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
// Remove the only YAML file. Dir still exists but is empty of YAML.
|
||||
require.NoError(t, os.Remove(filepath.Join(dir, "a.yaml")))
|
||||
time.Sleep(testInterval * 4)
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&n), "emptying the directory must not fire")
|
||||
|
||||
// Add a YAML file back; transition to present-with-content fires.
|
||||
writeYAMLInDir(t, dir, "c.yaml", "c")
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when content returns")
|
||||
}
|
||||
|
||||
func TestDirWatcher_ContextCancelStopsRun(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
|
||||
w := &DirWatcher{Path: dir, Interval: testInterval}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() { w.Run(ctx); close(done) }()
|
||||
|
||||
time.Sleep(testInterval * 2)
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Run did not return within 2s of cancel")
|
||||
}
|
||||
}
|
||||
+73
-21
@@ -6,6 +6,7 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -19,6 +20,7 @@ import (
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/server"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
"github.com/mostlygeek/llama-swap/internal/watcher"
|
||||
@@ -53,7 +55,8 @@ var logTimeFormats = map[string]string{
|
||||
}
|
||||
|
||||
func main() {
|
||||
flagConfig := flag.String("config", "", "path to config file (required)")
|
||||
flagConfig := flag.String("config", "", "path to config file")
|
||||
flagConfigDir := flag.String("config-dir", "", "directory of *.yml/*.yaml config files (additive to -config)")
|
||||
flagListen := flag.String("listen", "", "listen address (default :8080 or :8443 for TLS)")
|
||||
flagCertFile := flag.String("tls-cert-file", "", "TLS certificate file")
|
||||
flagKeyFile := flag.String("tls-key-file", "", "TLS key file")
|
||||
@@ -66,8 +69,8 @@ func main() {
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if *flagConfig == "" {
|
||||
slog.Error("-config is required")
|
||||
if *flagConfig == "" && *flagConfigDir == "" {
|
||||
slog.Error("at least one of -config or -config-dir must be provided")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@@ -86,10 +89,9 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
configPath := *flagConfig
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
cfg, err := config.LoadConfigSources(*flagConfig, *flagConfigDir)
|
||||
if err != nil {
|
||||
slog.Error("failed to load config", "path", configPath, "error", err)
|
||||
slog.Error("failed to load config", "config", *flagConfig, "config-dir", *flagConfigDir, "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@@ -122,6 +124,13 @@ func main() {
|
||||
applyLogSettings(cfg)
|
||||
proxyLog.Debugf("PID: %d", os.Getpid())
|
||||
|
||||
// On Windows, bind the process tree to a Job Object so every upstream
|
||||
// process is reaped when llama-swap exits — even on a forced kill. No-op
|
||||
// elsewhere. Non-fatal: a failure just falls back to per-process teardown.
|
||||
if err := process.SetupTreeCleanup(); err != nil {
|
||||
proxyLog.Warnf("failed to set up process tree cleanup: %v", err)
|
||||
}
|
||||
|
||||
// perfMon outlives config reloads; its config is updated in place.
|
||||
var perfMon *perf.Monitor
|
||||
if !cfg.Performance.Disabled {
|
||||
@@ -178,7 +187,7 @@ func main() {
|
||||
|
||||
proxyLog.Info("reloading configuration")
|
||||
|
||||
newCfg, err := config.LoadConfig(configPath)
|
||||
newCfg, err := config.LoadConfigSources(*flagConfig, *flagConfigDir)
|
||||
if err != nil {
|
||||
proxyLog.Warnf("failed to reload config: %v", err)
|
||||
return
|
||||
@@ -221,19 +230,37 @@ func main() {
|
||||
defer watcherCancel()
|
||||
|
||||
if *flagWatchConfig {
|
||||
absConfigPath, err := filepath.Abs(configPath)
|
||||
if err != nil {
|
||||
slog.Error("watch-config: failed to resolve config path", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
proxyLog.Info("watching configuration for changes (poll-based, 2s interval)")
|
||||
go func() {
|
||||
(&configwatcher.Watcher{
|
||||
Path: absConfigPath,
|
||||
Interval: configwatcher.DefaultInterval,
|
||||
OnChange: reload,
|
||||
}).Run(watcherCtx)
|
||||
}()
|
||||
|
||||
if *flagConfig != "" {
|
||||
absConfigPath, err := filepath.Abs(*flagConfig)
|
||||
if err != nil {
|
||||
slog.Error("watch-config: failed to resolve config path", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
go func() {
|
||||
(&configwatcher.Watcher{
|
||||
Path: absConfigPath,
|
||||
Interval: configwatcher.DefaultInterval,
|
||||
OnChange: reload,
|
||||
}).Run(watcherCtx)
|
||||
}()
|
||||
}
|
||||
|
||||
if *flagConfigDir != "" {
|
||||
absConfigDir, err := filepath.Abs(*flagConfigDir)
|
||||
if err != nil {
|
||||
slog.Error("watch-config: failed to resolve config-dir path", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
go func() {
|
||||
(&configwatcher.DirWatcher{
|
||||
Path: absConfigDir,
|
||||
Interval: configwatcher.DefaultInterval,
|
||||
OnChange: reload,
|
||||
}).Run(watcherCtx)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
@@ -254,6 +281,11 @@ func main() {
|
||||
}
|
||||
}()
|
||||
|
||||
if !shared.IsLoopbackAddr(listenAddr) {
|
||||
_, port, _ := net.SplitHostPort(listenAddr)
|
||||
proxyLog.Infof("llama-swap is reachable by all hosts on the network, use -listen localhost:%s to restrict to loopback only", port)
|
||||
}
|
||||
|
||||
exitChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
@@ -267,6 +299,16 @@ func main() {
|
||||
proxyLog.Infof("received signal %v, shutting down", sig)
|
||||
watcherCancel()
|
||||
|
||||
// Backstop against a stalled shutdown: force the process to
|
||||
// exit once the whole graceful sequence has had its full budget.
|
||||
// On Windows the Job Object reaps upstream processes on exit, so
|
||||
// a forced exit still cleans up rather than orphaning children.
|
||||
go func() {
|
||||
time.Sleep(shutdownTimeout + 5*time.Second)
|
||||
proxyLog.Warnf("graceful shutdown exceeded %v, forcing exit", shutdownTimeout)
|
||||
os.Exit(1)
|
||||
}()
|
||||
|
||||
activeMu.RLock()
|
||||
srv := activeSrv
|
||||
activeMu.RUnlock()
|
||||
@@ -275,13 +317,23 @@ func main() {
|
||||
// drain without blocking on them for the full timeout.
|
||||
srv.CloseStreams()
|
||||
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||
// Both phases share a single deadline so total shutdown is
|
||||
// bounded by shutdownTimeout rather than 2x it.
|
||||
deadline := time.Now().Add(shutdownTimeout)
|
||||
shutdownCtx, cancel := context.WithDeadline(context.Background(), deadline)
|
||||
defer cancel()
|
||||
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
proxyLog.Warnf("http server shutdown error: %v", err)
|
||||
}
|
||||
|
||||
if err := srv.Shutdown(shutdownTimeout); err != nil {
|
||||
// Clamp the remaining budget to a small positive value: a
|
||||
// non-positive timeout makes the router fall back to its own
|
||||
// healthCheckTimeout, which would defeat the shared deadline.
|
||||
remaining := time.Until(deadline)
|
||||
if remaining <= 0 {
|
||||
remaining = time.Millisecond
|
||||
}
|
||||
if err := srv.Shutdown(remaining); err != nil {
|
||||
proxyLog.Warnf("router shutdown error: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
ui_dist/*
|
||||
@@ -1,27 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Custom discard writer that implements http.ResponseWriter but just discards everything
|
||||
type DiscardWriter struct {
|
||||
header http.Header
|
||||
status int
|
||||
}
|
||||
|
||||
func (w *DiscardWriter) Header() http.Header {
|
||||
if w.header == nil {
|
||||
w.header = make(http.Header)
|
||||
}
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *DiscardWriter) Write(data []byte) (int, error) {
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *DiscardWriter) WriteHeader(code int) {
|
||||
w.status = code
|
||||
}
|
||||
|
||||
// Satisfy the http.Flusher interface for streaming responses
|
||||
func (w *DiscardWriter) Flush() {}
|
||||
@@ -1,60 +0,0 @@
|
||||
package proxy
|
||||
|
||||
// package level registry of the different event types
|
||||
|
||||
const ProcessStateChangeEventID = 0x01
|
||||
const ChatCompletionStatsEventID = 0x02
|
||||
const ConfigFileChangedEventID = 0x03
|
||||
const ActivityLogEventID = 0x05
|
||||
const ModelPreloadedEventID = 0x06
|
||||
const InFlightRequestsEventID = 0x07
|
||||
|
||||
type ProcessStateChangeEvent struct {
|
||||
ProcessName string
|
||||
NewState ProcessState
|
||||
OldState ProcessState
|
||||
}
|
||||
|
||||
func (e ProcessStateChangeEvent) Type() uint32 {
|
||||
return ProcessStateChangeEventID
|
||||
}
|
||||
|
||||
type ChatCompletionStats struct {
|
||||
TokensGenerated int
|
||||
}
|
||||
|
||||
func (e ChatCompletionStats) Type() uint32 {
|
||||
return ChatCompletionStatsEventID
|
||||
}
|
||||
|
||||
type ReloadingState int
|
||||
|
||||
const (
|
||||
ReloadingStateStart ReloadingState = iota
|
||||
ReloadingStateEnd
|
||||
)
|
||||
|
||||
type ConfigFileChangedEvent struct {
|
||||
ReloadingState ReloadingState
|
||||
}
|
||||
|
||||
func (e ConfigFileChangedEvent) Type() uint32 {
|
||||
return ConfigFileChangedEventID
|
||||
}
|
||||
|
||||
type ModelPreloadedEvent struct {
|
||||
ModelName string
|
||||
Success bool
|
||||
}
|
||||
|
||||
func (e ModelPreloadedEvent) Type() uint32 {
|
||||
return ModelPreloadedEventID
|
||||
}
|
||||
|
||||
type InFlightRequestsEvent struct {
|
||||
Total int
|
||||
}
|
||||
|
||||
func (e InFlightRequestsEvent) Type() uint32 {
|
||||
return InFlightRequestsEventID
|
||||
}
|
||||
@@ -1,304 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
nextTestPort int = 12000
|
||||
portMutex sync.Mutex
|
||||
testLogger = logmon.NewWriter(os.Stdout)
|
||||
simpleResponderPath = getSimpleResponderPath()
|
||||
)
|
||||
|
||||
// Check if the binary exists
|
||||
func TestMain(m *testing.M) {
|
||||
binaryPath := getSimpleResponderPath()
|
||||
if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
|
||||
fmt.Printf("simple-responder not found at %s, did you `make simple-responder`?\n", binaryPath)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
switch os.Getenv("LOG_LEVEL") {
|
||||
case "debug":
|
||||
testLogger.SetLogLevel(logmon.LevelDebug)
|
||||
case "warn":
|
||||
testLogger.SetLogLevel(logmon.LevelWarn)
|
||||
case "info":
|
||||
testLogger.SetLogLevel(logmon.LevelInfo)
|
||||
default:
|
||||
testLogger.SetLogLevel(logmon.LevelWarn)
|
||||
}
|
||||
|
||||
m.Run()
|
||||
}
|
||||
|
||||
// Helper function to get the binary path
|
||||
func getSimpleResponderPath() string {
|
||||
goos := runtime.GOOS
|
||||
goarch := runtime.GOARCH
|
||||
|
||||
if goos == "windows" {
|
||||
return filepath.Join("..", "build", "simple-responder.exe")
|
||||
} else {
|
||||
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||
}
|
||||
}
|
||||
|
||||
func getTestPort() int {
|
||||
portMutex.Lock()
|
||||
defer portMutex.Unlock()
|
||||
|
||||
port := nextTestPort
|
||||
nextTestPort++
|
||||
|
||||
return port
|
||||
}
|
||||
|
||||
// testConfigFromYAML substitutes {{RESPONDER}} with the simple-responder path and
|
||||
// loads through the real config pipeline (env vars, macros, port assignment, etc.)
|
||||
func testConfigFromYAML(t *testing.T, yamlTmpl string) config.Config {
|
||||
t.Helper()
|
||||
yamlStr := strings.ReplaceAll(yamlTmpl, "{{RESPONDER}}", filepath.ToSlash(simpleResponderPath))
|
||||
cfg, err := config.LoadConfigFromReader(strings.NewReader(yamlStr))
|
||||
require.NoError(t, err)
|
||||
return cfg
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
|
||||
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) config.ModelConfig {
|
||||
// Convert path to forward slashes for cross-platform compatibility
|
||||
// Windows handles forward slashes in paths correctly
|
||||
cmdPath := filepath.ToSlash(simpleResponderPath)
|
||||
|
||||
// Create a YAML string with just the values we want to set
|
||||
yamlStr := fmt.Sprintf(`
|
||||
cmd: '%s --port %d --silent --respond %s'
|
||||
proxy: "http://127.0.0.1:%d"
|
||||
`, cmdPath, port, expectedMessage, port)
|
||||
|
||||
var cfg config.ModelConfig
|
||||
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
|
||||
panic(fmt.Sprintf("failed to unmarshal test config: %v in [%s]", err, yamlStr))
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// injectTestHandlers sets a testHandler on every Process in every ProcessGroup
|
||||
// of the given ProxyManager, bypassing subprocess launches. modelResponses maps
|
||||
// model IDs to their respond strings; if a model ID is not in the map, the model
|
||||
// ID itself is used.
|
||||
func injectTestHandlers(pm *ProxyManager, modelResponses map[string]string) {
|
||||
for _, pg := range pm.processGroups {
|
||||
for modelID, process := range pg.processes {
|
||||
respond := modelID
|
||||
if r, ok := modelResponses[modelID]; ok {
|
||||
respond = r
|
||||
}
|
||||
process.testHandler = newTestHandler(respond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newTestHandler returns an http.Handler that mimics simple-responder's API.
|
||||
// It supports the endpoints that routing tests depend on, without launching
|
||||
// any subprocess or binding any port.
|
||||
func respondJSON(w http.ResponseWriter, respond string, bodyBytes []byte) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"responseMessage": respond,
|
||||
"h_content_length": strconv.Itoa(len(bodyBytes)),
|
||||
"request_body": string(bodyBytes),
|
||||
"usage": map[string]any{
|
||||
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||
},
|
||||
"timings": map[string]any{
|
||||
"prompt_n": 25, "prompt_ms": 13, "predicted_n": 10,
|
||||
"predicted_ms": 17, "predicted_per_second": 10,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func newTestHandler(respond string) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
isStreaming := r.URL.Query().Get("stream") == "true"
|
||||
|
||||
if wait := r.URL.Query().Get("wait"); wait != "" {
|
||||
if d, err := time.ParseDuration(wait); err == nil {
|
||||
time.Sleep(d)
|
||||
}
|
||||
}
|
||||
|
||||
if isStreaming {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
flusher := w.(http.Flusher)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
data, _ := json.Marshal(map[string]any{
|
||||
"created": time.Now().Unix(),
|
||||
"choices": []map[string]any{
|
||||
{"index": 0, "delta": map[string]any{"content": "asdf"}, "finish_reason": nil},
|
||||
},
|
||||
})
|
||||
fmt.Fprintf(w, "event: message\ndata: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
finalData, _ := json.Marshal(map[string]any{
|
||||
"usage": map[string]any{
|
||||
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||
},
|
||||
"timings": map[string]any{
|
||||
"prompt_n": 25, "prompt_ms": 13, "predicted_n": 10,
|
||||
"predicted_ms": 17, "predicted_per_second": 10,
|
||||
},
|
||||
})
|
||||
fmt.Fprintf(w, "event: message\ndata: %s\n\n", finalData)
|
||||
flusher.Flush()
|
||||
|
||||
fmt.Fprintf(w, "event: message\ndata: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
} else {
|
||||
respondJSON(w, respond, bodyBytes)
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
if modelName != respond {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, respond)})
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]string{"message": "ok"})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
respondJSON(w, respond, bodyBytes)
|
||||
})
|
||||
|
||||
for _, path := range []string{
|
||||
"/chat/completions", "/completions",
|
||||
"/responses", "/messages", "/messages/count_tokens",
|
||||
"/embeddings", "/rerank", "/reranking",
|
||||
} {
|
||||
mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
respondJSON(w, respond, bodyBytes)
|
||||
})
|
||||
}
|
||||
|
||||
mux.HandleFunc("/completion", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"responseMessage": respond,
|
||||
"usage": map[string]any{
|
||||
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/audio/transcriptions", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseMultipartForm(10 << 20); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
|
||||
return
|
||||
}
|
||||
model := r.FormValue("model")
|
||||
if model == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Missing model parameter"})
|
||||
return
|
||||
}
|
||||
file, _, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Error getting file: %s", err)})
|
||||
return
|
||||
}
|
||||
fileBytes, _ := io.ReadAll(file)
|
||||
file.Close()
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"text": fmt.Sprintf("The length of the file is %d bytes", len(fileBytes)),
|
||||
"model": model,
|
||||
"h_content_type": r.Header.Get("Content-Type"),
|
||||
"h_content_length": r.Header.Get("Content-Length"),
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/audio/voices", func(w http.ResponseWriter, r *http.Request) {
|
||||
model := r.URL.Query().Get("model")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"voices": []string{"voice1"}, "model": model,
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
fmt.Fprint(w, respond)
|
||||
})
|
||||
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
fmt.Fprintf(w, "%s %s", r.Method, r.URL.Path)
|
||||
})
|
||||
|
||||
mux.HandleFunc("/sdapi/v1/txt2img", func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"model": modelName, "images": []string{},
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/sdapi/v1/img2img", func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"model": modelName, "images": []string{},
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/sdapi/v1/loras", func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"loras": []string{},
|
||||
})
|
||||
})
|
||||
|
||||
return mux
|
||||
}
|
||||
-330
@@ -1,330 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
// MatrixSolver contains pure swap-decision logic with no Process dependencies.
|
||||
// It is safe for concurrent reads after construction.
|
||||
type MatrixSolver struct {
|
||||
expandedSets []config.ExpandedSet // all valid model combinations
|
||||
evictCosts map[string]int // real model name -> eviction cost (default 1)
|
||||
modelToSets map[string][]int // model name -> indices into expandedSets
|
||||
}
|
||||
|
||||
// NewMatrixSolver builds a solver from expanded sets and eviction costs.
|
||||
func NewMatrixSolver(expandedSets []config.ExpandedSet, evictCosts map[string]int) *MatrixSolver {
|
||||
modelToSets := make(map[string][]int)
|
||||
for i, es := range expandedSets {
|
||||
for _, model := range es.Models {
|
||||
modelToSets[model] = append(modelToSets[model], i)
|
||||
}
|
||||
}
|
||||
|
||||
return &MatrixSolver{
|
||||
expandedSets: expandedSets,
|
||||
evictCosts: evictCosts,
|
||||
modelToSets: modelToSets,
|
||||
}
|
||||
}
|
||||
|
||||
// SolveResult describes what the solver decided.
|
||||
type SolveResult struct {
|
||||
Evict []string // running models that must be stopped
|
||||
TargetSet []string // the chosen set of models (for informational purposes)
|
||||
SetName string // name of the chosen set
|
||||
DSL string // original DSL expression for the chosen set
|
||||
TotalCost int // total eviction cost
|
||||
}
|
||||
|
||||
// Solve determines which models to evict when a model is requested.
|
||||
//
|
||||
// Algorithm:
|
||||
// 1. If requestedModel is already running, no eviction needed.
|
||||
// 2. Find all sets containing requestedModel.
|
||||
// 3. If no sets found, the model runs alone; evict all running models.
|
||||
// 4. For each candidate set, compute cost = sum of evict_costs for running
|
||||
// models NOT in that set.
|
||||
// 5. Pick lowest cost. Ties broken by definition order (index in expandedSets).
|
||||
// 6. Return models to evict and the chosen set.
|
||||
func (s *MatrixSolver) Solve(requestedModel string, runningModels []string) (SolveResult, error) {
|
||||
// If already running, nothing to do (but fill in set info for logging)
|
||||
if slices.Contains(runningModels, requestedModel) {
|
||||
setName, dsl := s.findMatchingSet(requestedModel, runningModels)
|
||||
return SolveResult{
|
||||
TargetSet: runningModels,
|
||||
SetName: setName,
|
||||
DSL: dsl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
candidateIndices := s.modelToSets[requestedModel]
|
||||
|
||||
// Model not in any set: runs alone, evict everything
|
||||
if len(candidateIndices) == 0 {
|
||||
evict := make([]string, len(runningModels))
|
||||
copy(evict, runningModels)
|
||||
return SolveResult{
|
||||
Evict: evict,
|
||||
TargetSet: []string{requestedModel},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Find the cheapest candidate set
|
||||
bestCost := -1
|
||||
bestIdx := -1
|
||||
|
||||
for _, idx := range candidateIndices {
|
||||
setModels := s.expandedSets[idx].Models
|
||||
cost := 0
|
||||
for _, running := range runningModels {
|
||||
if !slices.Contains(setModels, running) {
|
||||
cost += s.evictCost(running)
|
||||
}
|
||||
}
|
||||
|
||||
if bestCost < 0 || cost < bestCost || (cost == bestCost && idx < bestIdx) {
|
||||
bestCost = cost
|
||||
bestIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
// Determine which running models to evict
|
||||
chosen := s.expandedSets[bestIdx]
|
||||
var evict []string
|
||||
for _, running := range runningModels {
|
||||
if !slices.Contains(chosen.Models, running) {
|
||||
evict = append(evict, running)
|
||||
}
|
||||
}
|
||||
|
||||
return SolveResult{
|
||||
Evict: evict,
|
||||
TargetSet: chosen.Models,
|
||||
SetName: chosen.SetName,
|
||||
DSL: chosen.DSL,
|
||||
TotalCost: bestCost,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// findMatchingSet finds the expanded set that contains all running models.
|
||||
// Returns the set name and DSL, or empty strings if no match.
|
||||
func (s *MatrixSolver) findMatchingSet(requestedModel string, runningModels []string) (string, string) {
|
||||
for _, idx := range s.modelToSets[requestedModel] {
|
||||
set := s.expandedSets[idx]
|
||||
allInSet := true
|
||||
for _, m := range runningModels {
|
||||
if !slices.Contains(set.Models, m) {
|
||||
allInSet = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allInSet {
|
||||
return set.SetName, set.DSL
|
||||
}
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func (s *MatrixSolver) evictCost(model string) int {
|
||||
if cost, ok := s.evictCosts[model]; ok {
|
||||
return cost
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// Matrix manages processes using solver-based swap logic.
|
||||
type Matrix struct {
|
||||
sync.Mutex
|
||||
solver *MatrixSolver
|
||||
processes map[string]*Process // all processes keyed by real model name
|
||||
config config.Config
|
||||
proxyLogger *logmon.Monitor
|
||||
upstreamLogger *logmon.Monitor
|
||||
|
||||
// inflight tracks ProxyRequest calls that have released m.Lock but may
|
||||
// not yet have incremented Process.inFlightRequests. A concurrent
|
||||
// request that needs to evict models waits for inflight to drain under
|
||||
// m.Lock before stopping anything. Without this, a request that
|
||||
// released m.Lock but has not yet reached Process.inFlightRequests.Add(1)
|
||||
// races with Stop()'s Wait() and can be killed mid-request.
|
||||
inflight sync.WaitGroup
|
||||
|
||||
// testDelayFastPath is a test-only hook invoked in the no-eviction path
|
||||
// after m.Lock is released but before the request is dispatched to
|
||||
// Process.ProxyRequest. Tests use it to park a request at the exact
|
||||
// race window to deterministically reproduce the race.
|
||||
testDelayFastPath func()
|
||||
}
|
||||
|
||||
// NewMatrix creates a Matrix from config. It creates a Process for every
|
||||
// model defined in the config (any model can run alone even if not in a set).
|
||||
func NewMatrix(cfg config.Config, proxyLogger, upstreamLogger *logmon.Monitor) *Matrix {
|
||||
processes := make(map[string]*Process)
|
||||
for modelID, modelConfig := range cfg.Models {
|
||||
processLogger := logmon.NewWriter(upstreamLogger)
|
||||
process := NewProcess(modelID, cfg.HealthCheckTimeout, modelConfig, processLogger, proxyLogger)
|
||||
processes[modelID] = process
|
||||
}
|
||||
|
||||
evictCosts := cfg.Matrix.ResolvedEvictCosts()
|
||||
|
||||
return &Matrix{
|
||||
solver: NewMatrixSolver(cfg.ExpandedSets, evictCosts),
|
||||
processes: processes,
|
||||
config: cfg,
|
||||
proxyLogger: proxyLogger,
|
||||
upstreamLogger: upstreamLogger,
|
||||
}
|
||||
}
|
||||
|
||||
// ProxyRequest handles the swap logic and proxies the request to the model.
|
||||
func (m *Matrix) ProxyRequest(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
process, ok := m.processes[modelID]
|
||||
if !ok {
|
||||
return fmt.Errorf("model %s not found in matrix", modelID)
|
||||
}
|
||||
|
||||
m.Lock()
|
||||
running := m.runningModels()
|
||||
result, err := m.solver.Solve(modelID, running)
|
||||
if err != nil {
|
||||
m.Unlock()
|
||||
return fmt.Errorf("matrix solver error: %w", err)
|
||||
}
|
||||
|
||||
// Log solver decision
|
||||
if len(result.Evict) > 0 {
|
||||
m.proxyLogger.Infof("Matrix: model=%s set=%s dsl=%q evict=%v target=%v cost=%d",
|
||||
modelID, result.SetName, result.DSL, result.Evict, result.TargetSet, result.TotalCost)
|
||||
} else if len(running) == 0 {
|
||||
m.proxyLogger.Infof("Matrix: model=%s starting (no models running)", modelID)
|
||||
} else {
|
||||
m.proxyLogger.Debugf("Matrix: model=%s already running in set=%s dsl=%q", modelID, result.SetName, result.DSL)
|
||||
}
|
||||
|
||||
// Evict models that need to be stopped
|
||||
if len(result.Evict) > 0 {
|
||||
// Wait for any in-flight ProxyRequest calls to register on their
|
||||
// Process before stopping anything. Without this, a request that
|
||||
// released m.Lock but has not yet incremented
|
||||
// Process.inFlightRequests races with Stop() and can be killed
|
||||
// mid-request.
|
||||
m.inflight.Wait()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, evictModel := range result.Evict {
|
||||
if p, exists := m.processes[evictModel]; exists {
|
||||
wg.Add(1)
|
||||
go func(p *Process) {
|
||||
defer wg.Done()
|
||||
p.Stop()
|
||||
}(p)
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Register this request in inflight before releasing m.Lock so a
|
||||
// concurrent eviction will wait for it to complete.
|
||||
m.inflight.Add(1)
|
||||
defer m.inflight.Done()
|
||||
isFastPath := len(result.Evict) == 0
|
||||
m.Unlock()
|
||||
|
||||
if isFastPath && m.testDelayFastPath != nil {
|
||||
m.testDelayFastPath()
|
||||
}
|
||||
|
||||
// Proxy the request (Process handles on-demand start)
|
||||
process.ProxyRequest(w, r)
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopProcesses stops all running processes.
|
||||
func (m *Matrix) StopProcesses(strategy StopStrategy) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range m.processes {
|
||||
wg.Add(1)
|
||||
go func(p *Process) {
|
||||
defer wg.Done()
|
||||
switch strategy {
|
||||
case StopImmediately:
|
||||
p.StopImmediately()
|
||||
default:
|
||||
p.Stop()
|
||||
}
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// StopProcess stops a single process by model ID.
|
||||
func (m *Matrix) StopProcess(modelID string, strategy StopStrategy) error {
|
||||
process, ok := m.processes[modelID]
|
||||
if !ok {
|
||||
return fmt.Errorf("process not found for %s", modelID)
|
||||
}
|
||||
|
||||
switch strategy {
|
||||
case StopImmediately:
|
||||
process.StopImmediately()
|
||||
default:
|
||||
process.Stop()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown shuts down all processes.
|
||||
func (m *Matrix) Shutdown() {
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range m.processes {
|
||||
wg.Add(1)
|
||||
go func(p *Process) {
|
||||
defer wg.Done()
|
||||
p.Shutdown()
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// RunningModels returns model names currently in an active (non-stopped) state.
|
||||
func (m *Matrix) RunningModels() []string {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return m.runningModels()
|
||||
}
|
||||
|
||||
// runningModels returns running model names (caller must hold lock).
|
||||
func (m *Matrix) runningModels() []string {
|
||||
var running []string
|
||||
for id, process := range m.processes {
|
||||
if process.CurrentState() != StateStopped && process.CurrentState() != StateShutdown {
|
||||
running = append(running, id)
|
||||
}
|
||||
}
|
||||
sort.Strings(running)
|
||||
return running
|
||||
}
|
||||
|
||||
// GetProcess returns the Process for a model.
|
||||
func (m *Matrix) GetProcess(modelID string) (*Process, bool) {
|
||||
p, ok := m.processes[modelID]
|
||||
return p, ok
|
||||
}
|
||||
|
||||
// HasModel returns true if the model is managed by this matrix.
|
||||
func (m *Matrix) HasModel(modelID string) bool {
|
||||
_, ok := m.processes[modelID]
|
||||
return ok
|
||||
}
|
||||
@@ -1,349 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Helper to build expanded sets for solver tests
|
||||
func makeExpandedSets(sets ...struct {
|
||||
name string
|
||||
models []string
|
||||
}) []config.ExpandedSet {
|
||||
var result []config.ExpandedSet
|
||||
for _, s := range sets {
|
||||
result = append(result, config.ExpandedSet{
|
||||
SetName: s.name,
|
||||
Models: s.models,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func es(name string, models ...string) struct {
|
||||
name string
|
||||
models []string
|
||||
} {
|
||||
return struct {
|
||||
name string
|
||||
models []string
|
||||
}{name, models}
|
||||
}
|
||||
|
||||
func TestMatrixSolver_AlreadyRunning(t *testing.T) {
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(es("s1", "a", "b")),
|
||||
nil,
|
||||
)
|
||||
|
||||
result, err := solver.Solve("a", []string{"a"})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"a"}, result.TargetSet)
|
||||
assert.Equal(t, "s1", result.SetName)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_NotInAnySet_RunsAlone(t *testing.T) {
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(es("s1", "a", "b")),
|
||||
nil,
|
||||
)
|
||||
|
||||
// Model "c" not in any set
|
||||
result, err := solver.Solve("c", []string{"a", "b"})
|
||||
require.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"a", "b"}, result.Evict)
|
||||
assert.Equal(t, []string{"c"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_NotInAnySet_NothingRunning(t *testing.T) {
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(es("s1", "a", "b")),
|
||||
nil,
|
||||
)
|
||||
|
||||
result, err := solver.Solve("c", []string{})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"c"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_SingleSet_EvictsNonMembers(t *testing.T) {
|
||||
// Set: [a, b]. Request a when b and c are running.
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(es("s1", "a", "b")),
|
||||
nil,
|
||||
)
|
||||
|
||||
result, err := solver.Solve("a", []string{"b", "c"})
|
||||
require.NoError(t, err)
|
||||
// c is not in the set, so it gets evicted. b is in the set, so it stays.
|
||||
assert.Equal(t, []string{"c"}, result.Evict)
|
||||
assert.Equal(t, []string{"a", "b"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_PicksLowestCost(t *testing.T) {
|
||||
// Two sets containing model "a":
|
||||
// s1: [a, v] — if v is running, cost=0; if L is running, cost=30
|
||||
// s2: [a, L] — if L is running, cost=0; if v is running, cost=50
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(
|
||||
es("s1", "a", "v"),
|
||||
es("s2", "a", "L"),
|
||||
),
|
||||
map[string]int{"v": 50, "L": 30},
|
||||
)
|
||||
|
||||
// v is running. Switching to a:
|
||||
// s1 cost: v is in s1, so 0
|
||||
// s2 cost: v is NOT in s2, so 50
|
||||
// => pick s1
|
||||
result, err := solver.Solve("a", []string{"v"})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"a", "v"}, result.TargetSet)
|
||||
|
||||
// L is running. Switching to a:
|
||||
// s1 cost: L is NOT in s1, so 30
|
||||
// s2 cost: L is in s2, so 0
|
||||
// => pick s2
|
||||
result, err = solver.Solve("a", []string{"L"})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"a", "L"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_TieBreakingByDefinitionOrder(t *testing.T) {
|
||||
// Two sets with identical cost. Definition order should win.
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(
|
||||
es("s1", "a", "x"),
|
||||
es("s2", "a", "y"),
|
||||
),
|
||||
nil,
|
||||
)
|
||||
|
||||
// Nothing running, both sets cost 0. s1 is first.
|
||||
result, err := solver.Solve("a", []string{})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"a", "x"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_EvictCostPreservesExpensive(t *testing.T) {
|
||||
// Model "v" costs 50 to evict, "m" costs 1 (default).
|
||||
// Sets: [g,v], [g,m]
|
||||
// Running: v, m. Request g.
|
||||
// s1=[g,v]: evict m (cost 1), keep v
|
||||
// s2=[g,m]: evict v (cost 50), keep m
|
||||
// => pick s1
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(
|
||||
es("s1", "g", "v"),
|
||||
es("s2", "g", "m"),
|
||||
),
|
||||
map[string]int{"v": 50},
|
||||
)
|
||||
|
||||
result, err := solver.Solve("g", []string{"v", "m"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"m"}, result.Evict)
|
||||
assert.Equal(t, []string{"g", "v"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_NothingRunning(t *testing.T) {
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(
|
||||
es("s1", "g", "v"),
|
||||
es("s2", "q", "v"),
|
||||
),
|
||||
nil,
|
||||
)
|
||||
|
||||
result, err := solver.Solve("g", []string{})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"g", "v"}, result.TargetSet)
|
||||
}
|
||||
|
||||
// TestMatrix_ProxyRequestSwapRaceAgainstFastPath verifies that an eviction
|
||||
// cannot stop a process while an in-flight ProxyRequest for that process is
|
||||
// still in the [m.Unlock, Process.inFlightRequests.Add(1)] window. Without
|
||||
// matrix-level inflight tracking, the eviction's Stop() races with the
|
||||
// pending request and kills it mid-start.
|
||||
func TestMatrix_ProxyRequestSwapRaceAgainstFastPath(t *testing.T) {
|
||||
cfg := config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
ExpandedSets: []config.ExpandedSet{
|
||||
{SetName: "s1", Models: []string{"model1"}},
|
||||
{SetName: "s2", Models: []string{"model2"}},
|
||||
},
|
||||
Matrix: &config.MatrixConfig{},
|
||||
}
|
||||
|
||||
m := NewMatrix(cfg, testLogger, testLogger)
|
||||
defer m.StopProcesses(StopImmediately)
|
||||
|
||||
// Bypass real subprocesses so the test is fast and deterministic.
|
||||
m.processes["model1"].testHandler = newTestHandler("model1")
|
||||
m.processes["model2"].testHandler = newTestHandler("model2")
|
||||
|
||||
// Prime: run a request through model1 so it reaches StateReady and
|
||||
// subsequent requests take the no-eviction path.
|
||||
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
primeW := httptest.NewRecorder()
|
||||
require.NoError(t, m.ProxyRequest("model1", primeW, primeReq))
|
||||
require.Equal(t, http.StatusOK, primeW.Code)
|
||||
require.Equal(t, StateReady, m.processes["model1"].CurrentState())
|
||||
require.Equal(t, StateStopped, m.processes["model2"].CurrentState())
|
||||
|
||||
// Install fast-path hook that signals arrival and waits for release.
|
||||
// This parks R2 at the race window — after m.Lock is released but
|
||||
// before Process.inFlightRequests.Add(1).
|
||||
r2Reached := make(chan struct{})
|
||||
r2Release := make(chan struct{})
|
||||
m.testDelayFastPath = func() {
|
||||
close(r2Reached)
|
||||
<-r2Release
|
||||
}
|
||||
|
||||
// R2: no-eviction request for model1. Will pause at the hook.
|
||||
r2Done := make(chan struct{})
|
||||
w2 := httptest.NewRecorder()
|
||||
go func() {
|
||||
defer close(r2Done)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
assert.NoError(t, m.ProxyRequest("model1", w2, req))
|
||||
}()
|
||||
|
||||
// Deterministically wait for R2 to reach the race window.
|
||||
<-r2Reached
|
||||
|
||||
// R3: request for model2 which requires evicting model1. Must wait for
|
||||
// R2 to finish before touching model1.
|
||||
r3Done := make(chan struct{})
|
||||
w3 := httptest.NewRecorder()
|
||||
go func() {
|
||||
defer close(r3Done)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
assert.NoError(t, m.ProxyRequest("model2", w3, req))
|
||||
}()
|
||||
|
||||
// Spin until R3 has acquired m.Lock and entered the eviction path. In
|
||||
// the fixed code, R3 then blocks on m.inflight.Wait() while still
|
||||
// holding the lock, so TryLock keeps failing.
|
||||
for m.TryLock() {
|
||||
m.Unlock()
|
||||
runtime.Gosched()
|
||||
}
|
||||
|
||||
// Bounded poll: give R3 a chance to demonstrate the bug by mutating
|
||||
// state. In the fixed code R3 is blocked and nothing changes; in the
|
||||
// buggy code R3 will Stop() model1 and start model2 within microseconds.
|
||||
deadline := time.Now().Add(100 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if m.processes["model1"].CurrentState() != StateReady ||
|
||||
m.processes["model2"].CurrentState() != StateStopped {
|
||||
break
|
||||
}
|
||||
done := false
|
||||
select {
|
||||
case <-r3Done:
|
||||
done = true
|
||||
default:
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
runtime.Gosched()
|
||||
}
|
||||
|
||||
// Invariant: R3 must be blocked while R2 is still in flight.
|
||||
select {
|
||||
case <-r3Done:
|
||||
t.Fatal("eviction completed while in-flight request was still pending — race not prevented")
|
||||
default:
|
||||
}
|
||||
assert.Equal(t, StateReady, m.processes["model1"].CurrentState(),
|
||||
"model1 must stay Ready while an in-flight request is pending")
|
||||
assert.Equal(t, StateStopped, m.processes["model2"].CurrentState(),
|
||||
"model2 must not be started until R2 finishes and model1 is evicted")
|
||||
|
||||
// Release R2 and let both requests finish.
|
||||
close(r2Release)
|
||||
<-r2Done
|
||||
<-r3Done
|
||||
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Contains(t, w2.Body.String(), "model1")
|
||||
assert.Equal(t, http.StatusOK, w3.Code)
|
||||
assert.Contains(t, w3.Body.String(), "model2")
|
||||
}
|
||||
|
||||
func TestMatrixSolver_FullScenario(t *testing.T) {
|
||||
// Simulates the example config:
|
||||
// standard: [g,v], [q,v], [m,v]
|
||||
// with_rerank: [g,v,e], [q,v,e]
|
||||
// creative: [g,sd], [q,sd]
|
||||
// full: [L]
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(
|
||||
es("standard", "g", "v"),
|
||||
es("standard", "q", "v"),
|
||||
es("standard", "m", "v"),
|
||||
es("with_rerank", "e", "g", "v"),
|
||||
es("with_rerank", "e", "q", "v"),
|
||||
es("creative", "g", "sd"),
|
||||
es("creative", "q", "sd"),
|
||||
es("full", "L"),
|
||||
),
|
||||
map[string]int{"v": 50, "L": 30, "whisper": 10},
|
||||
)
|
||||
|
||||
// Running: g, v. Request q.
|
||||
// standard[q,v]: evict g (cost 1), keep v. Total: 1.
|
||||
// with_rerank[q,v,e]: evict g (cost 1), keep v. Total: 1.
|
||||
// => tie, pick first by definition order = standard[q,v]
|
||||
result, err := solver.Solve("q", []string{"g", "v"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"g"}, result.Evict)
|
||||
assert.Equal(t, []string{"q", "v"}, result.TargetSet)
|
||||
|
||||
// Running: g, v. Request L.
|
||||
// full[L]: evict g (cost 1) + v (cost 50). Total: 51.
|
||||
// Only one set contains L, so pick it.
|
||||
result, err = solver.Solve("L", []string{"g", "v"})
|
||||
require.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"g", "v"}, result.Evict)
|
||||
assert.Equal(t, []string{"L"}, result.TargetSet)
|
||||
|
||||
// Running: g, v. Request sd.
|
||||
// creative[g,sd]: evict v (cost 50). Total: 50.
|
||||
// creative[q,sd]: evict g (cost 1) + v (cost 50). Total: 51.
|
||||
// => pick creative[g,sd]
|
||||
result, err = solver.Solve("sd", []string{"g", "v"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"v"}, result.Evict)
|
||||
assert.Equal(t, []string{"g", "sd"}, result.TargetSet)
|
||||
|
||||
// Running: q, v, e. Request g.
|
||||
// standard[g,v]: evict q (1) + e (1). Total: 2.
|
||||
// with_rerank[g,v,e]: evict q (1). Total: 1.
|
||||
// creative[g,sd]: evict q (1) + v (50) + e (1). Total: 52.
|
||||
// => pick with_rerank[g,v,e]
|
||||
result, err = solver.Solve("g", []string{"e", "q", "v"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"q"}, result.Evict)
|
||||
assert.Equal(t, []string{"e", "g", "v"}, result.TargetSet)
|
||||
}
|
||||
@@ -1,689 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/mostlygeek/llama-swap/internal/cache"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/ring"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// zstdEncOptions are the shared zstd encoder options for maximum compression.
|
||||
var zstdEncOptions = []zstd.EOption{
|
||||
zstd.WithEncoderLevel(zstd.SpeedBetterCompression),
|
||||
}
|
||||
|
||||
// zstdDecOptions are the shared zstd decoder options.
|
||||
var zstdDecOptions = []zstd.DOption{}
|
||||
|
||||
// zstdEncPool pools zstd.Encoder instances to reduce allocations.
|
||||
var zstdEncPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
enc, _ := zstd.NewWriter(nil, zstdEncOptions...)
|
||||
return enc
|
||||
},
|
||||
}
|
||||
|
||||
// zstdDecPool pools zstd.Decoder instances to reduce allocations.
|
||||
var zstdDecPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
dec, _ := zstd.NewReader(nil, zstdDecOptions...)
|
||||
return dec
|
||||
},
|
||||
}
|
||||
|
||||
// compressCapture marshals a ReqRespCapture to CBOR and compresses it with zstd.
|
||||
// Returns compressed bytes and the original CBOR byte count for logging.
|
||||
func compressCapture(c *ReqRespCapture) ([]byte, int, error) {
|
||||
cborBytes, err := cbor.Marshal(c)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("marshal capture: %w", err)
|
||||
}
|
||||
zenc := zstdEncPool.Get().(*zstd.Encoder)
|
||||
defer zstdEncPool.Put(zenc)
|
||||
return zenc.EncodeAll(cborBytes, nil), len(cborBytes), nil
|
||||
}
|
||||
|
||||
// decompressCapture decompresses zstd-compressed CBOR and unmarshals it into a ReqRespCapture.
|
||||
func decompressCapture(data []byte) (*ReqRespCapture, error) {
|
||||
dec := zstdDecPool.Get().(*zstd.Decoder)
|
||||
defer zstdDecPool.Put(dec)
|
||||
cborBytes, err := dec.DecodeAll(data, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decompress capture: %w", err)
|
||||
}
|
||||
var capture ReqRespCapture
|
||||
if err := cbor.Unmarshal(cborBytes, &capture); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal capture: %w", err)
|
||||
}
|
||||
return &capture, nil
|
||||
}
|
||||
|
||||
// TokenMetrics holds token usage and performance metrics
|
||||
type TokenMetrics struct {
|
||||
CachedTokens int `json:"cache_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||
TokensPerSecond float64 `json:"tokens_per_second"`
|
||||
}
|
||||
|
||||
// ActivityLogEntry represents parsed token statistics from llama-server logs
|
||||
type ActivityLogEntry struct {
|
||||
ID int `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Model string `json:"model"`
|
||||
ReqPath string `json:"req_path"`
|
||||
RespContentType string `json:"resp_content_type"`
|
||||
RespStatusCode int `json:"resp_status_code"`
|
||||
Tokens TokenMetrics `json:"tokens"`
|
||||
DurationMs int `json:"duration_ms"`
|
||||
HasCapture bool `json:"has_capture"`
|
||||
}
|
||||
|
||||
type ReqRespCapture struct {
|
||||
ID int `json:"id"`
|
||||
ReqPath string `json:"req_path"`
|
||||
ReqHeaders map[string]string `json:"req_headers"`
|
||||
ReqBody []byte `json:"req_body"`
|
||||
RespHeaders map[string]string `json:"resp_headers"`
|
||||
RespBody []byte `json:"resp_body"`
|
||||
}
|
||||
|
||||
// ActivityLogEvent represents a token metrics event
|
||||
type ActivityLogEvent struct {
|
||||
Metrics ActivityLogEntry
|
||||
}
|
||||
|
||||
func (e ActivityLogEvent) Type() uint32 {
|
||||
return ActivityLogEventID // defined in events.go
|
||||
}
|
||||
|
||||
// metricsMonitor parses llama-server output for token statistics
|
||||
type metricsMonitor struct {
|
||||
mu sync.RWMutex
|
||||
metrics ring.Buffer[ActivityLogEntry]
|
||||
nextID int
|
||||
logger *logmon.Monitor
|
||||
|
||||
// capture fields
|
||||
enableCaptures bool
|
||||
captureCache *cache.Cache // zstd-compressed CBOR of ReqRespCapture
|
||||
}
|
||||
|
||||
// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the
|
||||
// capture buffer size in megabytes; 0 disables captures.
|
||||
func newMetricsMonitor(logger *logmon.Monitor, maxMetrics int, captureBufferMB int) *metricsMonitor {
|
||||
mm := &metricsMonitor{
|
||||
logger: logger,
|
||||
metrics: ring.NewBuffer[ActivityLogEntry](maxMetrics),
|
||||
enableCaptures: captureBufferMB > 0,
|
||||
}
|
||||
if captureBufferMB > 0 {
|
||||
mm.captureCache = cache.New(captureBufferMB * 1024 * 1024)
|
||||
}
|
||||
return mm
|
||||
}
|
||||
|
||||
// queueMetrics adds a new metric to the collection without emitting an event.
|
||||
// Returns the assigned metric ID. Call emitMetric after capture setup.
|
||||
func (mp *metricsMonitor) queueMetrics(metric ActivityLogEntry) int {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
metric.ID = mp.nextID
|
||||
mp.nextID++
|
||||
mp.metrics.Push(metric)
|
||||
return metric.ID
|
||||
}
|
||||
|
||||
// emitMetric publishes an ActivityLogEvent for the given metric.
|
||||
func (mp *metricsMonitor) emitMetric(metric ActivityLogEntry) {
|
||||
event.Emit(ActivityLogEvent{Metrics: metric})
|
||||
}
|
||||
|
||||
// addCapture compresses and stores a capture in the cache.
|
||||
// Returns true if the capture was stored, false otherwise.
|
||||
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) bool {
|
||||
if !mp.enableCaptures {
|
||||
return false
|
||||
}
|
||||
|
||||
compressed, uncompressedBytes, err := compressCapture(&capture)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("failed to compress capture: %v, skipping", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := mp.captureCache.Add(capture.ID, compressed); err != nil {
|
||||
mp.logger.Warnf("capture %d too large (%d bytes), skipping: %v", capture.ID, len(compressed), err)
|
||||
return false
|
||||
}
|
||||
|
||||
compressionRatio := (1 - float64(len(compressed))/float64(uncompressedBytes)) * 100
|
||||
mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio)
|
||||
return true
|
||||
}
|
||||
|
||||
// getCompressedBytes returns the raw compressed bytes for a capture by ID.
|
||||
func (mp *metricsMonitor) getCompressedBytes(id int) ([]byte, bool) {
|
||||
if mp.captureCache == nil {
|
||||
return nil, false
|
||||
}
|
||||
data, err := mp.captureCache.Get(id)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return data, true
|
||||
}
|
||||
|
||||
// getCaptureByID decompresses and unmarshals a capture by ID.
|
||||
// Returns nil if the capture is not found or decompression fails.
|
||||
func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture {
|
||||
if mp.captureCache == nil {
|
||||
return nil
|
||||
}
|
||||
data, exists := mp.getCompressedBytes(id)
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
capture, err := decompressCapture(data)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("failed to decompress capture %d: %v", id, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return capture
|
||||
}
|
||||
|
||||
// getMetrics returns a copy of the current metrics with HasCapture resolved from cache.
|
||||
func (mp *metricsMonitor) getMetrics() []ActivityLogEntry {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
result := mp.metrics.Slice()
|
||||
if result == nil {
|
||||
return []ActivityLogEntry{}
|
||||
}
|
||||
if mp.captureCache != nil {
|
||||
for i := range result {
|
||||
result[i].HasCapture = mp.captureCache.Has(result[i].ID)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// getMetricsJSON returns metrics as JSON with HasCapture resolved from cache.
|
||||
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
result := mp.metrics.Slice()
|
||||
if result == nil {
|
||||
return json.Marshal([]ActivityLogEntry{})
|
||||
}
|
||||
if mp.captureCache != nil {
|
||||
for i := range result {
|
||||
result[i].HasCapture = mp.captureCache.Has(result[i].ID)
|
||||
}
|
||||
}
|
||||
return json.Marshal(result)
|
||||
}
|
||||
|
||||
// Capture field flags for controlling what is saved in ReqRespCapture.
|
||||
type captureFields uint
|
||||
|
||||
const (
|
||||
captureNone captureFields = 1 << iota
|
||||
captureReqHeaders
|
||||
captureReqBody
|
||||
captureRespHeaders
|
||||
captureRespBody
|
||||
)
|
||||
|
||||
const (
|
||||
captureReqAll = captureReqHeaders | captureReqBody
|
||||
captureRespAll = captureRespHeaders | captureRespBody
|
||||
captureAll = captureReqAll | captureRespAll
|
||||
)
|
||||
|
||||
// wrapHandler wraps the proxy handler to extract token metrics.
|
||||
// captureFields controls what is saved in the ReqRespCapture using bitwise flags.
|
||||
// if wrapHandler returns an error it is safe to assume that no
|
||||
// data was sent to the client
|
||||
func (mp *metricsMonitor) wrapHandler(
|
||||
modelID string,
|
||||
writer gin.ResponseWriter,
|
||||
request *http.Request,
|
||||
captureFields captureFields,
|
||||
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
|
||||
) error {
|
||||
// Capture request body and headers if captures enabled
|
||||
var reqBody []byte
|
||||
var reqHeaders map[string]string
|
||||
if mp.enableCaptures && (captureFields&captureReqBody) != 0 {
|
||||
if request.Body != nil {
|
||||
var err error
|
||||
reqBody, err = io.ReadAll(request.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read request body for capture: %w", err)
|
||||
}
|
||||
request.Body.Close()
|
||||
request.Body = io.NopCloser(bytes.NewBuffer(reqBody))
|
||||
}
|
||||
}
|
||||
if mp.enableCaptures && (captureFields&captureReqHeaders) != 0 {
|
||||
reqHeaders = make(map[string]string)
|
||||
for key, values := range request.Header {
|
||||
if len(values) > 0 {
|
||||
reqHeaders[key] = values[0]
|
||||
}
|
||||
}
|
||||
redactHeaders(reqHeaders)
|
||||
}
|
||||
|
||||
recorder := newBodyCopier(writer)
|
||||
|
||||
// Filter Accept-Encoding to only include encodings we can decompress for metrics
|
||||
if ae := request.Header.Get("Accept-Encoding"); ae != "" {
|
||||
request.Header.Set("Accept-Encoding", filterAcceptEncoding(ae))
|
||||
}
|
||||
|
||||
if err := next(modelID, recorder, request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// after this point we have to assume that data was sent to the client
|
||||
// and we can only log errors but not send them to clients
|
||||
|
||||
// Initialize default metrics - recorded for every request
|
||||
tm := ActivityLogEntry{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
ReqPath: request.URL.Path,
|
||||
RespContentType: recorder.Header().Get("Content-Type"),
|
||||
RespStatusCode: recorder.Status(),
|
||||
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
|
||||
}
|
||||
|
||||
if recorder.Status() != http.StatusOK {
|
||||
mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), request.URL.Path)
|
||||
tm.ID = mp.queueMetrics(tm)
|
||||
mp.emitMetric(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
body := recorder.body.Bytes()
|
||||
if len(body) == 0 {
|
||||
mp.logger.Warn("metrics: empty body, recording minimal metrics")
|
||||
tm.ID = mp.queueMetrics(tm)
|
||||
mp.emitMetric(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decompress if needed
|
||||
if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" {
|
||||
var err error
|
||||
body, err = decompressBody(body, encoding)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
tm.ID = mp.queueMetrics(tm)
|
||||
mp.emitMetric(tm)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
} else {
|
||||
tm.Tokens = parsed.Tokens
|
||||
tm.DurationMs = parsed.DurationMs
|
||||
}
|
||||
} else {
|
||||
if gjson.ValidBytes(body) {
|
||||
parsed := gjson.ParseBytes(body)
|
||||
usage := parsed.Get("usage")
|
||||
timings := parsed.Get("timings")
|
||||
|
||||
// extract timings for infill - response is an array, timings are in the last element
|
||||
// see #463
|
||||
if strings.HasPrefix(request.URL.Path, "/infill") {
|
||||
if arr := parsed.Array(); len(arr) > 0 {
|
||||
timings = arr[len(arr)-1].Get("timings")
|
||||
}
|
||||
}
|
||||
|
||||
if usage.Exists() || timings.Exists() {
|
||||
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
|
||||
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
} else {
|
||||
tm.Tokens = parsedMetrics.Tokens
|
||||
tm.DurationMs = parsedMetrics.DurationMs
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", request.URL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
// Build capture if enabled and determine if it will be stored
|
||||
var capture *ReqRespCapture
|
||||
if mp.enableCaptures {
|
||||
var respHeaders map[string]string
|
||||
var respBody []byte
|
||||
if (captureFields & captureRespHeaders) != 0 {
|
||||
respHeaders = make(map[string]string)
|
||||
for key, values := range recorder.Header() {
|
||||
if len(values) > 0 {
|
||||
respHeaders[key] = values[0]
|
||||
}
|
||||
}
|
||||
redactHeaders(respHeaders)
|
||||
delete(respHeaders, "Content-Encoding")
|
||||
}
|
||||
if (captureFields & captureRespBody) != 0 {
|
||||
respBody = body
|
||||
}
|
||||
capture = &ReqRespCapture{
|
||||
ReqPath: request.URL.Path,
|
||||
ReqHeaders: reqHeaders,
|
||||
ReqBody: reqBody,
|
||||
RespHeaders: respHeaders,
|
||||
RespBody: respBody,
|
||||
}
|
||||
}
|
||||
|
||||
metricID := mp.queueMetrics(tm)
|
||||
tm.ID = metricID
|
||||
|
||||
// Store capture if enabled
|
||||
if capture != nil {
|
||||
capture.ID = metricID
|
||||
if mp.addCapture(*capture) {
|
||||
tm.HasCapture = true
|
||||
}
|
||||
}
|
||||
|
||||
mp.emitMetric(tm)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// usagePaths lists the JSON paths where a per-event usage object can live.
|
||||
// v1/chat/completions puts it at top-level "usage"; v1/responses nests under
|
||||
// "response.usage"; v1/messages emits it at "message.usage" on message_start
|
||||
// and at "usage" on message_delta.
|
||||
var usagePaths = []string{"usage", "response.usage", "message.usage"}
|
||||
|
||||
// extractUsageTokens reads input/output/cached token counts from a usage
|
||||
// gjson.Result, handling the field-name differences across endpoints.
|
||||
// cached returns -1 when the field is absent. ok is true when at least one
|
||||
// field was present.
|
||||
func extractUsageTokens(usage gjson.Result) (input, output, cached int64, ok bool) {
|
||||
cached = -1
|
||||
if !usage.Exists() {
|
||||
return
|
||||
}
|
||||
|
||||
if v := usage.Get("prompt_tokens"); v.Exists() {
|
||||
// v1/chat/completions
|
||||
input = v.Int()
|
||||
ok = true
|
||||
} else if v := usage.Get("input_tokens"); v.Exists() {
|
||||
// v1/messages, v1/responses
|
||||
input = v.Int()
|
||||
ok = true
|
||||
}
|
||||
|
||||
if v := usage.Get("completion_tokens"); v.Exists() {
|
||||
// v1/chat/completions
|
||||
output = v.Int()
|
||||
ok = true
|
||||
} else if v := usage.Get("output_tokens"); v.Exists() {
|
||||
// v1/messages, v1/responses
|
||||
output = v.Int()
|
||||
ok = true
|
||||
}
|
||||
|
||||
if v := usage.Get("cache_read_input_tokens"); v.Exists() {
|
||||
// v1/messages (Anthropic)
|
||||
cached = v.Int()
|
||||
ok = true
|
||||
} else if v := usage.Get("input_tokens_details.cached_tokens"); v.Exists() {
|
||||
// v1/responses (OpenAI Responses API)
|
||||
cached = v.Int()
|
||||
ok = true
|
||||
} else if v := usage.Get("prompt_tokens_details.cached_tokens"); v.Exists() {
|
||||
// v1/chat/completions (OpenAI cache hits)
|
||||
cached = v.Int()
|
||||
ok = true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func processStreamingResponse(modelID string, start time.Time, body []byte) (ActivityLogEntry, error) {
|
||||
// Walk SSE "data:" lines forward, merging usage info from every event.
|
||||
// Different endpoints split usage across events:
|
||||
// - v1/chat/completions: usage on the final chunk before [DONE]
|
||||
// - v1/responses: usage on response.completed (response.usage)
|
||||
// - v1/messages: input + cache on message_start (message.usage),
|
||||
// output_tokens on message_delta (usage)
|
||||
// We take the latest informative value per field so all three are covered.
|
||||
|
||||
var (
|
||||
inputTokens, outputTokens int64
|
||||
cachedTokens int64 = -1
|
||||
hasAny bool
|
||||
timings gjson.Result
|
||||
)
|
||||
|
||||
prefix := []byte("data:")
|
||||
for offset := 0; offset < len(body); {
|
||||
nl := bytes.IndexByte(body[offset:], '\n')
|
||||
var line []byte
|
||||
if nl == -1 {
|
||||
line = body[offset:]
|
||||
offset = len(body)
|
||||
} else {
|
||||
line = body[offset : offset+nl]
|
||||
offset += nl + 1
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || !bytes.HasPrefix(line, prefix) {
|
||||
continue
|
||||
}
|
||||
data := bytes.TrimSpace(line[len(prefix):])
|
||||
if len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
if !gjson.ValidBytes(data) {
|
||||
continue
|
||||
}
|
||||
parsed := gjson.ParseBytes(data)
|
||||
|
||||
for _, path := range usagePaths {
|
||||
u := parsed.Get(path)
|
||||
if !u.Exists() {
|
||||
continue
|
||||
}
|
||||
i, o, c, ok := extractUsageTokens(u)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
hasAny = true
|
||||
// Take the latest non-zero value so message_start's input_tokens
|
||||
// is preserved when message_delta's usage omits it, and vice versa
|
||||
// for output_tokens.
|
||||
if i > 0 {
|
||||
inputTokens = i
|
||||
}
|
||||
if o > 0 {
|
||||
outputTokens = o
|
||||
}
|
||||
if c >= 0 {
|
||||
cachedTokens = c
|
||||
}
|
||||
}
|
||||
if t := parsed.Get("timings"); t.Exists() {
|
||||
timings = t
|
||||
hasAny = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasAny {
|
||||
return ActivityLogEntry{}, fmt.Errorf("no valid JSON data found in stream")
|
||||
}
|
||||
|
||||
return buildMetrics(modelID, start, inputTokens, outputTokens, cachedTokens, timings), nil
|
||||
}
|
||||
|
||||
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (ActivityLogEntry, error) {
|
||||
input, output, cached, _ := extractUsageTokens(usage)
|
||||
return buildMetrics(modelID, start, input, output, cached, timings), nil
|
||||
}
|
||||
|
||||
// buildMetrics composes an ActivityLogEntry from accumulated token counts and
|
||||
// optional llama-server timings (which override input/output and provide rates).
|
||||
func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, cachedTokens int64, timings gjson.Result) ActivityLogEntry {
|
||||
wallDurationMs := int(time.Since(start).Milliseconds())
|
||||
durationMs := wallDurationMs
|
||||
tokensPerSecond := -1.0
|
||||
promptPerSecond := -1.0
|
||||
|
||||
if timings.Exists() {
|
||||
inputTokens = timings.Get("prompt_n").Int()
|
||||
outputTokens = timings.Get("predicted_n").Int()
|
||||
promptPerSecond = timings.Get("prompt_per_second").Float()
|
||||
tokensPerSecond = timings.Get("predicted_per_second").Float()
|
||||
timingsDurationMs := int(timings.Get("prompt_ms").Float() + timings.Get("predicted_ms").Float())
|
||||
if timingsDurationMs > durationMs {
|
||||
durationMs = timingsDurationMs
|
||||
}
|
||||
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
|
||||
cachedTokens = cachedValue.Int()
|
||||
}
|
||||
}
|
||||
|
||||
return ActivityLogEntry{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
Tokens: TokenMetrics{
|
||||
CachedTokens: int(cachedTokens),
|
||||
InputTokens: int(inputTokens),
|
||||
OutputTokens: int(outputTokens),
|
||||
PromptPerSecond: promptPerSecond,
|
||||
TokensPerSecond: tokensPerSecond,
|
||||
},
|
||||
DurationMs: durationMs,
|
||||
}
|
||||
}
|
||||
|
||||
// decompressBody decompresses the body based on Content-Encoding header
|
||||
func decompressBody(body []byte, encoding string) ([]byte, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(encoding)) {
|
||||
case "gzip":
|
||||
reader, err := gzip.NewReader(bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
case "deflate":
|
||||
reader := flate.NewReader(bytes.NewReader(body))
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
default:
|
||||
return body, nil // Return as-is for unknown/no encoding
|
||||
}
|
||||
}
|
||||
|
||||
// responseBodyCopier records the response body and writes to the original response writer
|
||||
// while also capturing it in a buffer for later processing
|
||||
type responseBodyCopier struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
tee io.Writer
|
||||
start time.Time
|
||||
}
|
||||
|
||||
func newBodyCopier(w gin.ResponseWriter) *responseBodyCopier {
|
||||
bodyBuffer := &bytes.Buffer{}
|
||||
return &responseBodyCopier{
|
||||
ResponseWriter: w,
|
||||
body: bodyBuffer,
|
||||
tee: io.MultiWriter(w, bodyBuffer),
|
||||
start: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) Write(b []byte) (int, error) {
|
||||
return w.tee.Write(b)
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) WriteHeader(statusCode int) {
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) Header() http.Header {
|
||||
return w.ResponseWriter.Header()
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) StartTime() time.Time {
|
||||
return w.start
|
||||
}
|
||||
|
||||
// sensitiveHeaders lists headers that should be redacted in captures
|
||||
var sensitiveHeaders = map[string]bool{
|
||||
"authorization": true,
|
||||
"proxy-authorization": true,
|
||||
"cookie": true,
|
||||
"set-cookie": true,
|
||||
"x-api-key": true,
|
||||
}
|
||||
|
||||
// redactHeaders replaces sensitive header values in-place with "[REDACTED]"
|
||||
func redactHeaders(headers map[string]string) {
|
||||
for key := range headers {
|
||||
if sensitiveHeaders[strings.ToLower(key)] {
|
||||
headers[key] = "[REDACTED]"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// filterAcceptEncoding filters the Accept-Encoding header to only include
|
||||
// encodings we can decompress (gzip, deflate). This respects the client's
|
||||
// preferences while ensuring we can parse response bodies for metrics.
|
||||
func filterAcceptEncoding(acceptEncoding string) string {
|
||||
if acceptEncoding == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
supported := map[string]bool{"gzip": true, "deflate": true}
|
||||
var filtered []string
|
||||
|
||||
for part := range strings.SplitSeq(acceptEncoding, ",") {
|
||||
// Parse encoding and optional quality value (e.g., "gzip;q=1.0")
|
||||
encoding, _, _ := strings.Cut(strings.TrimSpace(part), ";")
|
||||
if supported[strings.ToLower(encoding)] {
|
||||
filtered = append(filtered, strings.TrimSpace(part))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(filtered, ", ")
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,144 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
type peerProxyMember struct {
|
||||
peerID string
|
||||
reverseProxy *httputil.ReverseProxy
|
||||
apiKey string
|
||||
}
|
||||
|
||||
type PeerProxy struct {
|
||||
peers config.PeerDictionaryConfig
|
||||
proxyMap map[string]*peerProxyMember
|
||||
}
|
||||
|
||||
func NewPeerProxy(peers config.PeerDictionaryConfig, proxyLogger *logmon.Monitor) (*PeerProxy, error) {
|
||||
proxyMap := make(map[string]*peerProxyMember)
|
||||
|
||||
// Sort peer IDs for consistent iteration order
|
||||
peerIDs := make([]string, 0, len(peers))
|
||||
for peerID := range peers {
|
||||
peerIDs = append(peerIDs, peerID)
|
||||
}
|
||||
sort.Strings(peerIDs)
|
||||
|
||||
for _, peerID := range peerIDs {
|
||||
peer := peers[peerID]
|
||||
|
||||
// Create a transport with per-peer timeout configuration
|
||||
peerTransport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: time.Duration(peer.Timeouts.Connect) * time.Second,
|
||||
KeepAlive: time.Duration(peer.Timeouts.KeepAlive) * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: time.Duration(peer.Timeouts.TLSHandshake) * time.Second,
|
||||
ResponseHeaderTimeout: time.Duration(peer.Timeouts.ResponseHeader) * time.Second,
|
||||
ExpectContinueTimeout: time.Duration(peer.Timeouts.ExpectContinue) * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second,
|
||||
}
|
||||
|
||||
// Create reverse proxy for this peer
|
||||
reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL)
|
||||
reverseProxy.Transport = peerTransport
|
||||
|
||||
// Wrap Director to set Host header for remote hosts (not localhost)
|
||||
originalDirector := reverseProxy.Director
|
||||
reverseProxy.Director = func(req *http.Request) {
|
||||
originalDirector(req)
|
||||
// Ensure Host header matches target URL for remote proxying
|
||||
req.Host = req.URL.Host
|
||||
}
|
||||
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||
resp.Header.Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
proxyLogger.Warnf("peer %s: proxy error: %v", peerID, err)
|
||||
errMsg := fmt.Sprintf("peer proxy error: %v", err)
|
||||
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
|
||||
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
|
||||
}
|
||||
http.Error(w, errMsg, http.StatusBadGateway)
|
||||
}
|
||||
|
||||
pp := &peerProxyMember{
|
||||
peerID: peerID,
|
||||
reverseProxy: reverseProxy,
|
||||
apiKey: peer.ApiKey,
|
||||
}
|
||||
|
||||
// Map each model to this peer's proxy
|
||||
for _, modelID := range peer.Models {
|
||||
if _, found := proxyMap[modelID]; found {
|
||||
proxyLogger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
|
||||
continue
|
||||
}
|
||||
proxyMap[modelID] = pp
|
||||
}
|
||||
}
|
||||
|
||||
return &PeerProxy{
|
||||
peers: peers,
|
||||
proxyMap: proxyMap,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PeerProxy) HasPeerModel(modelID string) bool {
|
||||
_, found := p.proxyMap[modelID]
|
||||
return found
|
||||
}
|
||||
|
||||
// GetPeerFilters returns the filters for a peer model, or empty filters if not found
|
||||
func (p *PeerProxy) GetPeerFilters(modelID string) config.Filters {
|
||||
pp, found := p.proxyMap[modelID]
|
||||
if !found {
|
||||
return config.Filters{}
|
||||
}
|
||||
// Get the peer config using the peerID
|
||||
peer, found := p.peers[pp.peerID]
|
||||
if !found {
|
||||
return config.Filters{}
|
||||
}
|
||||
return peer.Filters
|
||||
}
|
||||
|
||||
func (p *PeerProxy) ListPeers() config.PeerDictionaryConfig {
|
||||
return p.peers
|
||||
}
|
||||
|
||||
func (p *PeerProxy) ProxyRequest(model_id string, writer http.ResponseWriter, request *http.Request) error {
|
||||
pp, found := p.proxyMap[model_id]
|
||||
if !found {
|
||||
return fmt.Errorf("no peer proxy found for model %s", model_id)
|
||||
}
|
||||
|
||||
// Inject API key if configured for this peer
|
||||
if pp.apiKey != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+pp.apiKey)
|
||||
request.Header.Set("x-api-key", pp.apiKey)
|
||||
}
|
||||
|
||||
pp.reverseProxy.ServeHTTP(writer, request)
|
||||
return nil
|
||||
}
|
||||
@@ -1,311 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewPeerProxy_EmptyPeers(t *testing.T) {
|
||||
peers := config.PeerDictionaryConfig{}
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, pm)
|
||||
assert.Empty(t, pm.proxyMap)
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_SinglePeer(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "test-key",
|
||||
Models: []string{"model-a", "model-b"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pm.proxyMap, 2)
|
||||
assert.True(t, pm.HasPeerModel("model-a"))
|
||||
assert.True(t, pm.HasPeerModel("model-b"))
|
||||
assert.False(t, pm.HasPeerModel("model-c"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_MultiplePeers(t *testing.T) {
|
||||
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL1,
|
||||
Models: []string{"model-a", "model-b"},
|
||||
},
|
||||
"peer2": config.PeerConfig{
|
||||
Proxy: "http://peer2.example.com:8080",
|
||||
ProxyURL: proxyURL2,
|
||||
Models: []string{"model-c", "model-d"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pm.proxyMap, 4)
|
||||
assert.True(t, pm.HasPeerModel("model-a"))
|
||||
assert.True(t, pm.HasPeerModel("model-b"))
|
||||
assert.True(t, pm.HasPeerModel("model-c"))
|
||||
assert.True(t, pm.HasPeerModel("model-d"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_DuplicateModelWarning(t *testing.T) {
|
||||
// When the same model is in multiple peers, only the first (lexicographically by peer ID)
|
||||
// should be mapped, and a warning should be logged
|
||||
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"alpha-peer": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL1,
|
||||
Models: []string{"duplicate-model"},
|
||||
},
|
||||
"beta-peer": config.PeerConfig{
|
||||
Proxy: "http://peer2.example.com:8080",
|
||||
ProxyURL: proxyURL2,
|
||||
Models: []string{"duplicate-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
// Should only have one entry for the duplicate model
|
||||
assert.Len(t, pm.proxyMap, 1)
|
||||
assert.True(t, pm.HasPeerModel("duplicate-model"))
|
||||
}
|
||||
|
||||
func TestHasPeerModel(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"existing-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, pm.HasPeerModel("existing-model"))
|
||||
assert.False(t, pm.HasPeerModel("non-existing-model"))
|
||||
}
|
||||
|
||||
func TestProxyRequest_ModelNotFound(t *testing.T) {
|
||||
peers := config.PeerDictionaryConfig{}
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("non-existing-model", w, req)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no peer proxy found for model non-existing-model")
|
||||
}
|
||||
|
||||
func TestProxyRequest_Success(t *testing.T) {
|
||||
// Create a test server to act as the peer
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("response from peer"))
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "response from peer", w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyRequest_ApiKeyInjection(t *testing.T) {
|
||||
// Create a test server that checks for the Authorization header
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "secret-api-key",
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Bearer secret-api-key", receivedAuthHeader)
|
||||
}
|
||||
|
||||
func TestProxyRequest_NoApiKey(t *testing.T) {
|
||||
// Create a test server that checks for the Authorization header
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "", // No API key
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, receivedAuthHeader)
|
||||
}
|
||||
|
||||
func TestProxyRequest_HostHeaderSet(t *testing.T) {
|
||||
// Create a test server that checks the Host header
|
||||
var receivedHost string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedHost = r.Host
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
// The Host header should be set to the target URL's host
|
||||
assert.True(t, strings.HasPrefix(receivedHost, "127.0.0.1:"))
|
||||
}
|
||||
|
||||
func TestProxyRequest_SSEHeaderModification(t *testing.T) {
|
||||
// Create a test server that returns SSE content type
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
// The X-Accel-Buffering header should be set to "no" for SSE
|
||||
assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_CustomTimeouts(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://localhost:8080")
|
||||
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"test-peer": config.PeerConfig{
|
||||
Proxy: "http://localhost:8080",
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"model1"},
|
||||
Timeouts: config.TimeoutsConfig{
|
||||
Connect: 45,
|
||||
ResponseHeader: 300,
|
||||
TLSHandshake: 15,
|
||||
ExpectContinue: 2,
|
||||
IdleConn: 120,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
peerProxy, err := NewPeerProxy(peers, testLogger)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, peerProxy)
|
||||
assert.True(t, peerProxy.HasPeerModel("model1"))
|
||||
|
||||
// Verify the timeout values are actually applied to the transport
|
||||
member, found := peerProxy.proxyMap["model1"]
|
||||
require.True(t, found, "model1 should exist in proxyMap")
|
||||
assert.NotNil(t, member.reverseProxy)
|
||||
assert.NotNil(t, member.reverseProxy.Transport)
|
||||
|
||||
transport, ok := member.reverseProxy.Transport.(*http.Transport)
|
||||
require.True(t, ok, "Transport should be *http.Transport")
|
||||
|
||||
// Verify all timeout values are correctly applied
|
||||
assert.Equal(t, 300*time.Second, transport.ResponseHeaderTimeout)
|
||||
assert.Equal(t, 15*time.Second, transport.TLSHandshakeTimeout)
|
||||
assert.Equal(t, 2*time.Second, transport.ExpectContinueTimeout)
|
||||
assert.Equal(t, 120*time.Second, transport.IdleConnTimeout)
|
||||
// ForceAttemptHTTP2 should be enabled
|
||||
assert.True(t, transport.ForceAttemptHTTP2)
|
||||
}
|
||||
@@ -1,956 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
type ProcessState string
|
||||
|
||||
const (
|
||||
StateStopped ProcessState = ProcessState("stopped")
|
||||
StateStarting ProcessState = ProcessState("starting")
|
||||
StateReady ProcessState = ProcessState("ready")
|
||||
StateStopping ProcessState = ProcessState("stopping")
|
||||
|
||||
// process is shutdown and will not be restarted
|
||||
StateShutdown ProcessState = ProcessState("shutdown")
|
||||
)
|
||||
|
||||
type StopStrategy int
|
||||
|
||||
const (
|
||||
StopImmediately StopStrategy = iota
|
||||
StopWaitForInflightRequest
|
||||
)
|
||||
|
||||
type Process struct {
|
||||
ID string
|
||||
config config.ModelConfig
|
||||
cmd *exec.Cmd
|
||||
reverseProxy *httputil.ReverseProxy
|
||||
|
||||
// PR #155 called to cancel the upstream process
|
||||
cmdMutex sync.RWMutex
|
||||
cancelUpstream context.CancelFunc
|
||||
|
||||
// closed when command exits
|
||||
cmdWaitChan chan struct{}
|
||||
|
||||
processLogger *logmon.Monitor
|
||||
proxyLogger *logmon.Monitor
|
||||
|
||||
healthCheckTimeout int
|
||||
healthCheckLoopInterval time.Duration
|
||||
|
||||
lastRequestHandledMutex sync.RWMutex
|
||||
lastRequestHandled time.Time
|
||||
|
||||
stateMutex sync.RWMutex
|
||||
state ProcessState
|
||||
|
||||
inFlightRequests sync.WaitGroup
|
||||
inFlightRequestsCount atomic.Int32
|
||||
|
||||
// used to block on multiple start() calls
|
||||
waitStarting sync.WaitGroup
|
||||
|
||||
// for managing concurrency limits
|
||||
concurrencyLimitSemaphore chan struct{}
|
||||
|
||||
// used for testing to override the default value
|
||||
gracefulStopTimeout time.Duration
|
||||
|
||||
// used for testing to bypass subprocess and reverse proxy
|
||||
testHandler http.Handler
|
||||
|
||||
// track the number of failed starts
|
||||
failedStartCount int
|
||||
}
|
||||
|
||||
func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, processLogger *logmon.Monitor, proxyLogger *logmon.Monitor) *Process {
|
||||
concurrentLimit := 10
|
||||
if config.ConcurrencyLimit > 0 {
|
||||
concurrentLimit = config.ConcurrencyLimit
|
||||
}
|
||||
|
||||
// Setup the reverse proxy.
|
||||
proxyURL, err := url.Parse(config.Proxy)
|
||||
if err != nil {
|
||||
proxyLogger.Errorf("<%s> invalid proxy URL %q: %v", ID, config.Proxy, err)
|
||||
}
|
||||
|
||||
var reverseProxy *httputil.ReverseProxy
|
||||
if proxyURL != nil {
|
||||
reverseProxy = httputil.NewSingleHostReverseProxy(proxyURL)
|
||||
|
||||
// Create custom transport with configured timeouts
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: time.Duration(config.Timeouts.Connect) * time.Second,
|
||||
KeepAlive: time.Duration(config.Timeouts.KeepAlive) * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: time.Duration(config.Timeouts.TLSHandshake) * time.Second,
|
||||
ResponseHeaderTimeout: time.Duration(config.Timeouts.ResponseHeader) * time.Second,
|
||||
ExpectContinueTimeout: time.Duration(config.Timeouts.ExpectContinue) * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: time.Duration(config.Timeouts.IdleConn) * time.Second,
|
||||
}
|
||||
reverseProxy.Transport = transport
|
||||
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
// prevent nginx from buffering streaming responses (e.g., SSE)
|
||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||
resp.Header.Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return &Process{
|
||||
ID: ID,
|
||||
config: config,
|
||||
cmd: nil,
|
||||
reverseProxy: reverseProxy,
|
||||
cancelUpstream: nil,
|
||||
processLogger: processLogger,
|
||||
proxyLogger: proxyLogger,
|
||||
healthCheckTimeout: healthCheckTimeout,
|
||||
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
|
||||
state: StateStopped,
|
||||
|
||||
// concurrency limit
|
||||
concurrencyLimitSemaphore: make(chan struct{}, concurrentLimit),
|
||||
|
||||
// To be removed when migration over exec.CommandContext is complete
|
||||
// stop timeout
|
||||
gracefulStopTimeout: 10 * time.Second,
|
||||
cmdWaitChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// LogMonitor returns the log monitor associated with the process.
|
||||
func (p *Process) LogMonitor() *logmon.Monitor {
|
||||
return p.processLogger
|
||||
}
|
||||
|
||||
// setLastRequestHandled sets the last request handled time in a thread-safe manner.
|
||||
func (p *Process) setLastRequestHandled(t time.Time) {
|
||||
p.lastRequestHandledMutex.Lock()
|
||||
defer p.lastRequestHandledMutex.Unlock()
|
||||
p.lastRequestHandled = t
|
||||
}
|
||||
|
||||
// getLastRequestHandled gets the last request handled time in a thread-safe manner.
|
||||
func (p *Process) getLastRequestHandled() time.Time {
|
||||
p.lastRequestHandledMutex.RLock()
|
||||
defer p.lastRequestHandledMutex.RUnlock()
|
||||
return p.lastRequestHandled
|
||||
}
|
||||
|
||||
// custom error types for swapping state
|
||||
var (
|
||||
ErrExpectedStateMismatch = errors.New("expected state mismatch")
|
||||
ErrInvalidStateTransition = errors.New("invalid state transition")
|
||||
)
|
||||
|
||||
// swapState performs a compare and swap of the state atomically. It returns the current state
|
||||
// and an error if the swap failed.
|
||||
func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, error) {
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
|
||||
if p.state != expectedState {
|
||||
p.proxyLogger.Warnf("<%s> swapState() Unexpected current state %s, expected %s", p.ID, p.state, expectedState)
|
||||
return p.state, ErrExpectedStateMismatch
|
||||
}
|
||||
|
||||
if !isValidTransition(p.state, newState) {
|
||||
p.proxyLogger.Warnf("<%s> swapState() Invalid state transition from %s to %s", p.ID, p.state, newState)
|
||||
return p.state, ErrInvalidStateTransition
|
||||
}
|
||||
|
||||
p.state = newState
|
||||
|
||||
// Atomically increment waitStarting when entering StateStarting
|
||||
// This ensures any thread that sees StateStarting will also see the WaitGroup counter incremented
|
||||
if newState == StateStarting {
|
||||
p.waitStarting.Add(1)
|
||||
}
|
||||
|
||||
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
|
||||
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
|
||||
return p.state, nil
|
||||
}
|
||||
|
||||
// Helper function to encapsulate transition rules
|
||||
func isValidTransition(from, to ProcessState) bool {
|
||||
switch from {
|
||||
case StateStopped:
|
||||
return to == StateStarting
|
||||
case StateStarting:
|
||||
return to == StateReady || to == StateStopping || to == StateStopped
|
||||
case StateReady:
|
||||
return to == StateStopping
|
||||
case StateStopping:
|
||||
return to == StateStopped || to == StateShutdown
|
||||
case StateShutdown:
|
||||
return false // No transitions allowed from these states
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Process) CurrentState() ProcessState {
|
||||
p.stateMutex.RLock()
|
||||
defer p.stateMutex.RUnlock()
|
||||
return p.state
|
||||
}
|
||||
|
||||
// forceState forces the process state to the new state with mutex protection.
|
||||
// This should only be used in exceptional cases where the normal state transition
|
||||
// validation via swapState() cannot be used.
|
||||
func (p *Process) forceState(newState ProcessState) {
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
p.state = newState
|
||||
}
|
||||
|
||||
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
|
||||
// it is a private method because starting is automatic but stopping can be called
|
||||
// at any time.
|
||||
func (p *Process) start() error {
|
||||
|
||||
// test-only fast path: skip subprocess, health check, and TTL goroutine
|
||||
if p.testHandler != nil {
|
||||
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
|
||||
if err == ErrExpectedStateMismatch {
|
||||
if curState == StateStarting {
|
||||
p.waitStarting.Wait()
|
||||
curState = p.CurrentState()
|
||||
if curState == StateReady {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("process was already starting but wound up in state %v", curState)
|
||||
}
|
||||
return fmt.Errorf("process was in state %v when start() was called", curState)
|
||||
}
|
||||
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
|
||||
}
|
||||
defer p.waitStarting.Done()
|
||||
|
||||
// Mimic the real stop path: cancelUpstream transitions
|
||||
// StateStopping -> StateStopped and closes cmdWaitChan,
|
||||
// matching what waitForCmd does for real subprocesses.
|
||||
ch := make(chan struct{})
|
||||
p.cmdMutex.Lock()
|
||||
p.cancelUpstream = func() {
|
||||
if curState := p.CurrentState(); curState == StateStopping {
|
||||
if _, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||
p.forceState(StateStopped)
|
||||
}
|
||||
} else {
|
||||
p.forceState(StateStopped)
|
||||
}
|
||||
close(ch)
|
||||
}
|
||||
p.cmdWaitChan = ch
|
||||
p.cmdMutex.Unlock()
|
||||
|
||||
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
|
||||
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
|
||||
}
|
||||
p.failedStartCount = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
if p.config.Proxy == "" {
|
||||
return fmt.Errorf("can not start(), upstream proxy missing")
|
||||
}
|
||||
|
||||
args, err := p.config.SanitizedCommand()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get sanitized command: %v", err)
|
||||
}
|
||||
|
||||
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
|
||||
if err == ErrExpectedStateMismatch {
|
||||
// already starting, just wait for it to complete and expect
|
||||
// it to be be in the Ready start after. If not, return an error
|
||||
if curState == StateStarting {
|
||||
p.waitStarting.Wait()
|
||||
if state := p.CurrentState(); state == StateReady {
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("process was already starting but wound up in state %v", state)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("process was in state %v when start() was called", curState)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
|
||||
}
|
||||
}
|
||||
|
||||
// waitStarting.Add(1) is now called atomically in swapState() when transitioning to StateStarting
|
||||
defer p.waitStarting.Done()
|
||||
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())
|
||||
|
||||
p.cmd = exec.CommandContext(cmdContext, args[0], args[1:]...)
|
||||
p.cmd.Stdout = p.processLogger
|
||||
p.cmd.Stderr = p.processLogger
|
||||
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
|
||||
p.cmd.Cancel = p.cmdStopUpstreamProcess
|
||||
p.cmd.WaitDelay = p.gracefulStopTimeout
|
||||
setProcAttributes(p.cmd)
|
||||
|
||||
p.cmdMutex.Lock()
|
||||
p.cancelUpstream = ctxCancelUpstream
|
||||
p.cmdWaitChan = make(chan struct{})
|
||||
p.cmdMutex.Unlock()
|
||||
|
||||
p.failedStartCount++ // this will be reset to zero when the process has successfully started
|
||||
|
||||
p.proxyLogger.Debugf("<%s> Executing start command: %s, env: %s", p.ID, strings.Join(args, " "), strings.Join(p.config.Env, ", "))
|
||||
err = p.cmd.Start()
|
||||
|
||||
// Set process state to failed
|
||||
if err != nil {
|
||||
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
|
||||
p.forceState(StateStopped) // force it into a stopped state
|
||||
return fmt.Errorf(
|
||||
"failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
||||
strings.Join(args, " "), err, curState, swapErr,
|
||||
)
|
||||
}
|
||||
return fmt.Errorf("start() failed for command '%s': %v", strings.Join(args, " "), err)
|
||||
}
|
||||
|
||||
// Capture the exit error for later signalling
|
||||
go p.waitForCmd()
|
||||
|
||||
// One of three things can happen at this stage:
|
||||
// 1. The command exits unexpectedly
|
||||
// 2. The health check fails
|
||||
// 3. The health check passes
|
||||
//
|
||||
// only in the third case will the process be considered Ready to accept
|
||||
<-time.After(250 * time.Millisecond) // give process a bit of time to start
|
||||
|
||||
checkStartTime := time.Now()
|
||||
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
|
||||
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
||||
|
||||
// a "none" means don't check for health ... I could have picked a better word :facepalm:
|
||||
if checkEndpoint != "none" {
|
||||
proxyTo := p.config.Proxy
|
||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint)
|
||||
}
|
||||
|
||||
// Ready Check loop
|
||||
for {
|
||||
currentState := p.CurrentState()
|
||||
if currentState != StateStarting {
|
||||
if currentState == StateStopped {
|
||||
return fmt.Errorf("upstream command exited prematurely but successfully")
|
||||
}
|
||||
return errors.New("health check interrupted due to shutdown")
|
||||
}
|
||||
|
||||
if time.Since(checkStartTime) > maxDuration {
|
||||
p.stopCommand()
|
||||
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
|
||||
}
|
||||
|
||||
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
||||
p.proxyLogger.Infof("<%s> Health check passed on %s", p.ID, healthURL)
|
||||
break
|
||||
} else {
|
||||
if strings.Contains(err.Error(), "connection refused") {
|
||||
ttl := time.Until(checkStartTime.Add(maxDuration))
|
||||
p.proxyLogger.Debugf("<%s> Connection refused on %s, giving up in %.0fs (normal during startup)", p.ID, healthURL, ttl.Seconds())
|
||||
} else {
|
||||
p.proxyLogger.Debugf("<%s> Health check error on %s, %v (normal during startup)", p.ID, healthURL, err)
|
||||
}
|
||||
}
|
||||
<-time.After(p.healthCheckLoopInterval)
|
||||
}
|
||||
}
|
||||
|
||||
if p.config.UnloadAfter > 0 {
|
||||
// start a goroutine to check every second if
|
||||
// the process should be stopped
|
||||
go func() {
|
||||
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
|
||||
|
||||
for range time.Tick(time.Second) {
|
||||
if p.CurrentState() != StateReady {
|
||||
return
|
||||
}
|
||||
|
||||
// skip the TTL check if there are inflight requests
|
||||
if p.inFlightRequestsCount.Load() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if time.Since(p.getLastRequestHandled()) > maxDuration {
|
||||
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
|
||||
p.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
|
||||
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
|
||||
} else {
|
||||
p.failedStartCount = 0
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Stop will wait for inflight requests to complete before stopping the process.
|
||||
func (p *Process) Stop() {
|
||||
|
||||
// guard to prevent multiple goroutines from stopping
|
||||
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||
p.proxyLogger.Debugf("<%s> Stop() suppressing invalid transition from %s to StateStopping", p.ID, p.CurrentState())
|
||||
return
|
||||
}
|
||||
|
||||
// wait for any inflight requests before proceeding
|
||||
p.proxyLogger.Debugf("<%s> Stop(): Waiting for inflight requests to complete", p.ID)
|
||||
p.inFlightRequests.Wait()
|
||||
p.StopImmediately()
|
||||
}
|
||||
|
||||
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
|
||||
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
|
||||
func (p *Process) StopImmediately() {
|
||||
|
||||
// guard to prevent multiple goroutines from stopping the process
|
||||
enterState := p.CurrentState()
|
||||
if !isValidTransition(enterState, StateStopping) {
|
||||
p.proxyLogger.Debugf("<%s> StopImmediate() suppressing invalid transition from %s to StateStopping", p.ID, p.CurrentState())
|
||||
return
|
||||
}
|
||||
|
||||
p.proxyLogger.Debugf("<%s> Stopping process, enter state: %s", p.ID, enterState)
|
||||
if curState, err := p.swapState(enterState, StateStopping); err != nil {
|
||||
p.proxyLogger.Infof("<%s> Stop() %s -> StateStopping err: %v, current state: %v", p.ID, enterState, err, curState)
|
||||
return
|
||||
}
|
||||
|
||||
p.stopCommand()
|
||||
}
|
||||
|
||||
// Shutdown is called when llama-swap is shutting down. It will give a little bit
|
||||
// of time for any inflight requests to complete before shutting down. If the Process
|
||||
// is in the state of starting, it will cancel it and shut it down. Once a process is in
|
||||
// the StateShutdown state, it can not be started again.
|
||||
func (p *Process) Shutdown() {
|
||||
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||
return
|
||||
}
|
||||
|
||||
p.stopCommand()
|
||||
// just force it to this state since there is no recovery from shutdown
|
||||
p.forceState(StateShutdown)
|
||||
}
|
||||
|
||||
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
||||
// If it does not exit within 5 seconds, it will send a SIGKILL.
|
||||
func (p *Process) stopCommand() {
|
||||
stopStartTime := time.Now()
|
||||
defer func() {
|
||||
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
||||
|
||||
// free the buffer in processLogger so the memory can be recovered
|
||||
p.processLogger.Clear()
|
||||
}()
|
||||
|
||||
p.cmdMutex.RLock()
|
||||
cancelUpstream := p.cancelUpstream
|
||||
cmdWaitChan := p.cmdWaitChan
|
||||
p.cmdMutex.RUnlock()
|
||||
|
||||
if cancelUpstream == nil {
|
||||
p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID)
|
||||
return
|
||||
}
|
||||
|
||||
cancelUpstream()
|
||||
<-cmdWaitChan
|
||||
}
|
||||
|
||||
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||
|
||||
client := &http.Client{
|
||||
// wait a short time for a tcp connection to be established
|
||||
Transport: &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 500 * time.Millisecond,
|
||||
}).DialContext,
|
||||
},
|
||||
|
||||
// give a long time to respond to the health check endpoint
|
||||
// after the connection is established. See issue: 276
|
||||
Timeout: 5000 * time.Millisecond,
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", healthURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// got a response but it was not an OK
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if p.reverseProxy == nil {
|
||||
http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
requestBeginTime := time.Now()
|
||||
var startDuration time.Duration
|
||||
|
||||
// prevent new requests from being made while stopping or irrecoverable
|
||||
currentState := p.CurrentState()
|
||||
if currentState == StateShutdown || currentState == StateStopping {
|
||||
http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case p.concurrencyLimitSemaphore <- struct{}{}:
|
||||
defer func() { <-p.concurrencyLimitSemaphore }()
|
||||
default:
|
||||
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
p.inFlightRequests.Add(1)
|
||||
p.inFlightRequestsCount.Add(1)
|
||||
defer func() {
|
||||
p.setLastRequestHandled(time.Now())
|
||||
p.inFlightRequestsCount.Add(-1)
|
||||
p.inFlightRequests.Done()
|
||||
}()
|
||||
|
||||
// for #366
|
||||
// - extract streaming param from request context, should have been set by proxymanager
|
||||
var srw *statusResponseWriter
|
||||
swapCtx, cancelLoadCtx := context.WithCancel(r.Context())
|
||||
// start the process on demand
|
||||
if p.CurrentState() != StateReady {
|
||||
// start a goroutine to stream loading status messages into the response writer
|
||||
// add a sync so the streaming client only runs when the goroutine has exited
|
||||
|
||||
isStreaming, _ := r.Context().Value(proxyCtxKey("streaming")).(bool)
|
||||
|
||||
// PR #417 (no support for anthropic v1/messages yet)
|
||||
isChatCompletions := strings.HasPrefix(r.URL.Path, "/v1/chat/completions")
|
||||
if p.config.SendLoadingState != nil && *p.config.SendLoadingState && isStreaming && isChatCompletions {
|
||||
srw = newStatusResponseWriter(p, w)
|
||||
go srw.statusUpdates(swapCtx)
|
||||
} else {
|
||||
p.proxyLogger.Debugf("<%s> SendLoadingState is nil or false, not streaming loading state", p.ID)
|
||||
}
|
||||
|
||||
beginStartTime := time.Now()
|
||||
if err := p.start(); err != nil {
|
||||
errstr := fmt.Sprintf("unable to start process: %s", err)
|
||||
cancelLoadCtx()
|
||||
if srw != nil {
|
||||
srw.sendData(fmt.Sprintf("Unable to swap model err: %s\n", errstr))
|
||||
// Wait for statusUpdates goroutine to finish writing its deferred "Done!" messages
|
||||
// before closing the connection. Without this, the connection would close before
|
||||
// the goroutine can write its cleanup messages, causing incomplete SSE output.
|
||||
srw.waitForCompletion(100 * time.Millisecond)
|
||||
} else {
|
||||
http.Error(w, errstr, http.StatusBadGateway)
|
||||
}
|
||||
return
|
||||
}
|
||||
startDuration = time.Since(beginStartTime)
|
||||
}
|
||||
|
||||
// should trigger srw to stop sending loading events ...
|
||||
cancelLoadCtx()
|
||||
|
||||
// recover from http.ErrAbortHandler panics that can occur when the client
|
||||
// disconnects before the response is sent
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if r == http.ErrAbortHandler {
|
||||
p.proxyLogger.Infof("<%s> recovered from client disconnection during streaming", p.ID)
|
||||
} else {
|
||||
p.proxyLogger.Infof("<%s> recovered from panic: %v", p.ID, r)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if srw != nil {
|
||||
// Wait for the goroutine to finish writing its final messages
|
||||
const completionTimeout = 1 * time.Second
|
||||
if !srw.waitForCompletion(completionTimeout) {
|
||||
p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
if p.testHandler != nil {
|
||||
p.testHandler.ServeHTTP(w, r)
|
||||
} else if srw != nil {
|
||||
p.reverseProxy.ServeHTTP(srw, r)
|
||||
} else {
|
||||
p.reverseProxy.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
totalTime := time.Since(requestBeginTime)
|
||||
p.proxyLogger.Debugf("<%s> request %s - start: %v, total: %v",
|
||||
p.ID, r.RequestURI, startDuration, totalTime)
|
||||
}
|
||||
|
||||
// waitForCmd waits for the command to exit and handles exit conditions depending on current state
|
||||
func (p *Process) waitForCmd() {
|
||||
exitErr := p.cmd.Wait()
|
||||
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
|
||||
|
||||
if exitErr != nil {
|
||||
if errno, ok := exitErr.(syscall.Errno); ok {
|
||||
p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno)
|
||||
} else if exitError, ok := exitErr.(*exec.ExitError); ok {
|
||||
if strings.Contains(exitError.String(), "signal: terminated") {
|
||||
p.proxyLogger.Debugf("<%s> Process stopped OK", p.ID)
|
||||
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
||||
p.proxyLogger.Debugf("<%s> Process interrupted OK", p.ID)
|
||||
} else {
|
||||
p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
|
||||
}
|
||||
} else {
|
||||
if exitErr.Error() != "context canceled" /* this is normal */ {
|
||||
p.proxyLogger.Errorf("<%s> Process exited >> %v", p.ID, exitErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
currentState := p.CurrentState()
|
||||
switch currentState {
|
||||
case StateStopping:
|
||||
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err)
|
||||
p.forceState(StateStopped)
|
||||
}
|
||||
default:
|
||||
p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState)
|
||||
p.forceState(StateStopped) // force it to be in this state
|
||||
}
|
||||
|
||||
p.cmdMutex.Lock()
|
||||
close(p.cmdWaitChan)
|
||||
p.cmdMutex.Unlock()
|
||||
}
|
||||
|
||||
// cmdStopUpstreamProcess attemps to stop the upstream process gracefully
|
||||
func (p *Process) cmdStopUpstreamProcess() error {
|
||||
p.processLogger.Debugf("<%s> cmdStopUpstreamProcess() initiating graceful stop of upstream process", p.ID)
|
||||
|
||||
// this should never happen ...
|
||||
if p.cmd == nil || p.cmd.Process == nil {
|
||||
p.proxyLogger.Debugf("<%s> cmd or cmd.Process is nil (normal during config reload)", p.ID)
|
||||
return fmt.Errorf("<%s> process is nil or cmd is nil, skipping graceful stop", p.ID)
|
||||
}
|
||||
|
||||
if p.config.CmdStop != "" {
|
||||
// replace ${PID} with the pid of the process
|
||||
stopArgs, err := config.SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
|
||||
if err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
p.proxyLogger.Debugf("<%s> Executing stop command: %s", p.ID, strings.Join(stopArgs, " "))
|
||||
|
||||
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
|
||||
stopCmd.Stdout = p.processLogger
|
||||
stopCmd.Stderr = p.processLogger
|
||||
setProcAttributes(stopCmd)
|
||||
stopCmd.Env = p.cmd.Env
|
||||
|
||||
if err := stopCmd.Run(); err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Failed to exec stop command: %v", p.ID, err)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := p.cmd.Process.Signal(syscall.SIGTERM); err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Failed to send SIGTERM to process: %v", p.ID, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Logger returns the logger for this process.
|
||||
func (p *Process) Logger() *logmon.Monitor {
|
||||
return p.processLogger
|
||||
}
|
||||
|
||||
var loadingRemarks = []string{
|
||||
"Still faster than your last standup meeting...",
|
||||
"Reticulating splines...",
|
||||
"Waking up the hamsters...",
|
||||
"Teaching the model manners...",
|
||||
"Convincing the GPU to participate...",
|
||||
"Loading weights (they're heavy)...",
|
||||
"Herding electrons...",
|
||||
"Compiling excuses for the delay...",
|
||||
"Downloading more RAM...",
|
||||
"Asking the model nicely to boot up...",
|
||||
"Bribing CUDA with cookies...",
|
||||
"Still loading (blame VRAM)...",
|
||||
"The model is fashionably late...",
|
||||
"Warming up those tensors...",
|
||||
"Making the neural net do push-ups...",
|
||||
"Your patience is appreciated (really)...",
|
||||
"Almost there (probably)...",
|
||||
"Loading like it's 1999...",
|
||||
"The model forgot where it put its keys...",
|
||||
"Quantum tunneling through layers...",
|
||||
"Negotiating with the PCIe bus...",
|
||||
"Defrosting frozen parameters...",
|
||||
"Teaching attention heads to focus...",
|
||||
"Running the matrix (slowly)...",
|
||||
"Untangling transformer blocks...",
|
||||
"Calibrating the flux capacitor...",
|
||||
"Spinning up the probability wheels...",
|
||||
"Waiting for the GPU to wake from its nap...",
|
||||
"Converting caffeine to compute...",
|
||||
"Allocating virtual patience...",
|
||||
"Performing arcane CUDA rituals...",
|
||||
"The model is stuck in traffic...",
|
||||
"Inflating embeddings...",
|
||||
"Summoning computational demons...",
|
||||
"Pleading with the OOM killer...",
|
||||
"Calculating the meaning of life (still at 42)...",
|
||||
"Training the training wheels...",
|
||||
"Optimizing the optimizer...",
|
||||
"Bootstrapping the bootstrapper...",
|
||||
"Loading loading screen...",
|
||||
"Processing processing logs...",
|
||||
"Buffering buffer overflow jokes...",
|
||||
"The model hit snooze...",
|
||||
"Debugging the debugger...",
|
||||
"Compiling the compiler...",
|
||||
"Parsing the parser (meta)...",
|
||||
"Tokenizing tokens...",
|
||||
"Encoding the encoder...",
|
||||
"Hashing hash browns...",
|
||||
"Forking spoons (not forks)...",
|
||||
"The model is contemplating existence...",
|
||||
"Transcending dimensional barriers...",
|
||||
"Invoking elder tensor gods...",
|
||||
"Unfurling probability clouds...",
|
||||
"Synchronizing parallel universes...",
|
||||
"The GPU is having second thoughts...",
|
||||
"Recalibrating reality matrices...",
|
||||
"Time is an illusion, loading doubly so...",
|
||||
"Convincing bits to flip themselves...",
|
||||
"The model is reading its own documentation...",
|
||||
}
|
||||
|
||||
type statusResponseWriter struct {
|
||||
hasWritten bool
|
||||
writer http.ResponseWriter
|
||||
process *Process
|
||||
wg sync.WaitGroup // Track goroutine completion
|
||||
start time.Time
|
||||
}
|
||||
|
||||
func newStatusResponseWriter(p *Process, w http.ResponseWriter) *statusResponseWriter {
|
||||
s := &statusResponseWriter{
|
||||
writer: w,
|
||||
process: p,
|
||||
start: time.Now(),
|
||||
}
|
||||
|
||||
s.Header().Set("Content-Type", "text/event-stream") // SSE
|
||||
s.Header().Set("Cache-Control", "no-cache") // no-cache
|
||||
s.Header().Set("Connection", "keep-alive") // keep-alive
|
||||
s.WriteHeader(http.StatusOK) // send status code 200
|
||||
s.sendLine("━━━━━")
|
||||
s.sendLine(fmt.Sprintf("llama-swap loading model: %s", p.ID))
|
||||
return s
|
||||
}
|
||||
|
||||
// statusUpdates sends status updates to the client while the model is loading
|
||||
func (s *statusResponseWriter) statusUpdates(ctx context.Context) {
|
||||
s.wg.Add(1)
|
||||
defer s.wg.Done()
|
||||
|
||||
// Recover from panics caused by client disconnection
|
||||
// Note: recover() only works within the same goroutine, so we need it here
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
s.process.proxyLogger.Debugf("<%s> statusUpdates recovered from panic (likely client disconnect): %v", s.process.ID, r)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
duration := time.Since(s.start)
|
||||
s.sendLine(fmt.Sprintf("\nDone! (%.2fs)", duration.Seconds()))
|
||||
s.sendLine("━━━━━")
|
||||
s.sendLine(" ")
|
||||
}()
|
||||
|
||||
// Create a shuffled copy of loadingRemarks
|
||||
remarks := make([]string, len(loadingRemarks))
|
||||
copy(remarks, loadingRemarks)
|
||||
rand.Shuffle(len(remarks), func(i, j int) {
|
||||
remarks[i], remarks[j] = remarks[j], remarks[i]
|
||||
})
|
||||
ri := 0
|
||||
|
||||
// Pick a random duration to send a remark
|
||||
nextRemarkIn := time.Duration(2+rand.Intn(4)) * time.Second
|
||||
lastRemarkTime := time.Now()
|
||||
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop() // Ensure ticker is stopped to prevent resource leak
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if s.process.CurrentState() == StateReady {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if it's time for a snarky remark
|
||||
if time.Since(lastRemarkTime) >= nextRemarkIn {
|
||||
remark := remarks[ri%len(remarks)]
|
||||
ri++
|
||||
s.sendLine(fmt.Sprintf("\n%s", remark))
|
||||
lastRemarkTime = time.Now()
|
||||
// Pick a new random duration for the next remark
|
||||
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
|
||||
} else {
|
||||
s.sendData(".")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// waitForCompletion waits for the statusUpdates goroutine to finish
|
||||
func (s *statusResponseWriter) waitForCompletion(timeout time.Duration) bool {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return true
|
||||
case <-time.After(timeout):
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) sendLine(line string) {
|
||||
s.sendData(line + "\n")
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) sendData(data string) {
|
||||
// Create the proper SSE JSON structure
|
||||
type Delta struct {
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
}
|
||||
type Choice struct {
|
||||
Delta Delta `json:"delta"`
|
||||
}
|
||||
type SSEMessage struct {
|
||||
Choices []Choice `json:"choices"`
|
||||
}
|
||||
|
||||
msg := SSEMessage{
|
||||
Choices: []Choice{
|
||||
{
|
||||
Delta: Delta{
|
||||
ReasoningContent: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
s.process.proxyLogger.Errorf("<%s> Failed to marshal SSE message: %v", s.process.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Write SSE formatted data, panic if not able to write
|
||||
_, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("<%s> Failed to write SSE data: %v", s.process.ID, err))
|
||||
}
|
||||
s.Flush()
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) Header() http.Header {
|
||||
return s.writer.Header()
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) Write(data []byte) (int, error) {
|
||||
return s.writer.Write(data)
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) WriteHeader(statusCode int) {
|
||||
if s.hasWritten {
|
||||
return
|
||||
}
|
||||
s.hasWritten = true
|
||||
s.writer.WriteHeader(statusCode)
|
||||
s.Flush()
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) Flush() {
|
||||
if flusher, ok := s.writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
@@ -1,609 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
debugLogger = logmon.NewWriter(os.Stdout)
|
||||
)
|
||||
|
||||
func init() {
|
||||
// flip to help with debugging tests
|
||||
if false {
|
||||
debugLogger.SetLogLevel(logmon.LevelDebug)
|
||||
} else {
|
||||
debugLogger.SetLogLevel(logmon.LevelError)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
||||
|
||||
expectedMessage := "testing91931"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// Create a process
|
||||
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// process is automatically started
|
||||
assert.Equal(t, StateStopped, process.CurrentState())
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), expectedMessage)
|
||||
|
||||
// Stop the process
|
||||
process.Stop()
|
||||
|
||||
req = httptest.NewRequest("GET", "/", nil)
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
// Proxy the request
|
||||
process.ProxyRequest(w, req)
|
||||
|
||||
// should have automatically started the process again
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcess_WaitOnMultipleStarts tests that multiple concurrent requests
|
||||
// are all handled successfully, even though they all may ask for the process to .start()
|
||||
func TestProcess_WaitOnMultipleStarts(t *testing.T) {
|
||||
|
||||
expectedMessage := "testing91931"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func(reqID int) {
|
||||
defer wg.Done()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Worker %d got wrong HTTP code", reqID)
|
||||
assert.Contains(t, w.Body.String(), expectedMessage, "Worker %d got wrong message", reqID)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
}
|
||||
|
||||
// test that the automatic start returns the expected error type
|
||||
func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||
// Create a process configuration
|
||||
config := config.ModelConfig{
|
||||
Cmd: "nonexistent-command",
|
||||
Proxy: "http://127.0.0.1:9913",
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("broken", 1, config, debugLogger, debugLogger)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "unable to start process")
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "start() failed for command 'nonexistent-command':")
|
||||
}
|
||||
|
||||
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long auto unload TTL test")
|
||||
}
|
||||
|
||||
expectedMessage := "I_sense_imminent_danger"
|
||||
conf := getTestSimpleResponderConfig(expectedMessage)
|
||||
assert.Equal(t, config.MODEL_CONFIG_DEFAULT_TTL, conf.UnloadAfter)
|
||||
conf.UnloadAfter = 3 // seconds
|
||||
assert.Equal(t, 3, conf.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl_test", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// this should take 4 seconds
|
||||
req1 := httptest.NewRequest("GET", "/slow-respond?echo=1234&delay=1000ms", nil)
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Proxy the request (auto start) with a slow response that takes longer than config.UnloadAfter
|
||||
process.ProxyRequest(w, req1)
|
||||
|
||||
t.Log("sending slow first request (4 seconds)")
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "1234")
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
|
||||
// ensure the TTL timeout does not race slow requests (see issue #25)
|
||||
t.Log("sending second request (1 second)")
|
||||
time.Sleep(time.Second)
|
||||
w = httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req2)
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), expectedMessage)
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
|
||||
// wait 5 seconds
|
||||
t.Log("sleep 5 seconds and check if unloaded")
|
||||
time.Sleep(5 * time.Second)
|
||||
assert.Equal(t, StateStopped, process.CurrentState())
|
||||
}
|
||||
|
||||
func TestProcess_LowTTLValue(t *testing.T) {
|
||||
if true { // change this code to run this ...
|
||||
t.Skip("skipping test, edit process_test.go to run it ")
|
||||
}
|
||||
|
||||
conf := getTestSimpleResponderConfig("fast_ttl")
|
||||
assert.Equal(t, config.MODEL_CONFIG_DEFAULT_TTL, conf.UnloadAfter)
|
||||
conf.UnloadAfter = 1 // second
|
||||
assert.Equal(t, 1, conf.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
t.Logf("Waiting before sending request %d", i)
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
expected := fmt.Sprintf("echo=test_%d", i)
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=50ms", expected), nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), expected)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// issue #19
|
||||
// This test makes sure using Process.Stop() does not affect pending HTTP
|
||||
// requests. All HTTP requests in this test should complete successfully.
|
||||
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
expectedMessage := "12345"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
process := NewProcess("t", 10, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
results := map[string]string{
|
||||
"12345": "",
|
||||
"abcde": "",
|
||||
"fghij": "",
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
|
||||
for key := range results {
|
||||
wg.Add(1)
|
||||
go func(key string) {
|
||||
defer wg.Done()
|
||||
// send a request where simple-responder is will wait 300ms before responding
|
||||
// this will simulate an in-progress request.
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=300ms", key), nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
process.ProxyRequest(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
results[key] = w.Body.String()
|
||||
mu.Unlock()
|
||||
|
||||
}(key)
|
||||
}
|
||||
|
||||
// Stop the process while requests are still being processed
|
||||
go func() {
|
||||
<-time.After(150 * time.Millisecond)
|
||||
process.Stop()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
for key, result := range results {
|
||||
assert.Equal(t, key, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcess_SwapState(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
currentState ProcessState
|
||||
expectedState ProcessState
|
||||
newState ProcessState
|
||||
expectedError error
|
||||
expectedResult ProcessState
|
||||
}{
|
||||
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
|
||||
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
|
||||
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
|
||||
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, nil, StateStopped},
|
||||
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
|
||||
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
|
||||
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
|
||||
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
|
||||
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
|
||||
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
|
||||
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
|
||||
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
|
||||
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger)
|
||||
p.state = test.currentState
|
||||
|
||||
resultState, err := p.swapState(test.expectedState, test.newState)
|
||||
if err != nil && test.expectedError == nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
} else if err == nil && test.expectedError != nil {
|
||||
t.Errorf("Expected error: %v, but got none", test.expectedError)
|
||||
} else if err != nil && test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("Expected error: %v, got: %v", test.expectedError, err)
|
||||
}
|
||||
}
|
||||
|
||||
if resultState != test.expectedResult {
|
||||
t.Errorf("Expected state: %v, got: %v", test.expectedResult, resultState)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long shutdown test")
|
||||
}
|
||||
|
||||
expectedMessage := "testing91931"
|
||||
|
||||
// make a config where the healthcheck will always fail because port is wrong
|
||||
config := getTestSimpleResponderConfigPort(expectedMessage, 9999)
|
||||
config.Proxy = "http://localhost:9998/test"
|
||||
|
||||
healthCheckTTLSeconds := 30
|
||||
process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger)
|
||||
|
||||
// make it a lot faster
|
||||
process.healthCheckLoopInterval = time.Second
|
||||
|
||||
// start a goroutine to simulate a shutdown
|
||||
var wg sync.WaitGroup
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-time.After(time.Millisecond * 500)
|
||||
process.Shutdown()
|
||||
}()
|
||||
wg.Add(1)
|
||||
|
||||
// start the process, this is a blocking call
|
||||
err := process.start()
|
||||
|
||||
wg.Wait()
|
||||
assert.ErrorContains(t, err, "health check interrupted due to shutdown")
|
||||
assert.Equal(t, StateShutdown, process.CurrentState())
|
||||
}
|
||||
|
||||
func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping Exit Interrupts Health Check test")
|
||||
}
|
||||
|
||||
// should run and exit but interrupt the long checkHealthTimeout
|
||||
checkHealthTimeout := 5
|
||||
config := config.ModelConfig{
|
||||
Cmd: "sleep 1",
|
||||
Proxy: "http://127.0.0.1:9913",
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger)
|
||||
process.healthCheckLoopInterval = time.Second // make it faster
|
||||
err := process.start()
|
||||
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
func TestProcess_ConcurrencyLimit(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long concurrency limit test")
|
||||
}
|
||||
|
||||
expectedMessage := "concurrency_limit_test"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// only allow 1 concurrent request at a time
|
||||
config.ConcurrencyLimit = 1
|
||||
|
||||
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
|
||||
assert.Equal(t, 1, cap(process.concurrencyLimitSemaphore))
|
||||
defer process.Stop()
|
||||
|
||||
// launch a goroutine first to take up the semaphore
|
||||
go func() {
|
||||
req1 := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=75ms", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req1)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}()
|
||||
|
||||
// let the goroutine start
|
||||
<-time.After(time.Millisecond * 25)
|
||||
|
||||
denied := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, denied)
|
||||
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||
}
|
||||
|
||||
func TestProcess_StopImmediately(t *testing.T) {
|
||||
expectedMessage := "test_stop_immediate"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, process.CurrentState(), StateReady)
|
||||
go func() {
|
||||
// slow, but will get killed by StopImmediate
|
||||
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=1s", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
}()
|
||||
<-time.After(time.Millisecond)
|
||||
process.StopImmediately()
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
|
||||
// the upstream command
|
||||
func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping SIGTERM test on Windows ")
|
||||
}
|
||||
|
||||
expectedMessage := "test_sigkill"
|
||||
binaryPath := getSimpleResponderPath()
|
||||
port := getTestPort()
|
||||
|
||||
conf := config.ModelConfig{
|
||||
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
|
||||
// to force the process to exit
|
||||
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("stop_immediate", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// reduce to make testing go faster
|
||||
process.gracefulStopTimeout = time.Second
|
||||
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, process.CurrentState(), StateReady)
|
||||
|
||||
waitChan := make(chan struct{})
|
||||
go func() {
|
||||
// slow, but will get killed by StopImmediate
|
||||
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=2s", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
|
||||
// StatusOK because that was already sent before the kill
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// unexpected EOF because the kill happened, the "1" is sent before the kill
|
||||
// then the unexpected EOF is sent after the kill
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host")
|
||||
} else {
|
||||
// Upstream may be killed mid-response.
|
||||
// Assert an incomplete or partial response.
|
||||
assert.NotEqual(t, "12345", w.Body.String())
|
||||
}
|
||||
|
||||
close(waitChan)
|
||||
}()
|
||||
|
||||
<-time.After(time.Millisecond)
|
||||
process.StopImmediately()
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
|
||||
// the request should have been interrupted by SIGKILL
|
||||
<-waitChan
|
||||
}
|
||||
|
||||
func TestProcess_StopCmd(t *testing.T) {
|
||||
conf := getTestSimpleResponderConfig("test_stop_cmd")
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
conf.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
} else {
|
||||
conf.CmdStop = "kill -TERM ${PID}"
|
||||
}
|
||||
|
||||
process := NewProcess("testStopCmd", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, process.CurrentState(), StateReady)
|
||||
process.StopImmediately()
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
|
||||
expectedMessage := "test_env_not_emptied"
|
||||
conf := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// ensure that the the default config does not blank out the inherited environment
|
||||
configWEnv := conf
|
||||
|
||||
// ensure the additiona variables are appended to the process' environment
|
||||
configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2")
|
||||
|
||||
process1 := NewProcess("env_test", 2, conf, debugLogger, debugLogger)
|
||||
process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger)
|
||||
|
||||
process1.start()
|
||||
defer process1.Stop()
|
||||
process2.start()
|
||||
defer process2.Stop()
|
||||
|
||||
assert.NotZero(t, len(process1.cmd.Environ()))
|
||||
assert.NotZero(t, len(process2.cmd.Environ()))
|
||||
assert.Equal(t, len(process1.cmd.Environ())+2, len(process2.cmd.Environ()), "process2 should have 2 more environment variables than process1")
|
||||
|
||||
}
|
||||
|
||||
// TestProcess_ReverseProxyPanicIsHandled tests that panics from
|
||||
// httputil.ReverseProxy in Process.ProxyRequest(w, r) do not bubble up and are
|
||||
// handled appropriately.
|
||||
//
|
||||
// httputil.ReverseProxy will panic with http.ErrAbortHandler when it has sent headers
|
||||
// can't copy the body. This can be caused by a client disconnecting before the full
|
||||
// response is sent from some reason.
|
||||
//
|
||||
// bug: https://github.com/mostlygeek/llama-swap/issues/362
|
||||
// see: https://github.com/golang/go/issues/23643 (where panic was added to httputil.ReverseProxy)
|
||||
func TestProcess_ReverseProxyPanicIsHandled(t *testing.T) {
|
||||
// Add defer/recover to catch any panics that aren't handled by ProxyRequest
|
||||
// If this recover() is hit, it means ProxyRequest didn't handle the panic properly
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("ProxyRequest should handle panics from reverseProxy.ServeHTTP, but panic was not caught: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
expectedMessage := "panic_test"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
process := NewProcess("panic-test", 5, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// Start the process
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
|
||||
// Create a custom ResponseWriter that simulates a client disconnect
|
||||
// by panicking when Write is called after headers are sent
|
||||
panicWriter := &panicOnWriteResponseWriter{
|
||||
ResponseRecorder: httptest.NewRecorder(),
|
||||
shouldPanic: true,
|
||||
}
|
||||
|
||||
// Make a request that will trigger the panic
|
||||
req := httptest.NewRequest("GET", "/slow-respond?echo=test&delay=100ms", nil)
|
||||
|
||||
// This should panic inside reverseProxy.ServeHTTP when the panicWriter.Write() is called.
|
||||
// ProxyRequest should catch and handle this panic gracefully.
|
||||
process.ProxyRequest(panicWriter, req)
|
||||
|
||||
// If we get here, the panic was properly recovered in ProxyRequest
|
||||
// The process should still be in a ready state
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
}
|
||||
|
||||
// panicOnWriteResponseWriter is a ResponseWriter that panics on Write
|
||||
// to simulate a client disconnect after headers are sent
|
||||
// used by: TestProcess_ReverseProxyPanicIsHandled
|
||||
type panicOnWriteResponseWriter struct {
|
||||
*httptest.ResponseRecorder
|
||||
shouldPanic bool
|
||||
headerWritten bool
|
||||
}
|
||||
|
||||
func (w *panicOnWriteResponseWriter) WriteHeader(statusCode int) {
|
||||
w.headerWritten = true
|
||||
w.ResponseRecorder.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *panicOnWriteResponseWriter) Write(b []byte) (int, error) {
|
||||
if w.shouldPanic && w.headerWritten {
|
||||
// Simulate the panic that httputil.ReverseProxy throws
|
||||
panic(http.ErrAbortHandler)
|
||||
}
|
||||
return w.ResponseRecorder.Write(b)
|
||||
}
|
||||
|
||||
func TestProcess_CustomTimeouts(t *testing.T) {
|
||||
modelConfig := config.ModelConfig{
|
||||
Cmd: "echo test",
|
||||
Proxy: "http://localhost:8080",
|
||||
CheckEndpoint: "/health",
|
||||
Timeouts: config.TimeoutsConfig{
|
||||
Connect: 45,
|
||||
ResponseHeader: 120,
|
||||
TLSHandshake: 15,
|
||||
ExpectContinue: 2,
|
||||
IdleConn: 120,
|
||||
},
|
||||
}
|
||||
|
||||
debugLogger := logmon.NewWriter(io.Discard)
|
||||
process := NewProcess("test-model", 30, modelConfig, debugLogger, debugLogger)
|
||||
|
||||
// Verify the process was created successfully
|
||||
assert.NotNil(t, process)
|
||||
assert.Equal(t, "test-model", process.ID)
|
||||
assert.NotNil(t, process.reverseProxy)
|
||||
assert.NotNil(t, process.reverseProxy.Transport)
|
||||
|
||||
// Verify it's using http.Transport (not some other type)
|
||||
transport, ok := process.reverseProxy.Transport.(*http.Transport)
|
||||
assert.True(t, ok, "Transport should be *http.Transport")
|
||||
assert.NotNil(t, transport)
|
||||
|
||||
// Verify the timeouts are correctly applied
|
||||
assert.Equal(t, 120*time.Second, transport.ResponseHeaderTimeout)
|
||||
assert.Equal(t, 15*time.Second, transport.TLSHandshakeTimeout)
|
||||
assert.Equal(t, 2*time.Second, transport.ExpectContinueTimeout)
|
||||
assert.Equal(t, 120*time.Second, transport.IdleConnTimeout)
|
||||
assert.True(t, transport.ForceAttemptHTTP2)
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
// No-op on Unix systems
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
HideWindow: true,
|
||||
CreationFlags: 0x08000000, // CREATE_NO_WINDOW
|
||||
}
|
||||
}
|
||||
@@ -1,194 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
type ProcessGroup struct {
|
||||
sync.Mutex
|
||||
|
||||
config config.Config
|
||||
id string
|
||||
swap bool
|
||||
exclusive bool
|
||||
persistent bool
|
||||
|
||||
proxyLogger *logmon.Monitor
|
||||
upstreamLogger *logmon.Monitor
|
||||
|
||||
// map of current processes
|
||||
processes map[string]*Process
|
||||
lastUsedProcess string
|
||||
|
||||
// inflight tracks fast-path requests (requests for the already-selected
|
||||
// model in a swap group). Fast-path requests Add(1) while holding pg.Lock
|
||||
// and Done() on completion; a concurrent swap request calls inflight.Wait()
|
||||
// under pg.Lock before stopping the current process. Without this tracking,
|
||||
// a fast-path request that has released pg.Lock but has not yet called
|
||||
// Process.inFlightRequests.Add(1) races with Stop()'s Wait() and can be
|
||||
// killed mid-request.
|
||||
inflight sync.WaitGroup
|
||||
|
||||
// testDelayFastPath is a test-only hook that, when non-nil, is invoked in
|
||||
// the fast path after pg.Lock is released but before the request is
|
||||
// dispatched to Process.ProxyRequest. Tests use it to park a fast-path
|
||||
// request at the exact race window to deterministically reproduce the
|
||||
// fast-path vs swap race.
|
||||
testDelayFastPath func()
|
||||
}
|
||||
|
||||
func NewProcessGroup(id string, config config.Config, proxyLogger *logmon.Monitor, upstreamLogger *logmon.Monitor) *ProcessGroup {
|
||||
groupConfig, ok := config.Groups[id]
|
||||
if !ok {
|
||||
panic("Unable to find configuration for group id: " + id)
|
||||
}
|
||||
|
||||
pg := &ProcessGroup{
|
||||
id: id,
|
||||
config: config,
|
||||
swap: groupConfig.Swap,
|
||||
exclusive: groupConfig.Exclusive,
|
||||
persistent: groupConfig.Persistent,
|
||||
proxyLogger: proxyLogger,
|
||||
upstreamLogger: upstreamLogger,
|
||||
processes: make(map[string]*Process),
|
||||
}
|
||||
|
||||
// Create a Process for each member in the group
|
||||
for _, modelID := range groupConfig.Members {
|
||||
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
|
||||
processLogger := logmon.NewWriter(upstreamLogger)
|
||||
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, processLogger, pg.proxyLogger)
|
||||
pg.processes[modelID] = process
|
||||
}
|
||||
|
||||
return pg
|
||||
}
|
||||
|
||||
// ProxyRequest proxies a request to the specified model
|
||||
func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter, request *http.Request) error {
|
||||
if !pg.HasMember(modelID) {
|
||||
return fmt.Errorf("model %s not part of group %s", modelID, pg.id)
|
||||
}
|
||||
|
||||
if pg.swap {
|
||||
pg.Lock()
|
||||
if pg.lastUsedProcess != modelID {
|
||||
|
||||
// Wait for in-flight fast-path requests to drain before stopping
|
||||
// the previous process. Without this, a fast-path request that has
|
||||
// released pg.Lock but has not yet incremented
|
||||
// Process.inFlightRequests races with Stop() and can be killed
|
||||
// mid-request.
|
||||
pg.inflight.Wait()
|
||||
|
||||
// is there something already running?
|
||||
if pg.lastUsedProcess != "" {
|
||||
pg.processes[pg.lastUsedProcess].Stop()
|
||||
}
|
||||
|
||||
// wait for the request to the new model to be fully handled
|
||||
// and prevent race conditions see issue #277
|
||||
pg.processes[modelID].ProxyRequest(writer, request)
|
||||
pg.lastUsedProcess = modelID
|
||||
|
||||
// short circuit and exit
|
||||
pg.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fast path: register this request in inflight before releasing
|
||||
// pg.Lock so a concurrent swap will wait for it to complete.
|
||||
pg.inflight.Add(1)
|
||||
defer pg.inflight.Done()
|
||||
pg.Unlock()
|
||||
|
||||
if pg.testDelayFastPath != nil {
|
||||
pg.testDelayFastPath()
|
||||
}
|
||||
}
|
||||
|
||||
pg.processes[modelID].ProxyRequest(writer, request)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) HasMember(modelName string) bool {
|
||||
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) GetMember(modelName string) (*Process, bool) {
|
||||
if pg.HasMember(modelName) {
|
||||
return pg.processes[modelName], true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error {
|
||||
pg.Lock()
|
||||
|
||||
process, exists := pg.processes[modelID]
|
||||
if !exists {
|
||||
pg.Unlock()
|
||||
return fmt.Errorf("process not found for %s", modelID)
|
||||
}
|
||||
|
||||
if pg.lastUsedProcess == modelID {
|
||||
pg.lastUsedProcess = ""
|
||||
}
|
||||
pg.Unlock()
|
||||
|
||||
switch strategy {
|
||||
case StopImmediately:
|
||||
process.StopImmediately()
|
||||
default:
|
||||
process.Stop()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
|
||||
pg.Lock()
|
||||
defer pg.Unlock()
|
||||
|
||||
if strategy != StopImmediately {
|
||||
pg.inflight.Wait()
|
||||
}
|
||||
|
||||
if len(pg.processes) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// stop Processes in parallel
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pg.processes {
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
defer wg.Done()
|
||||
switch strategy {
|
||||
case StopImmediately:
|
||||
process.StopImmediately()
|
||||
default:
|
||||
process.Stop()
|
||||
}
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) Shutdown() {
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pg.processes {
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
defer wg.Done()
|
||||
process.Shutdown()
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -1,345 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
"model4": getTestSimpleResponderConfig("model4"),
|
||||
"model5": getTestSimpleResponderConfig("model5"),
|
||||
},
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model2"},
|
||||
},
|
||||
"G2": {
|
||||
Swap: false,
|
||||
Exclusive: true,
|
||||
Members: []string{"model3", "model4"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
|
||||
pg := NewProcessGroup(config.DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
|
||||
assert.True(t, pg.HasMember("model5"))
|
||||
}
|
||||
|
||||
func TestProcessGroup_HasMember(t *testing.T) {
|
||||
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||
assert.True(t, pg.HasMember("model1"))
|
||||
assert.True(t, pg.HasMember("model2"))
|
||||
assert.False(t, pg.HasMember("model3"))
|
||||
}
|
||||
|
||||
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
|
||||
// and multiple requests are made in parallel, only one process is running at a time.
|
||||
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
// use the same listening so if a model is already running, it will fail
|
||||
// this is a way to test that swap isolation is working
|
||||
// properly when there are parallel requests made at the
|
||||
// same time.
|
||||
"model1": getTestSimpleResponderConfigPort("model1", 9832),
|
||||
"model2": getTestSimpleResponderConfigPort("model2", 9832),
|
||||
"model3": getTestSimpleResponderConfigPort("model3", 9832),
|
||||
"model4": getTestSimpleResponderConfigPort("model4", 9832),
|
||||
"model5": getTestSimpleResponderConfigPort("model5", 9832),
|
||||
},
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Members: []string{"model1", "model2", "model3", "model4", "model5"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
tests := []string{"model1", "model2", "model3", "model4", "model5"}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(len(tests))
|
||||
for _, modelName := range tests {
|
||||
go func(modelName string) {
|
||||
defer wg.Done()
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelName)
|
||||
}(modelName)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestProcessGroup_ProxyRequestSwapRaceAgainstFastPath verifies that a swap
|
||||
// request cannot stop the current process while a fast-path request (for the
|
||||
// already-selected model) is in flight. Without ProcessGroup-level inflight
|
||||
// tracking, a fast-path request that has released pg.Lock but has not yet
|
||||
// incremented Process.inFlightRequests races with Stop()'s Wait() and the
|
||||
// process is killed mid-request.
|
||||
func TestProcessGroup_ProxyRequestSwapRaceAgainstFastPath(t *testing.T) {
|
||||
cfg := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Members: []string{"model1", "model2"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
pg := NewProcessGroup("G1", cfg, testLogger, testLogger)
|
||||
defer pg.StopProcesses(StopImmediately)
|
||||
|
||||
// Bypass real subprocesses so the test is fast and deterministic.
|
||||
pg.processes["model1"].testHandler = newTestHandler("model1")
|
||||
pg.processes["model2"].testHandler = newTestHandler("model2")
|
||||
|
||||
// Prime: run a request through model1 via the swap path so that
|
||||
// lastUsedProcess == "model1" and subsequent model1 requests take the
|
||||
// fast path.
|
||||
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
primeW := httptest.NewRecorder()
|
||||
require.NoError(t, pg.ProxyRequest("model1", primeW, primeReq))
|
||||
require.Equal(t, http.StatusOK, primeW.Code)
|
||||
require.Equal(t, StateReady, pg.processes["model1"].CurrentState())
|
||||
require.Equal(t, StateStopped, pg.processes["model2"].CurrentState())
|
||||
|
||||
// Fast-path hook: signal arrival at the race window, then wait for
|
||||
// release. This parks R2 deterministically at the point where pg.Lock
|
||||
// has been released but Process.inFlightRequests has not yet been
|
||||
// incremented — the exact window the race exploits.
|
||||
r2Reached := make(chan struct{})
|
||||
r2Release := make(chan struct{})
|
||||
pg.testDelayFastPath = func() {
|
||||
close(r2Reached)
|
||||
<-r2Release
|
||||
}
|
||||
|
||||
// R2: fast-path request for model1. Will pause at the test hook.
|
||||
r2Done := make(chan struct{})
|
||||
w2 := httptest.NewRecorder()
|
||||
go func() {
|
||||
defer close(r2Done)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
assert.NoError(t, pg.ProxyRequest("model1", w2, req))
|
||||
}()
|
||||
|
||||
// Deterministically wait for R2 to reach the race window.
|
||||
<-r2Reached
|
||||
|
||||
// R3: swap request for model2. Must wait for R2 to finish before touching
|
||||
// model1, otherwise model1 gets killed mid-request.
|
||||
r3Done := make(chan struct{})
|
||||
w3 := httptest.NewRecorder()
|
||||
go func() {
|
||||
defer close(r3Done)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
assert.NoError(t, pg.ProxyRequest("model2", w3, req))
|
||||
}()
|
||||
|
||||
// Spin until R3 has acquired pg.Lock and entered the swap critical
|
||||
// section. In the fixed code, R3 then blocks on pg.inflight.Wait() while
|
||||
// still holding the lock, so TryLock keeps failing.
|
||||
for pg.TryLock() {
|
||||
pg.Unlock()
|
||||
runtime.Gosched()
|
||||
}
|
||||
|
||||
// Bounded poll: give R3 a chance to demonstrate the bug by mutating
|
||||
// state. In the fixed code, R3 is blocked on pg.inflight.Wait() and
|
||||
// nothing changes, so we wait the full window. In the buggy code, R3
|
||||
// will Stop() model1 and start serving via model2 within microseconds —
|
||||
// we exit early once the mutation is observable.
|
||||
deadline := time.Now().Add(100 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if pg.processes["model1"].CurrentState() != StateReady ||
|
||||
pg.processes["model2"].CurrentState() != StateStopped {
|
||||
break
|
||||
}
|
||||
done := false
|
||||
select {
|
||||
case <-r3Done:
|
||||
done = true
|
||||
default:
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
runtime.Gosched()
|
||||
}
|
||||
|
||||
// Invariant: R3 must be blocked while R2 is still in flight.
|
||||
select {
|
||||
case <-r3Done:
|
||||
t.Fatal("swap completed while fast-path request was still in flight — race not prevented")
|
||||
default:
|
||||
}
|
||||
assert.Equal(t, StateReady, pg.processes["model1"].CurrentState(),
|
||||
"model1 must stay Ready while a fast-path request is in flight")
|
||||
assert.Equal(t, StateStopped, pg.processes["model2"].CurrentState(),
|
||||
"model2 must not be started until R2 finishes and model1 is swapped out")
|
||||
|
||||
// Release R2 and let both requests finish.
|
||||
close(r2Release)
|
||||
<-r2Done
|
||||
<-r3Done
|
||||
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Contains(t, w2.Body.String(), "model1")
|
||||
assert.Equal(t, http.StatusOK, w3.Code)
|
||||
assert.Contains(t, w3.Body.String(), "model2")
|
||||
}
|
||||
|
||||
// TestProcessGroup_StopProcessesWaitsForInflight verifies that StopProcesses
|
||||
// (called externally, e.g. from ProxyManager.swapProcessGroup) cannot stop a
|
||||
// process while a fast-path ProxyRequest is in the [pg.Unlock,
|
||||
// Process.inFlightRequests.Add(1)] window. Without pg.inflight.Wait() in
|
||||
// StopProcesses, the external caller bypasses the inflight guard and kills the
|
||||
// process mid-request.
|
||||
func TestProcessGroup_StopProcessesWaitsForInflight(t *testing.T) {
|
||||
cfg := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Members: []string{"model1", "model2"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
pg := NewProcessGroup("G1", cfg, testLogger, testLogger)
|
||||
defer pg.StopProcesses(StopImmediately)
|
||||
|
||||
pg.processes["model1"].testHandler = newTestHandler("model1")
|
||||
pg.processes["model2"].testHandler = newTestHandler("model2")
|
||||
|
||||
// Prime: model1 is active so subsequent model1 requests take the fast path.
|
||||
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
primeW := httptest.NewRecorder()
|
||||
require.NoError(t, pg.ProxyRequest("model1", primeW, primeReq))
|
||||
require.Equal(t, http.StatusOK, primeW.Code)
|
||||
require.Equal(t, StateReady, pg.processes["model1"].CurrentState())
|
||||
|
||||
// Park a fast-path request at the race window.
|
||||
r2Reached := make(chan struct{})
|
||||
r2Release := make(chan struct{})
|
||||
pg.testDelayFastPath = func() {
|
||||
close(r2Reached)
|
||||
<-r2Release
|
||||
}
|
||||
|
||||
r2Done := make(chan struct{})
|
||||
w2 := httptest.NewRecorder()
|
||||
go func() {
|
||||
defer close(r2Done)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
assert.NoError(t, pg.ProxyRequest("model1", w2, req))
|
||||
}()
|
||||
|
||||
<-r2Reached
|
||||
|
||||
// Simulate an external caller (e.g. ProxyManager.swapProcessGroup) stopping
|
||||
// the group while a fast-path request is in flight.
|
||||
r3Done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(r3Done)
|
||||
pg.StopProcesses(StopWaitForInflightRequest)
|
||||
}()
|
||||
|
||||
// Spin until StopProcesses has acquired pg.Lock.
|
||||
for pg.TryLock() {
|
||||
pg.Unlock()
|
||||
runtime.Gosched()
|
||||
}
|
||||
|
||||
// Bounded poll: in the fixed code StopProcesses blocks on pg.inflight.Wait()
|
||||
// and model1 stays Ready. In the buggy code it proceeds immediately and
|
||||
// kills model1.
|
||||
deadline := time.Now().Add(100 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if pg.processes["model1"].CurrentState() != StateReady {
|
||||
break
|
||||
}
|
||||
select {
|
||||
case <-r3Done:
|
||||
goto done
|
||||
default:
|
||||
}
|
||||
runtime.Gosched()
|
||||
}
|
||||
done:
|
||||
|
||||
select {
|
||||
case <-r3Done:
|
||||
t.Fatal("StopProcesses completed while a fast-path request was still in flight — race not prevented")
|
||||
default:
|
||||
}
|
||||
assert.Equal(t, StateReady, pg.processes["model1"].CurrentState(),
|
||||
"model1 must stay Ready while a fast-path request is in flight")
|
||||
|
||||
close(r2Release)
|
||||
<-r2Done
|
||||
<-r3Done
|
||||
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Contains(t, w2.Body.String(), "model1")
|
||||
}
|
||||
|
||||
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
||||
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
|
||||
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
tests := []string{"model3", "model4"}
|
||||
|
||||
for _, modelName := range tests {
|
||||
t.Run(modelName, func(t *testing.T) {
|
||||
reqBody := `{"x", "y"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelName)
|
||||
})
|
||||
}
|
||||
|
||||
// make sure all the processes are running
|
||||
for _, process := range pg.processes {
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user