Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ed77385d08 | |||
| 92b90447e8 | |||
| 62aea0e83d | |||
| 8c660dcb90 | |||
| f6877b8175 | |||
| 9b3a33d7b9 | |||
| 0cfe5a6639 | |||
| 44e1501e81 | |||
| 46cea36bc2 | |||
| ccfba0df28 | |||
| ddfae90b19 | |||
| 29d3d9ba20 | |||
| 9be9a87fa0 | |||
| 6ea551362e | |||
| 03d58e53fa | |||
| c790d0ee03 |
@@ -44,13 +44,10 @@ jobs:
|
||||
|
||||
echo "✓ config-schema.json is valid"
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 #v6.2.0
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c #6.4.0
|
||||
with:
|
||||
python-version: "3.x"
|
||||
|
||||
- name: Install check-jsonschema
|
||||
run: pip install check-jsonschema
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Validate config.example.yaml against schema
|
||||
run: check-jsonschema --schemafile config-schema.json config.example.yaml
|
||||
run: go test ./internal/config/ -run TestConfig_ExampleMatchesSchema -v
|
||||
|
||||
@@ -2,10 +2,10 @@ name: Build Containers
|
||||
|
||||
on:
|
||||
# time has no specific meaning, trying to time it after
|
||||
# the llama.cpp daily packages are published
|
||||
# the llama.cpp daily packages have time to build and publish (~8hr after llama.cpp project's cron)
|
||||
# https://github.com/ggml-org/llama.cpp/blob/master/.github/workflows/docker.yml
|
||||
schedule:
|
||||
- cron: "37 5 * * *"
|
||||
- cron: "00 12,18 * * *"
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
@@ -19,21 +19,17 @@ all: mac linux simple-responder
|
||||
clean:
|
||||
rm -rf $(BUILD_DIR)
|
||||
|
||||
proxy/ui_dist/placeholder.txt:
|
||||
mkdir -p proxy/ui_dist
|
||||
touch $@
|
||||
|
||||
# use cached test results while developing
|
||||
test-dev: proxy/ui_dist/placeholder.txt
|
||||
go test -short ./proxy/... ./internal/...
|
||||
staticcheck ./proxy/... ./internal/... || true
|
||||
test-dev:
|
||||
go test -short ./...
|
||||
staticcheck ./... || true
|
||||
|
||||
test: proxy/ui_dist/placeholder.txt
|
||||
go test -short -count=1 ./proxy/... ./internal/...
|
||||
test:
|
||||
go test -short -count=1 ./internal/...
|
||||
|
||||
# for CI - full test (takes longer)
|
||||
test-all: proxy/ui_dist/placeholder.txt
|
||||
go test -race -count=1 ./proxy/... ./internal/...
|
||||
test-all:
|
||||
go test -race -count=1 ./internal/...
|
||||
|
||||
ui/node_modules:
|
||||
cd ui-svelte && npm install
|
||||
@@ -64,7 +60,7 @@ windows: ui
|
||||
@echo "Building Windows binary..."
|
||||
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
|
||||
|
||||
# for testing proxy.Process
|
||||
# for testing with real external processes
|
||||
simple-responder:
|
||||
@echo "Building simple responder"
|
||||
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 cmd/simple-responder/simple-responder.go
|
||||
|
||||
@@ -1,249 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||
"github.com/mostlygeek/llama-swap/internal/watcher"
|
||||
"github.com/mostlygeek/llama-swap/proxy"
|
||||
)
|
||||
|
||||
var (
|
||||
version string = "0"
|
||||
commit string = "abcd1234"
|
||||
date string = "unknown"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Define a command-line flag for the port
|
||||
configPath := flag.String("config", "config.yaml", "config file name")
|
||||
listenStr := flag.String("listen", "", "listen ip/port")
|
||||
certFile := flag.String("tls-cert-file", "", "TLS certificate file")
|
||||
keyFile := flag.String("tls-key-file", "", "TLS key file")
|
||||
showVersion := flag.Bool("version", false, "show version of build")
|
||||
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
|
||||
mainLogger := logmon.New()
|
||||
|
||||
flag.Parse() // Parse the command-line flags
|
||||
|
||||
if *showVersion {
|
||||
fmt.Printf("version: %s (%s), built at %s", version, commit, date)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
conf, err := config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
mainLogger.Errorf("Error loading config: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(conf.Profiles) > 0 {
|
||||
mainLogger.Warn("Profile functionality has been removed in favor of Groups. See the README for more information.")
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(conf.LogLevel)) {
|
||||
case "debug":
|
||||
mainLogger.SetLogLevel(logmon.LevelDebug)
|
||||
case "info":
|
||||
mainLogger.SetLogLevel(logmon.LevelInfo)
|
||||
case "warn":
|
||||
mainLogger.SetLogLevel(logmon.LevelWarn)
|
||||
case "error":
|
||||
mainLogger.SetLogLevel(logmon.LevelError)
|
||||
default:
|
||||
mainLogger.SetLogLevel(logmon.LevelInfo)
|
||||
}
|
||||
|
||||
mainLogger.Debugf("PID: %d", os.Getpid())
|
||||
|
||||
if mode := os.Getenv("GIN_MODE"); mode != "" {
|
||||
gin.SetMode(mode)
|
||||
} else {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
|
||||
// Validate TLS flags.
|
||||
var useTLS = (*certFile != "" && *keyFile != "")
|
||||
if (*certFile != "" && *keyFile == "") ||
|
||||
(*certFile == "" && *keyFile != "") {
|
||||
fmt.Println("Error: Both --tls-cert-file and --tls-key-file must be provided for TLS.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Set default ports.
|
||||
if *listenStr == "" {
|
||||
defaultPort := ":8080"
|
||||
if useTLS {
|
||||
defaultPort = ":8443"
|
||||
}
|
||||
listenStr = &defaultPort
|
||||
}
|
||||
|
||||
var mon *perf.Monitor
|
||||
if !conf.Performance.Disabled {
|
||||
mon, err = perf.New(conf.Performance, mainLogger)
|
||||
if err != nil {
|
||||
mainLogger.Errorf("failed to create monitor: %s", err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
mon.Start()
|
||||
} else {
|
||||
mainLogger.Info("performance monitoring is disabled")
|
||||
}
|
||||
|
||||
// Setup channels for server management
|
||||
exitChan := make(chan struct{})
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
|
||||
|
||||
// Context that bounds the lifetime of background watcher goroutines.
|
||||
watcherCtx, watcherCancel := context.WithCancel(context.Background())
|
||||
|
||||
// Create server with initial handlergit
|
||||
srv := &http.Server{
|
||||
Addr: *listenStr,
|
||||
}
|
||||
|
||||
// Support for watching config and reloading when it changes
|
||||
reloading := false
|
||||
var reloadMutex sync.Mutex
|
||||
reloadProxyManager := func() {
|
||||
reloadMutex.Lock()
|
||||
if reloading {
|
||||
reloadMutex.Unlock()
|
||||
return
|
||||
}
|
||||
reloading = true
|
||||
reloadMutex.Unlock()
|
||||
defer func() {
|
||||
reloadMutex.Lock()
|
||||
reloading = false
|
||||
reloadMutex.Unlock()
|
||||
}()
|
||||
|
||||
if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||
mainLogger.Info("Reloading Configuration")
|
||||
conf, err = config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
mainLogger.Warnf("Unable to reload configuration: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
mainLogger.Debug("Configuration Changed")
|
||||
currentPM.Shutdown()
|
||||
if mon != nil {
|
||||
mon.UpdateConfig(conf.Performance)
|
||||
}
|
||||
newPM := proxy.New(conf)
|
||||
newPM.SetVersion(date, commit, version)
|
||||
newPM.SetPerfMonitor(mon)
|
||||
srv.Handler = newPM
|
||||
mainLogger.Debug("Configuration Reloaded")
|
||||
|
||||
// wait a few seconds and tell any UI to reload
|
||||
time.AfterFunc(3*time.Second, func() {
|
||||
event.Emit(proxy.ConfigFileChangedEvent{
|
||||
ReloadingState: proxy.ReloadingStateEnd,
|
||||
})
|
||||
})
|
||||
} else {
|
||||
conf, err = config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
mainLogger.Errorf("Unable to load configuration: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
newPM := proxy.New(conf)
|
||||
newPM.SetVersion(date, commit, version)
|
||||
newPM.SetPerfMonitor(mon)
|
||||
srv.Handler = newPM
|
||||
}
|
||||
}
|
||||
|
||||
// load the initial proxy manager
|
||||
reloadProxyManager()
|
||||
|
||||
if *watchConfig {
|
||||
go func() {
|
||||
absConfigPath, err := filepath.Abs(*configPath)
|
||||
if err != nil {
|
||||
mainLogger.Errorf("watch-config unable to determine absolute path for watching config file: %v", err)
|
||||
return
|
||||
}
|
||||
mainLogger.Info("Watching configuration for changes (poll-based, 2s interval)")
|
||||
(&configwatcher.Watcher{
|
||||
Path: absConfigPath,
|
||||
Interval: configwatcher.DefaultInterval,
|
||||
OnChange: func() {
|
||||
reloadProxyManager()
|
||||
},
|
||||
}).Run(watcherCtx)
|
||||
}()
|
||||
}
|
||||
|
||||
// Signal handling
|
||||
go func() {
|
||||
for {
|
||||
sig := <-sigChan
|
||||
switch sig {
|
||||
case syscall.SIGHUP:
|
||||
mainLogger.Debug("Received SIGHUP")
|
||||
reloadProxyManager()
|
||||
case syscall.SIGINT, syscall.SIGTERM:
|
||||
mainLogger.Debugf("Received signal %v, shutting down...", sig)
|
||||
if mon != nil {
|
||||
mon.Stop()
|
||||
}
|
||||
watcherCancel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
if pm, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||
pm.Shutdown()
|
||||
} else {
|
||||
mainLogger.Error("srv.Handler is not of type *proxy.ProxyManager")
|
||||
}
|
||||
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
mainLogger.Errorf("Server shutdown: %v", err)
|
||||
}
|
||||
close(exitChan)
|
||||
return
|
||||
default:
|
||||
// do nothing on other signals
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Start server
|
||||
go func() {
|
||||
var err error
|
||||
if useTLS {
|
||||
mainLogger.Infof("llama-swap listening with TLS on https://%s", *listenStr)
|
||||
err = srv.ListenAndServeTLS(*certFile, *keyFile)
|
||||
} else {
|
||||
mainLogger.Infof("llama-swap listening on http://%s", *listenStr)
|
||||
err = srv.ListenAndServe()
|
||||
}
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
mainLogger.Errorf("Fatal server error: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for exit signal
|
||||
<-exitChan
|
||||
}
|
||||
+209
-73
@@ -82,6 +82,78 @@
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "Timeout settings for proxy connections."
|
||||
},
|
||||
"groupsConfig": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"members"
|
||||
],
|
||||
"properties": {
|
||||
"swap": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Controls model swapping behaviour within the group. True: only one model runs at a time. False: all models can run together."
|
||||
},
|
||||
"exclusive": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Controls how the group affects other groups. True: causes all other groups to unload when this group runs a model. False: does not affect other groups."
|
||||
},
|
||||
"persistent": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Prevents other groups from unloading the models in this group. Does not affect individual model behaviour."
|
||||
},
|
||||
"members": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Array of model IDs that are members of this group. Model IDs must be defined in models."
|
||||
}
|
||||
}
|
||||
},
|
||||
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
|
||||
},
|
||||
"matrixConfig": {
|
||||
"type": "object",
|
||||
"description": "Solver-based alternative to groups. Declares valid combinations of concurrent models. The solver minimizes eviction cost when swapping. A config must use either groups or matrix, not both.",
|
||||
"required": [
|
||||
"vars",
|
||||
"sets"
|
||||
],
|
||||
"properties": {
|
||||
"vars": {
|
||||
"type": "object",
|
||||
"description": "Short names for models. Keys must be alphanumeric, 1-8 characters. All sets and evict_costs must use these IDs.",
|
||||
"minProperties": 1,
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
},
|
||||
"propertyNames": {
|
||||
"pattern": "^[a-zA-Z0-9]{1,8}$"
|
||||
}
|
||||
},
|
||||
"evict_costs": {
|
||||
"type": "object",
|
||||
"description": "Relative cost of evicting a running model. Models not listed default to 1. Values must be positive integers.",
|
||||
"additionalProperties": {
|
||||
"type": "integer",
|
||||
"minimum": 1
|
||||
}
|
||||
},
|
||||
"sets": {
|
||||
"type": "object",
|
||||
"description": "Named sets of concurrent model combinations. Values are DSL strings using & (AND), | (OR), () (grouping), and +ref (inline another set). Definition order is used for tie-breaking.",
|
||||
"minProperties": 1,
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"properties": {
|
||||
@@ -306,81 +378,68 @@
|
||||
},
|
||||
"timeouts": {
|
||||
"$ref": "#/definitions/timeouts"
|
||||
},
|
||||
"capabilities": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"in": {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"uniqueItems": true,
|
||||
"default": [],
|
||||
"items": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"text",
|
||||
"audio",
|
||||
"image"
|
||||
]
|
||||
},
|
||||
"description": "List of input modalities understood by the model."
|
||||
},
|
||||
"out": {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"uniqueItems": true,
|
||||
"default": [],
|
||||
"items": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"text",
|
||||
"audio",
|
||||
"image"
|
||||
]
|
||||
},
|
||||
"description": "List of output modalities generated by the model."
|
||||
},
|
||||
"tools": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Whether the model supports function calling."
|
||||
},
|
||||
"reranker": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Whether the model supports the /v1/rerank endpoint."
|
||||
},
|
||||
"context": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 0,
|
||||
"description": "Maximum token context length supported by the model."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "Defines what the model accepts for input, output and other metadata. Used in v1/models to inform clients what the model can do. An empty capabilities block (all zero values) is treated as not configured."
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"groups": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"members"
|
||||
],
|
||||
"properties": {
|
||||
"swap": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Controls model swapping behaviour within the group. True: only one model runs at a time. False: all models can run together."
|
||||
},
|
||||
"exclusive": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Controls how the group affects other groups. True: causes all other groups to unload when this group runs a model. False: does not affect other groups."
|
||||
},
|
||||
"persistent": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Prevents other groups from unloading the models in this group. Does not affect individual model behaviour."
|
||||
},
|
||||
"members": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Array of model IDs that are members of this group. Model IDs must be defined in models."
|
||||
}
|
||||
}
|
||||
},
|
||||
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
|
||||
"$ref": "#/definitions/groupsConfig"
|
||||
},
|
||||
"matrix": {
|
||||
"type": "object",
|
||||
"description": "Solver-based alternative to groups. Declares valid combinations of concurrent models. The solver minimizes eviction cost when swapping. A config must use either groups or matrix, not both.",
|
||||
"required": [
|
||||
"vars",
|
||||
"sets"
|
||||
],
|
||||
"properties": {
|
||||
"vars": {
|
||||
"type": "object",
|
||||
"description": "Short names for models. Keys must be alphanumeric, 1-8 characters. All sets and evict_costs must use these IDs.",
|
||||
"minProperties": 1,
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
},
|
||||
"propertyNames": {
|
||||
"pattern": "^[a-zA-Z0-9]{1,8}$"
|
||||
}
|
||||
},
|
||||
"evict_costs": {
|
||||
"type": "object",
|
||||
"description": "Relative cost of evicting a running model. Models not listed default to 1. Values must be positive integers.",
|
||||
"additionalProperties": {
|
||||
"type": "integer",
|
||||
"minimum": 1
|
||||
}
|
||||
},
|
||||
"sets": {
|
||||
"type": "object",
|
||||
"description": "Named sets of concurrent model combinations. Values are DSL strings using & (AND), | (OR), () (grouping), and +ref (inline another set). Definition order is used for tie-breaking.",
|
||||
"minProperties": 1,
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
"$ref": "#/definitions/matrixConfig"
|
||||
},
|
||||
"hooks": {
|
||||
"type": "object",
|
||||
@@ -512,28 +571,105 @@
|
||||
},
|
||||
"default": {},
|
||||
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
|
||||
},
|
||||
"routing": {
|
||||
"type": "object",
|
||||
"description": "Canonical routing/scheduling configuration. Alternative to the legacy top-level 'groups'/'matrix' keys; a config must not use both styles.",
|
||||
"properties": {
|
||||
"scheduler": {
|
||||
"type": "object",
|
||||
"description": "Scheduler configuration. Decides the order in which queued requests are serviced.",
|
||||
"properties": {
|
||||
"use": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"fifo"
|
||||
],
|
||||
"default": "fifo",
|
||||
"description": "Scheduler to use. Only 'fifo' is currently supported."
|
||||
},
|
||||
"settings": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"fifo": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"priority": {
|
||||
"type": "object",
|
||||
"description": "Per-model priority. Keys are model IDs, values are integers (default 0). Higher values are serviced first.",
|
||||
"additionalProperties": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
},
|
||||
"router": {
|
||||
"type": "object",
|
||||
"description": "Router configuration. Selects between the group and matrix swapping strategies.",
|
||||
"properties": {
|
||||
"use": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"group",
|
||||
"matrix"
|
||||
],
|
||||
"default": "group",
|
||||
"description": "Router to use. 'group' uses static groups, 'matrix' uses the solver-based swap matrix."
|
||||
},
|
||||
"settings": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"groups": {
|
||||
"$ref": "#/definitions/groupsConfig"
|
||||
},
|
||||
"matrix": {
|
||||
"$ref": "#/definitions/matrixConfig"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"allOf": [
|
||||
{
|
||||
"if": {
|
||||
"required": ["groups"]
|
||||
"required": [
|
||||
"groups"
|
||||
]
|
||||
},
|
||||
"then": {
|
||||
"not": {
|
||||
"required": ["matrix"]
|
||||
"required": [
|
||||
"matrix"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"if": {
|
||||
"required": ["matrix"]
|
||||
"required": [
|
||||
"matrix"
|
||||
]
|
||||
},
|
||||
"then": {
|
||||
"not": {
|
||||
"required": ["groups"]
|
||||
"required": [
|
||||
"groups"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
+201
-87
@@ -312,6 +312,37 @@ models:
|
||||
tlsHandshake: 10
|
||||
idleConn: 90
|
||||
|
||||
# capabilities: defines what the model accepts for input, output and other metadata
|
||||
# - optional; omitted or all-zero means no capabilities
|
||||
# - used in v1/models to inform clients what the model can do
|
||||
capabilities:
|
||||
# in: list of modalities understood by the model
|
||||
# - default: []
|
||||
# - valid: text, audio, image
|
||||
in:
|
||||
- text
|
||||
- audio
|
||||
- image
|
||||
# out: list of modalities generated by the model
|
||||
# - default: []
|
||||
# - valid: text, audio, image
|
||||
out:
|
||||
- text
|
||||
- audio
|
||||
- image
|
||||
# tools: the model supports function calling
|
||||
# - default: false
|
||||
tools: true
|
||||
|
||||
# reranker: the model supports the /v1/rerank endpoint
|
||||
# - default: false
|
||||
reranker: false
|
||||
|
||||
# context: the maximum token context length supported
|
||||
# - default: 0
|
||||
# - must be an integer > 0
|
||||
context: 32000
|
||||
|
||||
# Unlisted model example:
|
||||
"qwen-unlisted":
|
||||
# unlisted: boolean, true or false
|
||||
@@ -343,93 +374,6 @@ models:
|
||||
# - processes have 5 seconds to shutdown until forceful termination is attempted
|
||||
cmdStop: docker stop ${MODEL_ID}
|
||||
|
||||
# =============================================================================
|
||||
# matrix: run concurrent models with a solver-based swap DSL
|
||||
# =============================================================================
|
||||
#
|
||||
# Matrix or Groups?
|
||||
#
|
||||
# Groups are available and fully supported. The syntax may be easier to use
|
||||
# for simple use cases.
|
||||
#
|
||||
# Documentation can be found here:
|
||||
# https://github.com/mostlygeek/llama-swap/blob/40e39f7/config.example.yaml#L334-L396
|
||||
#
|
||||
# A config can only use a matrix (recommended) or groups. A configuration error
|
||||
# will occur if both are defined. Groups is legacy but is fully supported with
|
||||
# no plans to deprecate it.
|
||||
#
|
||||
# ~~~~~
|
||||
#
|
||||
# The matrix declares valid combinations of models that can run concurrently.
|
||||
# When a model is requested, the solver finds the cheapest way to make it
|
||||
# available by evicting as few (and least costly) running models as possible.
|
||||
#
|
||||
# Solver behavior:
|
||||
# 1. Request arrives for model X
|
||||
# 2. If X is already running, forward immediately. Done.
|
||||
# 3. Find all sets containing X
|
||||
# 4. For each candidate set, compute cost: sum of evict_costs for
|
||||
# every running model NOT in that set
|
||||
# 5. Pick lowest cost candidate. Ties broken by definition order.
|
||||
# 6. Evict what needs to stop. Start X. Forward request.
|
||||
#
|
||||
# Subset semantics: a set [a, b, c] means any subset is valid.
|
||||
# Only the requested model is started — others are not preloaded.
|
||||
#
|
||||
# A model not appearing in any set can only run alone.
|
||||
#
|
||||
matrix:
|
||||
# vars: short names for models (alphanumeric, 1-8 chars)
|
||||
# - required for sets and evict_costs settings
|
||||
# - each entry is a short name to a real model ID. Do not use an alias
|
||||
# - used to keep set DSL logic short and easier to read
|
||||
# - sets and evict_costs only use identifiers defined in vars
|
||||
vars:
|
||||
g: gemma-model
|
||||
q: qwen-model
|
||||
m: mistral-model
|
||||
v: voxtral-model
|
||||
e: reranker-model
|
||||
L: llama-70B
|
||||
sd: stable-diffusion
|
||||
|
||||
# evict_costs: relative cost of losing a running model (default: 1)
|
||||
evict_costs:
|
||||
v: 50 # vllm backend, slow cold start
|
||||
L: 30 # 70B weights, slow to load
|
||||
|
||||
# sets: named sets of concurrent model combinations
|
||||
# Values are DSL strings with operators:
|
||||
# & AND (models run together)
|
||||
# | OR (alternatives)
|
||||
# () grouping
|
||||
# +ref inline another set's expression
|
||||
#
|
||||
# Expansion examples:
|
||||
# "L" → [L]
|
||||
# "a & b" → [a, b]
|
||||
# "a | b" → [a], [b]
|
||||
# "(a | b) & c" → [a, c], [b, c]
|
||||
# "(a | b) & (c | d)" → [a,c], [a,d], [b,c], [b,d]
|
||||
# "+llms & v" → expands llms inline, then applies & v
|
||||
sets:
|
||||
# LLM + TTS: switching between g/q/m won't evict v
|
||||
# expands to: [g,v], [q,v], [m,v]
|
||||
standard: "(g | q | m) & v"
|
||||
|
||||
# LLM + TTS + reranker
|
||||
# expands to: [g,v,e], [q,v,e]
|
||||
with_rerank: "(g | q) & v & e"
|
||||
|
||||
# LLM + image generation, no TTS
|
||||
# expands to: [g,sd], [q,sd]
|
||||
creative: "(g | q) & sd"
|
||||
|
||||
# 70B model uses all GPUs, can only run alone
|
||||
# expands to: [L]
|
||||
full: "L"
|
||||
|
||||
# hooks: a dictionary of event triggers and actions
|
||||
# - optional, default: empty dictionary
|
||||
# - the only supported hook is on_startup
|
||||
@@ -446,6 +390,176 @@ hooks:
|
||||
preload:
|
||||
- "llama"
|
||||
|
||||
# routing:
|
||||
# Controls how llama-swap decides which models can run at the same time and
|
||||
# which get swapped out. Choose one of two swap engines:
|
||||
#
|
||||
# - group: the default engine. Simpler to configure. You define groups of
|
||||
# models that run together, and loading one group typically unloads
|
||||
# the others.
|
||||
#
|
||||
# - matrix: the newer engine. More involved to configure, but far more
|
||||
# flexible. It uses a small expression language to describe which
|
||||
# model combinations are allowed to run concurrently, enabling
|
||||
# setups that groups cannot express.
|
||||
#
|
||||
# The routing section is optional.
|
||||
routing:
|
||||
router:
|
||||
# use: a string defining which engine to use
|
||||
# - optional, default: "group"
|
||||
# - valid values: group, matrix
|
||||
use: group
|
||||
|
||||
# settings: a dictionary of settings for the specific engines
|
||||
settings:
|
||||
# groups: a dictionary of named groups
|
||||
# - optional, default: empty dictionary
|
||||
# - lets you keep some models loaded while others swap out
|
||||
# - every member must be a model ID defined in the models section
|
||||
# - a model can belong to only one group
|
||||
# - behaviour is set per group with the `swap`, `exclusive` and
|
||||
# `persistent` fields
|
||||
# - see issue #109 for details
|
||||
#
|
||||
# NOTE: the model names below are illustrative and are not defined above.
|
||||
groups:
|
||||
# group1 reproduces llama-swap's default behaviour: only one model
|
||||
# runs at a time across the entire instance.
|
||||
"group1":
|
||||
# swap: how members of this group swap among themselves
|
||||
# - optional, default: true
|
||||
# - true: only one member runs at a time
|
||||
# - false: all members can run together, no swapping
|
||||
swap: true
|
||||
|
||||
# exclusive: how this group affects other groups
|
||||
# - optional, default: true
|
||||
# - true: running a member unloads every other group
|
||||
# - false: running a member leaves other groups untouched
|
||||
exclusive: true
|
||||
|
||||
# members: the model IDs in this group
|
||||
# required
|
||||
members:
|
||||
- "llama"
|
||||
- "qwen-unlisted"
|
||||
|
||||
# group2: members all run together, but loading any other group
|
||||
# unloads them.
|
||||
"group2":
|
||||
# swap: false lets all members stay loaded at once
|
||||
swap: false
|
||||
|
||||
# exclusive: false means requesting a member loads it without
|
||||
# unloading any other group
|
||||
exclusive: false
|
||||
members:
|
||||
- "docker-llama"
|
||||
- "modelA"
|
||||
- "modelB"
|
||||
|
||||
# forever: a persistent group that other groups can never unload.
|
||||
"forever":
|
||||
# persistent: other groups cannot unload this group's members
|
||||
# - optional, default: false
|
||||
# - has no effect on swapping within the group
|
||||
persistent: true
|
||||
|
||||
# swap/exclusive: false keeps all members loaded and avoids
|
||||
# unloading other groups
|
||||
swap: false
|
||||
exclusive: false
|
||||
members:
|
||||
- "forever-modelA"
|
||||
- "forever-modelB"
|
||||
- "forever-modelc"
|
||||
|
||||
# The matrix lists the model combinations that are allowed to run
|
||||
# concurrently. When a model is requested, the solver makes room for it
|
||||
# by evicting as few running models as possible, preferring to keep the
|
||||
# costliest ones loaded.
|
||||
#
|
||||
# Solver behaviour:
|
||||
# 1. A request arrives for model X.
|
||||
# 2. If X is already running, forward the request. Done.
|
||||
# 3. Collect every set that contains X.
|
||||
# 4. For each set, add up the evict_costs of the running models that
|
||||
# are NOT in that set — that is the set's cost.
|
||||
# 5. Choose the lowest-cost set. Break ties by definition order.
|
||||
# 6. Evict the models outside that set, start X, forward the request.
|
||||
#
|
||||
# Subset semantics: a set [a, b, c] also permits any subset of itself.
|
||||
# Only the requested model is started; the others are not preloaded.
|
||||
#
|
||||
# A model that appears in no set can only run on its own.
|
||||
#
|
||||
matrix:
|
||||
# vars: short aliases for model IDs (alphanumeric, 1-8 chars)
|
||||
# - required: sets and evict_costs reference these names, not model IDs
|
||||
# - map each short name to a real model ID (not a model alias)
|
||||
# - keeps the set expressions short and readable
|
||||
vars:
|
||||
g: gemma-model
|
||||
q: qwen-model
|
||||
m: mistral-model
|
||||
v: voxtral-model
|
||||
e: reranker-model
|
||||
L: llama-70B
|
||||
sd: stable-diffusion
|
||||
|
||||
# evict_costs: relative cost of losing a running model (default: 1)
|
||||
evict_costs:
|
||||
v: 50 # vllm backend, slow cold start
|
||||
L: 30 # 70B weights, slow to load
|
||||
|
||||
# sets: named combinations of models that may run together.
|
||||
# Each value is an expression built from these operators:
|
||||
# & AND (models run together)
|
||||
# | OR (alternatives)
|
||||
# () grouping
|
||||
# +ref inline the expression of another set
|
||||
#
|
||||
# Each expression expands into one or more concrete sets:
|
||||
# "L" → [L]
|
||||
# "a & b" → [a, b]
|
||||
# "a | b" → [a], [b]
|
||||
# "(a | b) & c" → [a, c], [b, c]
|
||||
# "(a | b) & (c | d)" → [a,c], [a,d], [b,c], [b,d]
|
||||
# "+llms & v" → inline the llms set, then AND with v
|
||||
sets:
|
||||
# An LLM plus TTS. Switching between g/q/m keeps v loaded.
|
||||
# expands to: [g,v], [q,v], [m,v]
|
||||
standard: "(g | q | m) & v"
|
||||
|
||||
# An LLM plus TTS plus reranker.
|
||||
# expands to: [g,v,e], [q,v,e]
|
||||
with_rerank: "(g | q) & v & e"
|
||||
|
||||
# An LLM plus image generation, no TTS.
|
||||
# expands to: [g,sd], [q,sd]
|
||||
creative: "(g | q) & sd"
|
||||
|
||||
# The 70B model uses every GPU, so it can only run alone.
|
||||
# expands to: [L]
|
||||
full: "L"
|
||||
|
||||
# scheduler: how queued requests are ordered.
|
||||
# The default and only valid scheduler is "fifo"
|
||||
scheduler:
|
||||
use: fifo
|
||||
settings:
|
||||
fifo:
|
||||
# priority: a dictionary of model ID -> priority
|
||||
# - optional, default: empty dictionary
|
||||
# - models default to priority 0
|
||||
# - higher priority requests are serviced first in the queue
|
||||
priority:
|
||||
A: 10
|
||||
B: 5
|
||||
C: 5
|
||||
D: 1
|
||||
|
||||
# peers: a dictionary of remote peers and models they provide
|
||||
# - optional, default empty dictionary
|
||||
# - peers can be another llama-swap
|
||||
|
||||
@@ -2,10 +2,6 @@ ARG BASE_IMAGE=ghcr.io/ggml-org/llama.cpp
|
||||
ARG BASE_TAG=server-cuda
|
||||
FROM ${BASE_IMAGE}:${BASE_TAG}
|
||||
|
||||
# has to be after the FROM
|
||||
# TARGETARCH is auto-set by `docker buildx build --platform …` (amd64/arm64);
|
||||
# falls back to amd64 when an older `docker build` runs without buildx.
|
||||
ARG TARGETARCH=amd64
|
||||
ARG LS_VER=170
|
||||
ARG LS_REPO=mostlygeek/llama-swap
|
||||
|
||||
@@ -37,9 +33,15 @@ WORKDIR /app
|
||||
ENV PATH="/app:${PATH}"
|
||||
|
||||
RUN \
|
||||
curl -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_${TARGETARCH}.tar.gz" && \
|
||||
tar -zxf "llama-swap_${LS_VER}_linux_${TARGETARCH}.tar.gz" && \
|
||||
rm "llama-swap_${LS_VER}_linux_${TARGETARCH}.tar.gz"
|
||||
set -eux; \
|
||||
case "$(uname -m)" in \
|
||||
x86_64) ARCH=amd64 ;; \
|
||||
aarch64) ARCH=arm64 ;; \
|
||||
*) echo "unsupported arch: $(uname -m)" >&2; exit 1 ;; \
|
||||
esac; \
|
||||
curl --fail -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_${ARCH}.tar.gz" && \
|
||||
tar -zxf "llama-swap_${LS_VER}_linux_${ARCH}.tar.gz" && \
|
||||
rm "llama-swap_${LS_VER}_linux_${ARCH}.tar.gz"
|
||||
|
||||
COPY --chown=$UID:$GID config.example.yaml /app/config.yaml
|
||||
|
||||
|
||||
@@ -9,12 +9,14 @@ require (
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/fxamacker/cbor/v2 v2.9.1
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/google/jsonschema-go v0.4.3
|
||||
github.com/klauspost/compress v1.18.5
|
||||
github.com/shirou/gopsutil/v4 v4.26.4
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
golang.org/x/sync v0.20.0
|
||||
golang.org/x/sys v0.41.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
@@ -70,7 +72,6 @@ require (
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
)
|
||||
|
||||
@@ -61,6 +61,8 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0=
|
||||
github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE=
|
||||
|
||||
@@ -129,13 +129,16 @@ type Config struct {
|
||||
GlobalTTL int `yaml:"globalTTL"`
|
||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||
|
||||
// swap matrix: solver-based alternative to groups
|
||||
Matrix *MatrixConfig `yaml:"matrix"`
|
||||
// routing is the canonical source for swap/scheduling configuration.
|
||||
// New code must read Routing, never the backwards-compat fields below.
|
||||
Routing RoutingConfig `yaml:"routing"`
|
||||
|
||||
// populated during validation when matrix is configured
|
||||
ExpandedSets []ExpandedSet `yaml:"-"`
|
||||
// Groups and Matrix are permanent backwards-compat input fields for the
|
||||
// legacy top-level `groups:`/`matrix:` keys. They are normalized into
|
||||
// Routing by LoadConfigFromReader. New code must not read them directly.
|
||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||
Matrix *MatrixConfig `yaml:"matrix"`
|
||||
|
||||
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
||||
Macros MacroList `yaml:"macros"`
|
||||
@@ -162,6 +165,35 @@ type Config struct {
|
||||
Peers PeerDictionaryConfig `yaml:"peers"`
|
||||
}
|
||||
|
||||
// RoutingConfig is the canonical, normalized routing/scheduling configuration.
|
||||
type RoutingConfig struct {
|
||||
Scheduler SchedulerConfig `yaml:"scheduler"`
|
||||
Router RouterConfig `yaml:"router"`
|
||||
}
|
||||
|
||||
type SchedulerConfig struct {
|
||||
Use string `yaml:"use"` // default "fifo"
|
||||
Settings SchedulerSettings `yaml:"settings"`
|
||||
}
|
||||
|
||||
type SchedulerSettings struct {
|
||||
Fifo FifoConfig `yaml:"fifo"`
|
||||
}
|
||||
|
||||
type FifoConfig struct {
|
||||
Priority map[string]int `yaml:"priority"` // model ID -> priority, default 0
|
||||
}
|
||||
|
||||
type RouterConfig struct {
|
||||
Use string `yaml:"use"` // "group" (default) | "matrix"
|
||||
Settings RouterSettings `yaml:"settings"`
|
||||
}
|
||||
|
||||
type RouterSettings struct {
|
||||
Groups map[string]GroupConfig `yaml:"groups"`
|
||||
Matrix *MatrixConfig `yaml:"matrix"`
|
||||
}
|
||||
|
||||
func (c *Config) RealModelName(search string) (string, bool) {
|
||||
if _, found := c.Models[search]; found {
|
||||
return search, true
|
||||
@@ -415,6 +447,10 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if err = modelConfig.Capabilities.Validate(); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %w", modelId, err)
|
||||
}
|
||||
|
||||
// Validate SetParamsByID keys and values
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
|
||||
@@ -455,6 +491,34 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
config.Models[modelId] = modelConfig
|
||||
}
|
||||
|
||||
// Normalize routing config. The legacy top-level `matrix`/`groups` keys and
|
||||
// the new `routing.router` block are mutually exclusive: a config may use
|
||||
// either style, never both.
|
||||
hasTopLevel := config.Matrix != nil || len(config.Groups) > 0
|
||||
rtr := config.Routing.Router
|
||||
hasRouting := rtr.Use != "" || rtr.Settings.Matrix != nil || len(rtr.Settings.Groups) > 0
|
||||
|
||||
if hasTopLevel && hasRouting {
|
||||
return Config{}, fmt.Errorf("config uses both the legacy top-level 'matrix'/'groups' keys and the new 'routing.router' block; please migrate the top-level keys into 'routing.router' and remove them")
|
||||
}
|
||||
|
||||
if !hasTopLevel {
|
||||
// Both groups and matrix may be defined under routing.router.settings;
|
||||
// routing.router.use selects which one is active, so there is no conflict.
|
||||
rs := config.Routing.Router.Settings
|
||||
switch config.Routing.Router.Use {
|
||||
case "matrix":
|
||||
if rs.Matrix == nil {
|
||||
return Config{}, fmt.Errorf("routing.router.use is 'matrix' but routing.router.settings.matrix is not set")
|
||||
}
|
||||
config.Matrix = rs.Matrix
|
||||
case "group", "":
|
||||
config.Groups = rs.Groups
|
||||
default:
|
||||
return Config{}, fmt.Errorf("routing.router.use: unknown router %q (valid: group, matrix)", config.Routing.Router.Use)
|
||||
}
|
||||
}
|
||||
|
||||
// groups XOR matrix
|
||||
if config.Matrix != nil && len(config.Groups) > 0 {
|
||||
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
|
||||
@@ -465,7 +529,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("matrix: %w", err)
|
||||
}
|
||||
config.ExpandedSets = expandedSets
|
||||
config.Matrix.ExpandedSets = expandedSets
|
||||
} else {
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
|
||||
@@ -487,6 +551,29 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Build the canonical Config.Routing from the effective result. Both legacy
|
||||
// and new-style configs converge here. The Matrix pointer is shared so
|
||||
// ExpandedSets stays in one place.
|
||||
if config.Matrix != nil {
|
||||
config.Routing.Router.Use = "matrix"
|
||||
} else {
|
||||
config.Routing.Router.Use = "group"
|
||||
}
|
||||
config.Routing.Router.Settings.Matrix = config.Matrix
|
||||
config.Routing.Router.Settings.Groups = config.Groups
|
||||
|
||||
if config.Routing.Scheduler.Use == "" {
|
||||
config.Routing.Scheduler.Use = "fifo"
|
||||
}
|
||||
if config.Routing.Scheduler.Use != "fifo" {
|
||||
return Config{}, fmt.Errorf("routing.scheduler.use: unknown scheduler %q (valid: fifo)", config.Routing.Scheduler.Use)
|
||||
}
|
||||
for modelID := range config.Routing.Scheduler.Settings.Fifo.Priority {
|
||||
if _, found := config.RealModelName(modelID); !found {
|
||||
return Config{}, fmt.Errorf("routing.scheduler.settings.fifo.priority references unknown model %q", modelID)
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up hooks preload
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
var toPreload []string
|
||||
|
||||
@@ -173,6 +173,25 @@ groups:
|
||||
IdleConn: 90,
|
||||
}
|
||||
|
||||
expectedGroups := map[string]GroupConfig{
|
||||
DEFAULT_GROUP_ID: {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model3"},
|
||||
},
|
||||
"group1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model4"},
|
||||
},
|
||||
}
|
||||
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
@@ -246,22 +265,16 @@ groups:
|
||||
"m2": "model2",
|
||||
"mthree": "model3",
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
DEFAULT_GROUP_ID: {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model3"},
|
||||
Groups: expectedGroups,
|
||||
Routing: RoutingConfig{
|
||||
Router: RouterConfig{
|
||||
Use: "group",
|
||||
Settings: RouterSettings{
|
||||
Groups: expectedGroups,
|
||||
},
|
||||
},
|
||||
"group1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model4"},
|
||||
Scheduler: SchedulerConfig{
|
||||
Use: "fifo",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/google/jsonschema-go/jsonschema"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// TestConfig_ExampleMatchesSchema validates that config.example.yaml conforms to
|
||||
// config-schema.json. Both files live at the repository root.
|
||||
func TestConfig_ExampleMatchesSchema(t *testing.T) {
|
||||
const (
|
||||
schemaPath = "../../config-schema.json"
|
||||
examplePath = "../../config.example.yaml"
|
||||
)
|
||||
|
||||
schemaBytes, err := os.ReadFile(schemaPath)
|
||||
if err != nil {
|
||||
t.Fatalf("reading %s: %v", schemaPath, err)
|
||||
}
|
||||
|
||||
var schema jsonschema.Schema
|
||||
if err := json.Unmarshal(schemaBytes, &schema); err != nil {
|
||||
t.Fatalf("unmarshalling schema: %v", err)
|
||||
}
|
||||
|
||||
resolved, err := schema.Resolve(&jsonschema.ResolveOptions{
|
||||
BaseURI: "https://github.com/mostlygeek/llama-swap/",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("resolving schema: %v", err)
|
||||
}
|
||||
|
||||
exampleBytes, err := os.ReadFile(examplePath)
|
||||
if err != nil {
|
||||
t.Fatalf("reading %s: %v", examplePath, err)
|
||||
}
|
||||
|
||||
// Convert YAML to a JSON-like value so numbers and keys match what the
|
||||
// validator expects.
|
||||
var yamlValue any
|
||||
if err := yaml.Unmarshal(exampleBytes, &yamlValue); err != nil {
|
||||
t.Fatalf("unmarshalling example yaml: %v", err)
|
||||
}
|
||||
jsonBytes, err := json.Marshal(yamlValue)
|
||||
if err != nil {
|
||||
t.Fatalf("converting example to json: %v", err)
|
||||
}
|
||||
var instance any
|
||||
if err := json.Unmarshal(jsonBytes, &instance); err != nil {
|
||||
t.Fatalf("unmarshalling example json: %v", err)
|
||||
}
|
||||
|
||||
if err := resolved.Validate(instance); err != nil {
|
||||
t.Fatalf("config.example.yaml does not match config-schema.json:\n%v", err)
|
||||
}
|
||||
}
|
||||
@@ -1544,3 +1544,174 @@ peers:
|
||||
assert.Equal(t, 1, peerConfig.Timeouts.ExpectContinue)
|
||||
assert.Equal(t, 90, peerConfig.Timeouts.IdleConn)
|
||||
}
|
||||
|
||||
// twoModels is a minimal models block reused by the routing tests below.
|
||||
const twoModels = `
|
||||
models:
|
||||
gemma:
|
||||
cmd: echo gemma
|
||||
proxy: http://localhost:8080
|
||||
qwen:
|
||||
cmd: echo qwen
|
||||
proxy: http://localhost:8081
|
||||
`
|
||||
|
||||
func TestConfig_Routing_LegacyTopLevelGroups(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
groups:
|
||||
g1:
|
||||
members: [gemma, qwen]
|
||||
`
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
||||
// default group injected for orphaned models (none here) still leaves g1
|
||||
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
|
||||
assert.Equal(t, "fifo", cfg.Routing.Scheduler.Use)
|
||||
}
|
||||
|
||||
func TestConfig_Routing_LegacyTopLevelMatrix(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
matrix:
|
||||
vars:
|
||||
g: gemma
|
||||
q: qwen
|
||||
sets:
|
||||
combo: "g | q"
|
||||
`
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "matrix", cfg.Routing.Router.Use)
|
||||
require.NotNil(t, cfg.Routing.Router.Settings.Matrix)
|
||||
assert.Len(t, cfg.Routing.Router.Settings.Matrix.ExpandedSets, 2)
|
||||
}
|
||||
|
||||
func TestConfig_Routing_RouterUseMatrix(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
routing:
|
||||
router:
|
||||
use: matrix
|
||||
settings:
|
||||
matrix:
|
||||
vars:
|
||||
g: gemma
|
||||
q: qwen
|
||||
sets:
|
||||
combo: "g | q"
|
||||
`
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "matrix", cfg.Routing.Router.Use)
|
||||
require.NotNil(t, cfg.Routing.Router.Settings.Matrix)
|
||||
assert.Len(t, cfg.Routing.Router.Settings.Matrix.ExpandedSets, 2)
|
||||
}
|
||||
|
||||
func TestConfig_Routing_RouterUseGroup(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
routing:
|
||||
router:
|
||||
use: group
|
||||
settings:
|
||||
groups:
|
||||
g1:
|
||||
members: [gemma, qwen]
|
||||
`
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
||||
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
|
||||
}
|
||||
|
||||
func TestConfig_Routing_DefaultsToGroup(t *testing.T) {
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(twoModels))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
||||
assert.Equal(t, "fifo", cfg.Routing.Scheduler.Use)
|
||||
}
|
||||
|
||||
func TestConfig_Routing_LegacyAndRoutingConflict(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
groups:
|
||||
g1:
|
||||
members: [gemma, qwen]
|
||||
routing:
|
||||
router:
|
||||
use: group
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "migrate")
|
||||
}
|
||||
|
||||
func TestConfig_Routing_RouterUseMatrixWithoutSettings(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
routing:
|
||||
router:
|
||||
use: matrix
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "routing.router.settings.matrix is not set")
|
||||
}
|
||||
|
||||
// Both groups and matrix may be defined under routing.router.settings;
|
||||
// routing.router.use selects which one is active.
|
||||
func TestConfig_Routing_RouterSettingsBothGroupsAndMatrix(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
routing:
|
||||
router:
|
||||
use: group
|
||||
settings:
|
||||
groups:
|
||||
g1:
|
||||
members: [gemma, qwen]
|
||||
matrix:
|
||||
sets:
|
||||
s: "gemma"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
// use: group means groups are active and matrix is ignored
|
||||
assert.Equal(t, "group", config.Routing.Router.Use)
|
||||
assert.Nil(t, config.Matrix)
|
||||
assert.Contains(t, config.Groups, "g1")
|
||||
}
|
||||
|
||||
func TestConfig_Routing_UnknownRouter(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
routing:
|
||||
router:
|
||||
use: bogus
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown router")
|
||||
}
|
||||
|
||||
func TestConfig_Routing_FifoPriorityUnknownModel(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
routing:
|
||||
scheduler:
|
||||
settings:
|
||||
fifo:
|
||||
priority:
|
||||
nope: 5
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown model")
|
||||
}
|
||||
|
||||
func TestConfig_Routing_FifoPriorityKnownModel(t *testing.T) {
|
||||
yaml := twoModels + `
|
||||
routing:
|
||||
scheduler:
|
||||
settings:
|
||||
fifo:
|
||||
priority:
|
||||
gemma: 5
|
||||
`
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 5, cfg.Routing.Scheduler.Settings.Fifo.Priority["gemma"])
|
||||
}
|
||||
|
||||
@@ -165,6 +165,25 @@ groups:
|
||||
IdleConn: 90,
|
||||
}
|
||||
|
||||
expectedGroups := map[string]GroupConfig{
|
||||
DEFAULT_GROUP_ID: {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model3"},
|
||||
},
|
||||
"group1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model4"},
|
||||
},
|
||||
}
|
||||
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
@@ -235,22 +254,16 @@ groups:
|
||||
"m2": "model2",
|
||||
"mthree": "model3",
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
DEFAULT_GROUP_ID: {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model3"},
|
||||
Groups: expectedGroups,
|
||||
Routing: RoutingConfig{
|
||||
Router: RouterConfig{
|
||||
Use: "group",
|
||||
Settings: RouterSettings{
|
||||
Groups: expectedGroups,
|
||||
},
|
||||
},
|
||||
"group1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model4"},
|
||||
Scheduler: SchedulerConfig{
|
||||
Use: "fifo",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -15,6 +15,9 @@ type MatrixConfig struct {
|
||||
Var map[string]string `yaml:"vars"`
|
||||
EvictCosts map[string]int `yaml:"evict_costs"`
|
||||
Sets OrderedSets `yaml:"sets"`
|
||||
|
||||
// populated by ValidateMatrix; not settable from yaml
|
||||
ExpandedSets []ExpandedSet `yaml:"-"`
|
||||
}
|
||||
|
||||
// SetEntry is a single named set with its DSL expression.
|
||||
|
||||
@@ -289,7 +289,9 @@ matrix:
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, cfg.Matrix)
|
||||
assert.Len(t, cfg.ExpandedSets, 2)
|
||||
assert.Len(t, cfg.Matrix.ExpandedSets, 2)
|
||||
assert.Equal(t, "matrix", cfg.Routing.Router.Use)
|
||||
assert.Len(t, cfg.Routing.Router.Settings.Matrix.ExpandedSets, 2)
|
||||
// Groups should be empty when matrix is used
|
||||
assert.Empty(t, cfg.Groups)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
@@ -9,6 +10,47 @@ const (
|
||||
MODEL_CONFIG_DEFAULT_TTL = -1
|
||||
)
|
||||
|
||||
var validModalities = map[string]struct{}{
|
||||
"text": {},
|
||||
"audio": {},
|
||||
"image": {},
|
||||
}
|
||||
|
||||
// ModelCapConfig defines what modalities and features a model supports.
|
||||
// Used in /v1/models to inform clients. An empty block (all zero values) is
|
||||
// treated as not configured.
|
||||
type ModelCapConfig struct {
|
||||
In []string `yaml:"in"`
|
||||
Out []string `yaml:"out"`
|
||||
Tools bool `yaml:"tools"`
|
||||
Reranker bool `yaml:"reranker"`
|
||||
Context int `yaml:"context"`
|
||||
}
|
||||
|
||||
// Empty returns true when all fields are at their zero values.
|
||||
func (c ModelCapConfig) Empty() bool {
|
||||
return len(c.In) == 0 && len(c.Out) == 0 && !c.Tools && !c.Reranker && c.Context == 0
|
||||
}
|
||||
|
||||
// Validate checks that all modality values are recognized and context is
|
||||
// non-negative. Returns an error if any value is invalid.
|
||||
func (c ModelCapConfig) Validate() error {
|
||||
for _, m := range c.In {
|
||||
if _, ok := validModalities[m]; !ok {
|
||||
return fmt.Errorf("capabilities.in: invalid modality %q, must be one of: text, audio, image", m)
|
||||
}
|
||||
}
|
||||
for _, m := range c.Out {
|
||||
if _, ok := validModalities[m]; !ok {
|
||||
return fmt.Errorf("capabilities.out: invalid modality %q, must be one of: text, audio, image", m)
|
||||
}
|
||||
}
|
||||
if c.Context < 0 {
|
||||
return errors.New("capabilities.context: must be >= 0")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TimeoutsConfig holds timeout settings for proxy connections
|
||||
// 0 = no timeout
|
||||
type TimeoutsConfig struct {
|
||||
@@ -55,6 +97,9 @@ type ModelConfig struct {
|
||||
// Timeout settings for proxy connections
|
||||
Timeouts TimeoutsConfig `yaml:"timeouts"`
|
||||
|
||||
// Capabilities defines what modalities and features the model supports.
|
||||
Capabilities ModelCapConfig `yaml:"capabilities"`
|
||||
|
||||
// Copy of HealthCheckTimeout from global config
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
}
|
||||
|
||||
@@ -152,7 +152,7 @@ models:
|
||||
stop:
|
||||
- "<|end|>"
|
||||
- "<|stop|>"
|
||||
`
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -170,3 +170,167 @@ models:
|
||||
assert.Equal(t, 0.7, setParams["temperature"])
|
||||
assert.Equal(t, 0.9, setParams["top_p"])
|
||||
}
|
||||
|
||||
func TestConfig_ModelCapabilities(t *testing.T) {
|
||||
t.Run("all fields", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
capabilities:
|
||||
in:
|
||||
- text
|
||||
- audio
|
||||
- image
|
||||
out:
|
||||
- text
|
||||
- audio
|
||||
- image
|
||||
tools: true
|
||||
context: 32000
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
mc := config.Models["model1"]
|
||||
assert.False(t, mc.Capabilities.Empty())
|
||||
assert.Equal(t, []string{"text", "audio", "image"}, mc.Capabilities.In)
|
||||
assert.Equal(t, []string{"text", "audio", "image"}, mc.Capabilities.Out)
|
||||
assert.True(t, mc.Capabilities.Tools)
|
||||
assert.Equal(t, 32000, mc.Capabilities.Context)
|
||||
})
|
||||
|
||||
t.Run("partial fields", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
capabilities:
|
||||
tools: true
|
||||
context: 8192
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
mc := config.Models["model1"]
|
||||
assert.False(t, mc.Capabilities.Empty())
|
||||
assert.Nil(t, mc.Capabilities.In)
|
||||
assert.Nil(t, mc.Capabilities.Out)
|
||||
assert.True(t, mc.Capabilities.Tools)
|
||||
assert.Equal(t, 8192, mc.Capabilities.Context)
|
||||
})
|
||||
|
||||
t.Run("not set", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
mc := config.Models["model1"]
|
||||
assert.True(t, mc.Capabilities.Empty())
|
||||
})
|
||||
|
||||
t.Run("tools false is empty", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
capabilities:
|
||||
tools: false
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
mc := config.Models["model1"]
|
||||
assert.True(t, mc.Capabilities.Empty())
|
||||
})
|
||||
|
||||
t.Run("reranker true is not empty", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
capabilities:
|
||||
reranker: true
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
mc := config.Models["model1"]
|
||||
assert.False(t, mc.Capabilities.Empty())
|
||||
assert.True(t, mc.Capabilities.Reranker)
|
||||
})
|
||||
|
||||
t.Run("reranker false is empty", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
capabilities:
|
||||
reranker: false
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
mc := config.Models["model1"]
|
||||
assert.True(t, mc.Capabilities.Empty())
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_ModelCapabilities_Validate(t *testing.T) {
|
||||
t.Run("valid_modalities", func(t *testing.T) {
|
||||
caps := ModelCapConfig{
|
||||
In: []string{"text", "image"},
|
||||
Out: []string{"text", "audio"},
|
||||
Tools: true,
|
||||
Context: 100000,
|
||||
}
|
||||
assert.NoError(t, caps.Validate())
|
||||
})
|
||||
|
||||
t.Run("empty_is_valid", func(t *testing.T) {
|
||||
caps := ModelCapConfig{}
|
||||
assert.NoError(t, caps.Validate())
|
||||
})
|
||||
|
||||
t.Run("invalid_in_modality", func(t *testing.T) {
|
||||
caps := ModelCapConfig{In: []string{"video"}}
|
||||
err := caps.Validate()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "capabilities.in")
|
||||
assert.Contains(t, err.Error(), "video")
|
||||
})
|
||||
|
||||
t.Run("invalid_out_modality", func(t *testing.T) {
|
||||
caps := ModelCapConfig{Out: []string{"video"}}
|
||||
err := caps.Validate()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "capabilities.out")
|
||||
assert.Contains(t, err.Error(), "video")
|
||||
})
|
||||
|
||||
t.Run("negative_context", func(t *testing.T) {
|
||||
caps := ModelCapConfig{Context: -1}
|
||||
err := caps.Validate()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "capabilities.context")
|
||||
})
|
||||
|
||||
t.Run("rejects_invalid_at_load", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
capabilities:
|
||||
in:
|
||||
- text
|
||||
- video
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "video")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package perf
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -43,3 +47,168 @@ func ParseNvidiaSmiLine(line string) *GpuStat {
|
||||
PowerDrawW: powerDraw,
|
||||
}
|
||||
}
|
||||
|
||||
// mactopOutput maps the subset of mactop's headless JSON output that is
|
||||
// relevant to GpuStat. Note that mactop's memory object is whole-system memory,
|
||||
// not GPU-attributed; the darwin monitor overlays ioreg's GPU-attributed
|
||||
// unified memory (see overlayIoregMem) so both backends report consistent
|
||||
// memory figures.
|
||||
type mactopOutput struct {
|
||||
SocMetrics struct {
|
||||
GPUPower float64 `json:"gpu_power"`
|
||||
GPUFreq int `json:"gpu_freq_mhz"`
|
||||
GPUTemp float64 `json:"gpu_temp"`
|
||||
} `json:"soc_metrics"`
|
||||
Memory struct {
|
||||
Total uint64 `json:"total"`
|
||||
Used uint64 `json:"used"`
|
||||
} `json:"memory"`
|
||||
GPUUsage float64 `json:"gpu_usage"`
|
||||
SystemInfo struct {
|
||||
Name string `json:"name"`
|
||||
GPUCoreCount int `json:"gpu_core_count"`
|
||||
} `json:"system_info"`
|
||||
Fans []struct {
|
||||
RPM int `json:"rpm"`
|
||||
MinRPM int `json:"min_rpm"`
|
||||
MaxRPM int `json:"max_rpm"`
|
||||
} `json:"fans"`
|
||||
Temperatures []struct {
|
||||
Group string `json:"group"`
|
||||
Avg float64 `json:"avg_celsius"`
|
||||
} `json:"temperatures"`
|
||||
}
|
||||
|
||||
// ioreg output uses ` = ` (with spaces) for top-level device properties and
|
||||
// `=` (no spaces) for values inside nested dictionaries such as
|
||||
// PerformanceStatistics.
|
||||
var (
|
||||
reIoregModel = regexp.MustCompile(`"model"\s*=\s*"([^"]+)"`)
|
||||
reIoregCoreCount = regexp.MustCompile(`"gpu-core-count"\s*=\s*(\d+)`)
|
||||
reIoregUtil = regexp.MustCompile(`"Device Utilization %"=(\d+)`)
|
||||
reIoregMemUsed = regexp.MustCompile(`"In use system memory"=(\d+)`)
|
||||
)
|
||||
|
||||
// ParseIoregOutput parses `ioreg -r -c IOGPU -d 1 -f` output into a GpuStat for
|
||||
// the Apple Silicon integrated GPU. This is a fallback for when mactop is not
|
||||
// installed: utilization and used memory are available, but power, temperature,
|
||||
// and fan speed are not exposed by ioreg. memTotalMB is the unified memory size
|
||||
// supplied by the caller, since Apple Silicon shares memory between CPU and GPU.
|
||||
// Returns nil if no GPU device is found in the output.
|
||||
func ParseIoregOutput(out []byte, memTotalMB int) *GpuStat {
|
||||
utilMatch := reIoregUtil.FindSubmatch(out)
|
||||
memMatch := reIoregMemUsed.FindSubmatch(out)
|
||||
if utilMatch == nil && memMatch == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var gpuUtil float64
|
||||
if utilMatch != nil {
|
||||
gpuUtil, _ = strconv.ParseFloat(string(utilMatch[1]), 64)
|
||||
}
|
||||
|
||||
const toMB = 1024 * 1024
|
||||
var memUsedMB int
|
||||
if memMatch != nil {
|
||||
memUsedBytes, _ := strconv.ParseInt(string(memMatch[1]), 10, 64)
|
||||
memUsedMB = int(memUsedBytes / toMB)
|
||||
}
|
||||
|
||||
var memUtil float64
|
||||
if memTotalMB > 0 {
|
||||
memUtil = float64(memUsedMB) / float64(memTotalMB) * 100
|
||||
}
|
||||
|
||||
name := "Apple GPU"
|
||||
if m := reIoregModel.FindSubmatch(out); m != nil {
|
||||
name = string(m[1])
|
||||
}
|
||||
if m := reIoregCoreCount.FindSubmatch(out); m != nil {
|
||||
if cores, err := strconv.Atoi(string(m[1])); err == nil && cores > 0 {
|
||||
name = fmt.Sprintf("%s (%d-core GPU)", name, cores)
|
||||
}
|
||||
}
|
||||
|
||||
return &GpuStat{
|
||||
Timestamp: time.Now(),
|
||||
ID: 0,
|
||||
Name: name,
|
||||
GpuUtilPct: gpuUtil,
|
||||
MemUtilPct: memUtil,
|
||||
MemUsedMB: memUsedMB,
|
||||
MemTotalMB: memTotalMB,
|
||||
}
|
||||
}
|
||||
|
||||
// ParseMactopLine parses a single line of mactop headless JSON output into a
|
||||
// GpuStat for the Apple Silicon integrated GPU. Returns nil if the line cannot
|
||||
// be parsed.
|
||||
func ParseMactopLine(line string) *GpuStat {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var out mactopOutput
|
||||
if err := json.Unmarshal([]byte(line), &out); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
const toMB = 1024 * 1024
|
||||
memUsedMB := int(out.Memory.Used / toMB)
|
||||
memTotalMB := int(out.Memory.Total / toMB)
|
||||
|
||||
var memUtil float64
|
||||
if memTotalMB > 0 {
|
||||
memUtil = float64(memUsedMB) / float64(memTotalMB) * 100
|
||||
}
|
||||
|
||||
name := out.SystemInfo.Name
|
||||
if name == "" {
|
||||
name = "Apple GPU"
|
||||
}
|
||||
if out.SystemInfo.GPUCoreCount > 0 {
|
||||
name = fmt.Sprintf("%s (%d-core GPU)", name, out.SystemInfo.GPUCoreCount)
|
||||
}
|
||||
|
||||
// Unified memory has no dedicated VRAM sensor; use the memory temperature
|
||||
// group when mactop exposes it.
|
||||
var vramTempC int
|
||||
for _, t := range out.Temperatures {
|
||||
if strings.EqualFold(t.Group, "Memory") {
|
||||
vramTempC = int(math.Round(t.Avg))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Average fan load across all fans as a percentage of their RPM range.
|
||||
var fanSpeed float64
|
||||
var fanCount int
|
||||
for _, f := range out.Fans {
|
||||
if f.MaxRPM > f.MinRPM {
|
||||
pct := float64(f.RPM-f.MinRPM) / float64(f.MaxRPM-f.MinRPM) * 100
|
||||
if pct < 0 {
|
||||
pct = 0
|
||||
}
|
||||
fanSpeed += pct
|
||||
fanCount++
|
||||
}
|
||||
}
|
||||
if fanCount > 0 {
|
||||
fanSpeed /= float64(fanCount)
|
||||
}
|
||||
|
||||
return &GpuStat{
|
||||
Timestamp: time.Now(),
|
||||
ID: 0,
|
||||
Name: name,
|
||||
TempC: int(math.Round(out.SocMetrics.GPUTemp)),
|
||||
VramTempC: vramTempC,
|
||||
GpuUtilPct: out.GPUUsage,
|
||||
MemUtilPct: memUtil,
|
||||
MemUsedMB: memUsedMB,
|
||||
MemTotalMB: memTotalMB,
|
||||
FanSpeedPct: fanSpeed,
|
||||
PowerDrawW: out.SocMetrics.GPUPower,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotImplemented = errors.New("Not Implemented")
|
||||
ErrNotImplemented = errors.New("not implemented")
|
||||
ErrNoGpuTool = errors.New("no GPU monitoring tool available")
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
package perf
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
@@ -11,7 +15,156 @@ import (
|
||||
)
|
||||
|
||||
func getGpuStats(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||
return nil, ErrNotImplemented
|
||||
if ch, err := tryMactop(ctx, every, logger); err == nil {
|
||||
logger.Info("using mactop for GPU monitoring")
|
||||
return ch, nil
|
||||
} else {
|
||||
logger.Debugf("mactop: %s", err.Error())
|
||||
}
|
||||
|
||||
if ch, err := tryIoreg(ctx, every, logger); err == nil {
|
||||
logger.Info("using ioreg for GPU monitoring")
|
||||
return ch, nil
|
||||
} else {
|
||||
logger.Debugf("ioreg: %s", err.Error())
|
||||
}
|
||||
|
||||
return nil, ErrNoGpuTool
|
||||
}
|
||||
|
||||
// tryIoreg polls `ioreg -r -c IOGPU -d 1 -f` for Apple Silicon GPU stats. It is
|
||||
// a fallback for when mactop is not installed. ioreg exposes GPU utilization and
|
||||
// used memory but not power, temperature, or fan speed.
|
||||
func tryIoreg(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||
if _, err := exec.LookPath("ioreg"); err != nil {
|
||||
return nil, ErrNoGpuTool
|
||||
}
|
||||
|
||||
// Verify ioreg actually reports a GPU device before committing to it, so we
|
||||
// can fall through to ErrNoGpuTool otherwise.
|
||||
if stat := sampleIoreg(ctx); stat == nil {
|
||||
return nil, fmt.Errorf("ioreg reported no GPU device")
|
||||
}
|
||||
|
||||
if every < time.Second {
|
||||
every = time.Second
|
||||
}
|
||||
|
||||
ch := make(chan []GpuStat, 1)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
ticker := time.NewTicker(every)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
stat := sampleIoreg(ctx)
|
||||
if stat == nil {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case ch <- []GpuStat{*stat}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// sampleIoreg runs ioreg once and parses a single GpuStat, or returns nil.
|
||||
func sampleIoreg(ctx context.Context) *GpuStat {
|
||||
out, err := exec.CommandContext(ctx, "ioreg", "-r", "-c", "IOGPU", "-d", "1", "-f").Output()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var memTotalMB int
|
||||
if vmStat, err := mem.VirtualMemory(); err == nil {
|
||||
memTotalMB = int(vmStat.Total / (1024 * 1024))
|
||||
}
|
||||
|
||||
return ParseIoregOutput(out, memTotalMB)
|
||||
}
|
||||
|
||||
// overlayIoregMem replaces a GpuStat's memory fields with the GPU-attributed
|
||||
// unified memory reported by ioreg. mactop only exposes whole-system memory, so
|
||||
// without this the mactop and ioreg backends would report different memory
|
||||
// semantics. It is a no-op when ioreg is unavailable or reports no GPU memory,
|
||||
// leaving the mactop-supplied values in place.
|
||||
func overlayIoregMem(ctx context.Context, stat *GpuStat) {
|
||||
ioStat := sampleIoreg(ctx)
|
||||
if ioStat == nil {
|
||||
return
|
||||
}
|
||||
stat.MemUsedMB = ioStat.MemUsedMB
|
||||
stat.MemTotalMB = ioStat.MemTotalMB
|
||||
stat.MemUtilPct = ioStat.MemUtilPct
|
||||
}
|
||||
|
||||
// tryMactop streams Apple Silicon GPU stats from mactop's headless mode.
|
||||
// See https://github.com/metaspartan/mactop. mactop emits one JSON object per
|
||||
// sample to stdout, which we parse into GpuStat.
|
||||
func tryMactop(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||
if _, err := exec.LookPath("mactop"); err != nil {
|
||||
return nil, ErrNoGpuTool
|
||||
}
|
||||
|
||||
// mactop samples power over the interval, so give it at least a second.
|
||||
intervalMs := int(every.Milliseconds())
|
||||
if intervalMs < 1000 {
|
||||
intervalMs = 1000
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "mactop",
|
||||
"--headless",
|
||||
"--format", "json",
|
||||
"--interval", fmt.Sprintf("%d", intervalMs),
|
||||
)
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mactop stdout pipe failed: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("mactop start failed: %w", err)
|
||||
}
|
||||
|
||||
ch := make(chan []GpuStat, 1)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
// mactop's JSON objects can be large; allow generous line lengths.
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
stat := ParseMactopLine(line)
|
||||
if stat != nil {
|
||||
// mactop only reports whole-system memory; overlay ioreg's
|
||||
// GPU-attributed unified memory so both backends are consistent.
|
||||
overlayIoregMem(ctx, stat)
|
||||
select {
|
||||
case ch <- []GpuStat{*stat}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
cmd.Wait()
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func readSysStats() (SysStat, error) {
|
||||
|
||||
@@ -264,3 +264,50 @@ func TestParseNvidiaSmiLine_ZeroMemoryTotal(t *testing.T) {
|
||||
require.NotNil(t, stat)
|
||||
assert.Equal(t, 0.0, stat.MemUtilPct)
|
||||
}
|
||||
|
||||
const ioregSample = `+-o AGXAcceleratorG13X <class AGXAcceleratorG13X, id 0x1000009a1, registered, matched, active, busy 0 (39191 ms), retain 108>
|
||||
{
|
||||
"model" = "Apple M1 Pro"
|
||||
"gpu-core-count" = 16
|
||||
"PerformanceStatistics" = {"In use system memory (driver)"=0,"Alloc system memory"=14511046656,"Tiler Utilization %"=34,"recoveryCount"=0,"Renderer Utilization %"=34,"Device Utilization %"=34,"In use system memory"=7688503296}
|
||||
"IOClass" = "AGXAcceleratorG13X"
|
||||
}`
|
||||
|
||||
func TestParseIoregOutput_ValidOutput(t *testing.T) {
|
||||
const memTotalMB = 32768
|
||||
|
||||
stat := ParseIoregOutput([]byte(ioregSample), memTotalMB)
|
||||
require.NotNil(t, stat)
|
||||
|
||||
assert.Equal(t, 0, stat.ID)
|
||||
assert.Equal(t, "Apple M1 Pro (16-core GPU)", stat.Name)
|
||||
assert.Equal(t, 34.0, stat.GpuUtilPct)
|
||||
assert.Equal(t, 7688503296/(1024*1024), stat.MemUsedMB)
|
||||
assert.Equal(t, memTotalMB, stat.MemTotalMB)
|
||||
assert.InDelta(t, float64(stat.MemUsedMB)/memTotalMB*100, stat.MemUtilPct, 0.01)
|
||||
// Not exposed by ioreg.
|
||||
assert.Equal(t, 0, stat.TempC)
|
||||
assert.Equal(t, 0.0, stat.PowerDrawW)
|
||||
assert.Equal(t, 0.0, stat.FanSpeedPct)
|
||||
}
|
||||
|
||||
func TestParseIoregOutput_NoGpuDevice(t *testing.T) {
|
||||
stat := ParseIoregOutput([]byte("no gpu here"), 32768)
|
||||
assert.Nil(t, stat)
|
||||
}
|
||||
|
||||
func TestParseIoregOutput_ZeroMemTotal(t *testing.T) {
|
||||
stat := ParseIoregOutput([]byte(ioregSample), 0)
|
||||
require.NotNil(t, stat)
|
||||
assert.Equal(t, 0.0, stat.MemUtilPct)
|
||||
}
|
||||
|
||||
func TestParseIoregOutput_MissingModel(t *testing.T) {
|
||||
const out = `"Device Utilization %"=50,"In use system memory"=1048576`
|
||||
|
||||
stat := ParseIoregOutput([]byte(out), 1024)
|
||||
require.NotNil(t, stat)
|
||||
assert.Equal(t, "Apple GPU", stat.Name)
|
||||
assert.Equal(t, 50.0, stat.GpuUtilPct)
|
||||
assert.Equal(t, 1, stat.MemUsedMB)
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
@@ -22,6 +21,22 @@ import (
|
||||
|
||||
var ErrStartAborted = fmt.Errorf("aborted")
|
||||
|
||||
// cmdWaitDelay is the upper bound the runtime will wait for child I/O to
|
||||
// drain after the process exits before force-closing the stdout/stderr
|
||||
// pipes. Required so that cmd.Wait() returns even when a forked grandchild
|
||||
// inherits and holds the pipes open (e.g. a shell wrapper that backgrounds
|
||||
// the real binary). killProcess sends the stop signal directly (not via the
|
||||
// cmd context), so this delay is measured from process exit rather than from
|
||||
// the stop request, and stays independent of the caller's graceful timeout.
|
||||
const cmdWaitDelay = 10 * time.Second
|
||||
|
||||
// parentCancelGraceTimeout is the graceful timeout used when the process is
|
||||
// torn down because parentCtx was cancelled (final router teardown or app
|
||||
// shutdown). In the normal flow the process has already been stopped via
|
||||
// Stop() by this point, so killProcess is a no-op kill; the short grace just
|
||||
// bounds the rare case where a process is still alive when its context is cut.
|
||||
const parentCancelGraceTimeout = time.Second
|
||||
|
||||
type runReq struct {
|
||||
timeout time.Duration
|
||||
respond chan error
|
||||
@@ -39,6 +54,7 @@ type waitReadyReq struct {
|
||||
type startResult struct {
|
||||
cmd *exec.Cmd
|
||||
cmdDone chan struct{}
|
||||
cancel context.CancelFunc
|
||||
handlerFn http.HandlerFunc
|
||||
err error
|
||||
}
|
||||
@@ -51,6 +67,11 @@ type ProcessCommand struct {
|
||||
processLogger *logmon.Monitor
|
||||
proxyLogger *logmon.Monitor
|
||||
|
||||
// waitDelay is assigned to cmd.WaitDelay when starting the upstream
|
||||
// process. Defaults to cmdWaitDelay; tests override it to keep the
|
||||
// pipe-close backstop from dominating their runtime.
|
||||
waitDelay time.Duration
|
||||
|
||||
runCh chan runReq
|
||||
stopCh chan stopReq
|
||||
waitReadyCh chan waitReadyReq
|
||||
@@ -85,6 +106,7 @@ func New(
|
||||
runCh: make(chan runReq),
|
||||
stopCh: make(chan stopReq),
|
||||
waitReadyCh: make(chan waitReadyReq),
|
||||
waitDelay: cmdWaitDelay,
|
||||
}
|
||||
p.state.Store(StateStopped)
|
||||
|
||||
@@ -122,6 +144,7 @@ func (p *ProcessCommand) run() {
|
||||
var (
|
||||
cmd *exec.Cmd
|
||||
cmdDone <-chan struct{}
|
||||
cmdCancel context.CancelFunc
|
||||
readyWaiters []waitReadyReq
|
||||
// runResp parks the in-flight Run caller's response channel. The
|
||||
// interface contract is that Run blocks until the process is
|
||||
@@ -164,9 +187,10 @@ func (p *ProcessCommand) run() {
|
||||
setState(StateShutdown)
|
||||
if cmd != nil {
|
||||
p.handler.Store(nil)
|
||||
p.killProcess(cmd, cmdDone, 100*time.Millisecond)
|
||||
p.killProcess(cmd, cmdCancel, cmdDone, parentCancelGraceTimeout)
|
||||
cmd = nil
|
||||
cmdDone = nil
|
||||
cmdCancel = nil
|
||||
}
|
||||
notifyWaiters(fmt.Errorf("[%s] shutdown", p.id))
|
||||
respondRun(fmt.Errorf("[%s] shutdown", p.id))
|
||||
@@ -177,8 +201,12 @@ func (p *ProcessCommand) run() {
|
||||
// cmdDone is nil while no process is running, so this case is
|
||||
// dormant outside of StateReady.
|
||||
case <-cmdDone:
|
||||
if cmdCancel != nil {
|
||||
cmdCancel()
|
||||
}
|
||||
cmd = nil
|
||||
cmdDone = nil
|
||||
cmdCancel = nil
|
||||
p.handler.Store(nil)
|
||||
setState(StateStopped)
|
||||
respondRun(fmt.Errorf("[%s] upstream exited unexpectedly", p.id))
|
||||
@@ -226,6 +254,7 @@ func (p *ProcessCommand) run() {
|
||||
if res.err == nil {
|
||||
cmd = res.cmd
|
||||
cmdDone = res.cmdDone
|
||||
cmdCancel = res.cancel
|
||||
fn := res.handlerFn
|
||||
p.handler.Store(&fn)
|
||||
setState(StateReady)
|
||||
@@ -273,7 +302,7 @@ func (p *ProcessCommand) run() {
|
||||
cancelStart()
|
||||
res := <-resultCh
|
||||
if res.cmd != nil {
|
||||
p.killProcess(res.cmd, res.cmdDone, stop.timeout)
|
||||
p.killProcess(res.cmd, res.cancel, res.cmdDone, stop.timeout)
|
||||
}
|
||||
setState(StateStopped)
|
||||
notifyWaiters(ErrStartAborted)
|
||||
@@ -293,7 +322,7 @@ func (p *ProcessCommand) run() {
|
||||
setState(StateShutdown)
|
||||
res := <-resultCh
|
||||
if res.cmd != nil {
|
||||
p.killProcess(res.cmd, res.cmdDone, 100*time.Millisecond)
|
||||
p.killProcess(res.cmd, res.cancel, res.cmdDone, parentCancelGraceTimeout)
|
||||
}
|
||||
notifyWaiters(fmt.Errorf("[%s] shutdown", p.id))
|
||||
respondRun(fmt.Errorf("[%s] shutdown", p.id))
|
||||
@@ -310,9 +339,10 @@ func (p *ProcessCommand) run() {
|
||||
case stop := <-p.stopCh:
|
||||
if cmd != nil {
|
||||
setState(StateStopping)
|
||||
p.killProcess(cmd, cmdDone, stop.timeout)
|
||||
p.killProcess(cmd, cmdCancel, cmdDone, stop.timeout)
|
||||
cmd = nil
|
||||
cmdDone = nil
|
||||
cmdCancel = nil
|
||||
p.handler.Store(nil)
|
||||
}
|
||||
// Stop is a no-op (and not an error) when already Stopped — this
|
||||
@@ -377,46 +407,71 @@ func (p *ProcessCommand) doStart(startCtx context.Context, healthCheckTimeout ti
|
||||
reverseProxy.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
cmd := exec.Command(args[0], args[1:]...)
|
||||
// cmdCtx + cmd.Cancel are wired as a safety net: if the context is ever
|
||||
// cancelled while the process is alive, cmd.Cancel sends SIGTERM / CmdStop
|
||||
// and the runtime escalates to SIGKILL after cmd.WaitDelay. In the normal
|
||||
// teardown path killProcess sends the stop signal directly instead, so
|
||||
// cmd.WaitDelay only acts as the inherited-pipe backstop measured from
|
||||
// process exit (see killProcess).
|
||||
cmdCtx, cmdCancel := context.WithCancel(context.Background())
|
||||
cmd := exec.CommandContext(cmdCtx, args[0], args[1:]...)
|
||||
cmd.Stderr = p.processLogger
|
||||
cmd.Stdout = p.processLogger
|
||||
cmd.Env = append(cmd.Environ(), p.config.Env...)
|
||||
cmd.Cancel = func() error { return p.sendStopSignal(cmd) }
|
||||
cmd.WaitDelay = p.waitDelay
|
||||
setProcAttributes(cmd)
|
||||
|
||||
p.proxyLogger.Debugf("<%s> Executing start command: %s, env: %s", p.id, strings.Join(args, " "), strings.Join(p.config.Env, ", "))
|
||||
|
||||
cmdDone := make(chan struct{})
|
||||
if err := cmd.Start(); err != nil {
|
||||
cmdCancel()
|
||||
return startResult{err: fmt.Errorf("failed to start command '%s': %w", strings.Join(args, " "), err)}
|
||||
}
|
||||
|
||||
go func() {
|
||||
waitErr := cmd.Wait()
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
p.proxyLogger.Debugf("<%s> process exited: code=%d, err=%v", p.id, exitErr.ExitCode(), waitErr)
|
||||
} else if waitErr != nil {
|
||||
p.proxyLogger.Debugf("<%s> process exited with error: %v", p.id, waitErr)
|
||||
} else {
|
||||
switch st := p.State(); {
|
||||
case waitErr == nil:
|
||||
p.proxyLogger.Debugf("<%s> process exited cleanly", p.id)
|
||||
case st == StateStopping || st == StateShutdown:
|
||||
// Expected: we force-terminated the process. A forced kill exits
|
||||
// the child with a non-zero code (e.g. taskkill /f on Windows
|
||||
// yields exit status 1), so this is not an error.
|
||||
p.proxyLogger.Debugf("<%s> process stopped by llama-swap: %v", p.id, waitErr)
|
||||
default:
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
p.proxyLogger.Debugf("<%s> process exited: code=%d, err=%v", p.id, exitErr.ExitCode(), waitErr)
|
||||
} else {
|
||||
p.proxyLogger.Debugf("<%s> process exited with error: %v", p.id, waitErr)
|
||||
}
|
||||
}
|
||||
close(cmdDone)
|
||||
}()
|
||||
|
||||
abort := func(err error) startResult {
|
||||
p.killProcess(cmd, cmdCancel, cmdDone, 5*time.Second)
|
||||
return startResult{err: err}
|
||||
}
|
||||
prematureExit := func() startResult {
|
||||
cmdCancel()
|
||||
return startResult{err: fmt.Errorf("upstream command exited prematurely")}
|
||||
}
|
||||
|
||||
if startCtx.Err() != nil {
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: ErrStartAborted}
|
||||
return abort(ErrStartAborted)
|
||||
}
|
||||
|
||||
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
||||
if checkEndpoint == "none" {
|
||||
return startResult{cmd: cmd, cmdDone: cmdDone, handlerFn: handlerFn}
|
||||
return startResult{cmd: cmd, cmdDone: cmdDone, cancel: cmdCancel, handlerFn: handlerFn}
|
||||
}
|
||||
|
||||
// Wait 250ms for the command to start up before health checking
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: ErrStartAborted}
|
||||
return abort(ErrStartAborted)
|
||||
case <-time.After(250 * time.Millisecond):
|
||||
}
|
||||
|
||||
@@ -424,16 +479,14 @@ func (p *ProcessCommand) doStart(startCtx context.Context, healthCheckTimeout ti
|
||||
for {
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: ErrStartAborted}
|
||||
return abort(ErrStartAborted)
|
||||
case <-cmdDone:
|
||||
return startResult{err: fmt.Errorf("upstream command exited prematurely")}
|
||||
return prematureExit()
|
||||
default:
|
||||
}
|
||||
|
||||
if time.Now().After(deadline) {
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: fmt.Errorf("health check timed out after %v", healthCheckTimeout)}
|
||||
return abort(fmt.Errorf("health check timed out after %v", healthCheckTimeout))
|
||||
}
|
||||
|
||||
req, _ := http.NewRequestWithContext(startCtx, "GET", p.config.CheckEndpoint, nil)
|
||||
@@ -445,42 +498,99 @@ func (p *ProcessCommand) doStart(startCtx context.Context, healthCheckTimeout ti
|
||||
p.proxyLogger.Infof("<%s> Health check passed on %s%s", p.id, p.config.Proxy, p.config.CheckEndpoint)
|
||||
break
|
||||
} else if startCtx.Err() != nil {
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: ErrStartAborted}
|
||||
return abort(ErrStartAborted)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: ErrStartAborted}
|
||||
return abort(ErrStartAborted)
|
||||
case <-cmdDone:
|
||||
return startResult{err: fmt.Errorf("upstream command exited prematurely")}
|
||||
return prematureExit()
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
return startResult{cmd: cmd, cmdDone: cmdDone, handlerFn: handlerFn}
|
||||
return startResult{cmd: cmd, cmdDone: cmdDone, cancel: cmdCancel, handlerFn: handlerFn}
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) killProcess(cmd *exec.Cmd, cmdDone <-chan struct{}, gracefulTimeout time.Duration) {
|
||||
// sendStopSignal runs the configured CmdStop (if any) or sends SIGTERM to
|
||||
// the upstream process. Wired up as cmd.Cancel so it fires whenever the
|
||||
// cmd's context is cancelled.
|
||||
func (p *ProcessCommand) sendStopSignal(cmd *exec.Cmd) error {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return
|
||||
p.processLogger.Debugf("<%s> sendStopSignal() called with nil cmd or process, nothing to stop", p.id)
|
||||
return nil
|
||||
}
|
||||
|
||||
pid := cmd.Process.Pid
|
||||
if p.config.CmdStop != "" {
|
||||
p.processLogger.Debugf("<%s> sendStopSignal() using CmdStop %q for pid %d", p.id, p.config.CmdStop, pid)
|
||||
stopArgs, err := config.SanitizeCommand(
|
||||
strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", cmd.Process.Pid)),
|
||||
strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", pid)),
|
||||
)
|
||||
if err == nil {
|
||||
p.processLogger.Debugf("<%s> sendStopSignal() running stop command: %s", p.id, strings.Join(stopArgs, " "))
|
||||
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
|
||||
stopCmd.Env = cmd.Env
|
||||
setProcAttributes(stopCmd)
|
||||
stopCmd.Run()
|
||||
} else {
|
||||
cmd.Process.Signal(syscall.SIGTERM)
|
||||
runErr := stopCmd.Run()
|
||||
if runErr != nil {
|
||||
p.processLogger.Errorf("<%s> sendStopSignal() stop command failed: %v", p.id, runErr)
|
||||
} else {
|
||||
p.processLogger.Debugf("<%s> sendStopSignal() stop command completed for pid %d", p.id, pid)
|
||||
}
|
||||
return runErr
|
||||
}
|
||||
} else {
|
||||
cmd.Process.Signal(syscall.SIGTERM)
|
||||
// fall through to SIGTERM if sanitize failed
|
||||
p.processLogger.Errorf("<%s> sendStopSignal() failed to sanitize CmdStop %q: %v, falling back to terminateProcessTree", p.id, p.config.CmdStop, err)
|
||||
}
|
||||
// On Unix this SIGTERMs the whole process group so a forked grandchild
|
||||
// (e.g. a shell wrapper that backgrounds the real binary) is taken down
|
||||
// with the parent rather than orphaned.
|
||||
p.processLogger.Debugf("<%s> sendStopSignal() no CmdStop configured, calling terminateProcessTree for pid %d", p.id, pid)
|
||||
termErr := terminateProcessTree(cmd)
|
||||
if termErr != nil {
|
||||
p.processLogger.Errorf("<%s> sendStopSignal() terminateProcessTree failed for pid %d: %v", p.id, pid, termErr)
|
||||
}
|
||||
return termErr
|
||||
}
|
||||
|
||||
// killProcess terminates the upstream process. The flow:
|
||||
//
|
||||
// 1. Send the graceful stop signal (CmdStop / SIGTERM) directly — NOT by
|
||||
// cancelling cmdCtx. Cancelling the context would start cmd.WaitDelay
|
||||
// immediately, which force-kills the process WaitDelay after the signal
|
||||
// and would silently cap gracefulTimeout at WaitDelay whenever
|
||||
// gracefulTimeout is the longer of the two.
|
||||
// 2. We wait up to gracefulTimeout for the process to exit on its own.
|
||||
// 3. If still alive, we SIGKILL the process group directly (Unix) so any
|
||||
// forked descendant is force-terminated alongside the parent.
|
||||
// 4. We wait on cmdDone. cmd.WaitDelay (set when the cmd was built) is the
|
||||
// critical backstop here: once the process exits, if a forked grandchild
|
||||
// inherited the stdout/stderr pipes and is still holding them, the runtime
|
||||
// force-closes the pipes WaitDelay after the exit and cmd.Wait() unblocks.
|
||||
// Because we never cancelled the context, that WaitDelay timer measures
|
||||
// from process exit (see os/exec awaitGoroutines), not from this call.
|
||||
// Without WaitDelay this select would hang forever (the v219 bug).
|
||||
//
|
||||
// cancel() is still invoked (deferred) to release the context, but only after
|
||||
// the process has exited and os/exec's ctx watcher has already torn down, so it
|
||||
// never re-fires cmd.Cancel.
|
||||
func (p *ProcessCommand) killProcess(cmd *exec.Cmd, cancel context.CancelFunc, cmdDone <-chan struct{}, gracefulTimeout time.Duration) {
|
||||
if cancel == nil {
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
// Deliver CmdStop / SIGTERM in a goroutine so a slow or hanging CmdStop
|
||||
// cannot block the run() goroutine; the gracefulTimeout + Process.Kill
|
||||
// path below still guarantees teardown.
|
||||
if cmd != nil {
|
||||
go func() {
|
||||
p.proxyLogger.Debugf("[%s] sending stop signal with timeout %v", p.id, gracefulTimeout)
|
||||
if err := p.sendStopSignal(cmd); err != nil {
|
||||
p.proxyLogger.Warnf("[%s] stop signal failed: %v", p.id, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
timer := time.NewTimer(gracefulTimeout)
|
||||
@@ -488,10 +598,16 @@ func (p *ProcessCommand) killProcess(cmd *exec.Cmd, cmdDone <-chan struct{}, gra
|
||||
|
||||
select {
|
||||
case <-cmdDone:
|
||||
return
|
||||
case <-timer.C:
|
||||
cmd.Process.Kill()
|
||||
<-cmdDone
|
||||
}
|
||||
|
||||
if cmd != nil {
|
||||
// SIGKILL the whole process group on Unix so any descendant that
|
||||
// ignored or outlived the graceful signal is force-terminated too.
|
||||
_ = killProcessTree(cmd)
|
||||
}
|
||||
<-cmdDone
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) ID() string {
|
||||
|
||||
@@ -0,0 +1,262 @@
|
||||
//go:build !windows
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
)
|
||||
|
||||
// TestProcessCommand_StopForkingWrapper is a regression for the bug reported
|
||||
// against v219 where Stop would hang indefinitely when the upstream command
|
||||
// is a shell wrapper that forks the real binary (e.g. `#!/bin/bash` then
|
||||
// `"$@"`). After SIGTERM the wrapper dies but the grandchild inherits the
|
||||
// stdout/stderr pipes; cmd.Wait() blocks waiting for the pipe-copy goroutine
|
||||
// to drain EOF, which never happens while the grandchild holds the fds.
|
||||
//
|
||||
// The fix is cmd.WaitDelay (combined with exec.CommandContext + cmd.Cancel),
|
||||
// which causes the runtime to force-close the pipes after the delay so
|
||||
// cmd.Wait() — and therefore Stop — returns.
|
||||
func TestProcessCommand_StopForkingWrapper(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
port := getFreePort(t)
|
||||
dir := t.TempDir()
|
||||
pidFile := filepath.Join(dir, "child.pid")
|
||||
|
||||
// Wrapper script: backgrounds the child (which inherits stdout/stderr),
|
||||
// records its PID for cleanup, then waits. When SIGTERM hits bash it
|
||||
// dies without forwarding the signal; the grandchild keeps running and
|
||||
// keeps the inherited pipe fds open. This is the scenario reported in
|
||||
// the v219 regression.
|
||||
wrapper := filepath.Join(dir, "wrapper.sh")
|
||||
script := fmt.Sprintf("#!/bin/bash\n%q -port %d -silent &\necho $! > %q\nwait\n",
|
||||
simpleResponderPath, port, pidFile)
|
||||
if err := os.WriteFile(wrapper, []byte(script), 0o755); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { killChildFromPidFile(pidFile) })
|
||||
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: wrapper,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
})
|
||||
// Shrink the pipe-close backstop so the test doesn't sit at the
|
||||
// production default (10s). Must be set before Run() so doStart picks
|
||||
// it up when building the cmd.
|
||||
const testWaitDelay = 250 * time.Millisecond
|
||||
p.waitDelay = testWaitDelay
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
|
||||
// Stop must return within a bounded time even though the grandchild
|
||||
// is still holding the pipe open. Budget is generous on top of
|
||||
// testWaitDelay to absorb scheduling jitter on slow CI runners; the
|
||||
// pre-fix behaviour was an unbounded hang, so any reasonable cap
|
||||
// distinguishes pass from fail.
|
||||
stopReturned := make(chan error, 1)
|
||||
stopStart := time.Now()
|
||||
go func() { stopReturned <- p.Stop(testStopTimeout) }()
|
||||
|
||||
const stopBudget = testWaitDelay + 2*time.Second
|
||||
select {
|
||||
case err := <-stopReturned:
|
||||
if err != nil {
|
||||
t.Fatalf("Stop: %v", err)
|
||||
}
|
||||
t.Logf("Stop returned in %v", time.Since(stopStart))
|
||||
case <-time.After(stopBudget):
|
||||
t.Fatalf("Stop did not return within %v — cmd.Wait() likely hung on inherited pipe", stopBudget)
|
||||
}
|
||||
|
||||
if got := p.State(); got != StateStopped {
|
||||
t.Errorf("after Stop: expected state %s, got %s", StateStopped, got)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Errorf("Run did not return after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_StopHonorsGracefulTimeout is a regression for the bug
|
||||
// where cmd.WaitDelay capped the graceful shutdown window. killProcess used to
|
||||
// cancel the cmd context to deliver SIGTERM, which starts cmd.WaitDelay
|
||||
// immediately; a process whose SIGTERM handler needs longer than WaitDelay to
|
||||
// finish was force-killed early even though Stop was given a much longer
|
||||
// timeout. The fix sends the signal directly so WaitDelay measures from process
|
||||
// exit (its inherited-pipe backstop role), leaving the graceful window to the
|
||||
// caller's Stop timeout.
|
||||
func TestProcessCommand_StopHonorsGracefulTimeout(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
marker := filepath.Join(dir, "graceful.done")
|
||||
ready := filepath.Join(dir, "trap.ready")
|
||||
|
||||
// On SIGTERM, sleep past the (short) WaitDelay, then write the marker and
|
||||
// exit cleanly. If WaitDelay still drove the kill, bash would be SIGKILLed
|
||||
// mid-handler and the marker would never be written. The ready file is
|
||||
// written only after the trap is installed so the test does not race
|
||||
// SIGTERM ahead of it (CheckEndpoint:none marks ready before bash runs).
|
||||
script := filepath.Join(dir, "graceful.sh")
|
||||
body := fmt.Sprintf(
|
||||
"#!/bin/bash\ncleanup() { sleep 0.6; echo done > %q; exit 0; }\ntrap cleanup SIGTERM\necho ready > %q\nwhile true; do sleep 0.1; done\n",
|
||||
marker, ready,
|
||||
)
|
||||
if err := os.WriteFile(script, []byte(body), 0o755); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: script,
|
||||
Proxy: "http://127.0.0.1:1", // unused: health check disabled
|
||||
CheckEndpoint: "none",
|
||||
})
|
||||
// WaitDelay shorter than the handler's 0.6s sleep, and far shorter than the
|
||||
// Stop timeout below — this is the window the old code mis-killed in.
|
||||
p.waitDelay = 200 * time.Millisecond
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
|
||||
// Wait until the trap is installed before stopping.
|
||||
trapDeadline := time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
if _, err := os.Stat(ready); err == nil {
|
||||
break
|
||||
}
|
||||
if time.Now().After(trapDeadline) {
|
||||
t.Fatalf("script did not install SIGTERM trap in time")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
stopStart := time.Now()
|
||||
if err := p.Stop(5 * time.Second); err != nil {
|
||||
t.Fatalf("Stop: %v", err)
|
||||
}
|
||||
elapsed := time.Since(stopStart)
|
||||
|
||||
// The handler must have run to completion (marker written) rather than
|
||||
// being force-killed at waitDelay.
|
||||
if _, err := os.Stat(marker); err != nil {
|
||||
t.Fatalf("graceful handler did not complete (marker missing): %v", err)
|
||||
}
|
||||
// And Stop must have waited for the handler (>~0.6s), not returned at the
|
||||
// 200ms waitDelay.
|
||||
if elapsed < 500*time.Millisecond {
|
||||
t.Fatalf("Stop returned in %v — process was killed before its graceful handler finished", elapsed)
|
||||
}
|
||||
|
||||
if got := p.State(); got != StateStopped {
|
||||
t.Errorf("after Stop: expected state %s, got %s", StateStopped, got)
|
||||
}
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Errorf("Run did not return after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_StopReapsForkedGrandchild verifies that stopping a forking
|
||||
// wrapper takes down the backgrounded grandchild too, rather than leaving it as
|
||||
// an orphan. The fix is Setpgid (runtime_unix.go): the wrapper leads its own
|
||||
// process group, so the stop signal is delivered to the whole group via the
|
||||
// negative PID and reaches the grandchild the wrapper never reaped.
|
||||
func TestProcessCommand_StopReapsForkedGrandchild(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
port := getFreePort(t)
|
||||
dir := t.TempDir()
|
||||
pidFile := filepath.Join(dir, "child.pid")
|
||||
|
||||
wrapper := filepath.Join(dir, "wrapper.sh")
|
||||
script := fmt.Sprintf("#!/bin/bash\n%q -port %d -silent &\necho $! > %q\nwait\n",
|
||||
simpleResponderPath, port, pidFile)
|
||||
if err := os.WriteFile(wrapper, []byte(script), 0o755); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { killChildFromPidFile(pidFile) })
|
||||
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: wrapper,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
})
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
|
||||
// Read the grandchild PID the wrapper recorded.
|
||||
var childPID int
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
data, err := os.ReadFile(pidFile)
|
||||
if err == nil {
|
||||
if pid, perr := strconv.Atoi(strings.TrimSpace(string(data))); perr == nil && pid > 0 {
|
||||
childPID = pid
|
||||
break
|
||||
}
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatalf("wrapper did not record grandchild PID")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop: %v", err)
|
||||
}
|
||||
|
||||
// After Stop the grandchild must be gone. Signal 0 probes liveness without
|
||||
// actually sending a signal; give it a brief window to exit after the
|
||||
// group SIGTERM.
|
||||
proc, err := os.FindProcess(childPID)
|
||||
if err != nil {
|
||||
t.Fatalf("FindProcess: %v", err)
|
||||
}
|
||||
gone := false
|
||||
for i := 0; i < 100; i++ {
|
||||
if err := proc.Signal(syscall.Signal(0)); err != nil {
|
||||
gone = true
|
||||
break
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
if !gone {
|
||||
t.Errorf("grandchild PID %d still alive after Stop — process group was not reaped", childPID)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Errorf("Run did not return after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
// killChildFromPidFile reads a PID written by the wrapper script and SIGKILLs
|
||||
// it so leaked orphans don't accumulate between test runs. Best-effort.
|
||||
func killChildFromPidFile(pidFile string) {
|
||||
data, err := os.ReadFile(pidFile)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
|
||||
if err != nil || pid <= 0 {
|
||||
return
|
||||
}
|
||||
proc, err := os.FindProcess(pid)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = proc.Kill()
|
||||
}
|
||||
@@ -4,9 +4,41 @@ package process
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes
|
||||
// setProcAttributes starts the upstream in its own process group (Setpgid) so
|
||||
// the entire process tree can be signalled at once via its negative PID. This
|
||||
// is what lets us reap a forked grandchild — e.g. a shell wrapper that
|
||||
// backgrounds the real binary and exits — instead of leaking it as an orphan
|
||||
// that holds the inherited stdout/stderr pipes open.
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
// No-op on Unix systems
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
}
|
||||
|
||||
// terminateProcessTree sends SIGTERM to the whole process group led by the
|
||||
// command, giving every process in the tree a chance to shut down gracefully.
|
||||
func terminateProcessTree(cmd *exec.Cmd) error {
|
||||
return signalProcessTree(cmd, syscall.SIGTERM)
|
||||
}
|
||||
|
||||
// killProcessTree sends SIGKILL to the whole process group, force-terminating
|
||||
// every process in the tree.
|
||||
func killProcessTree(cmd *exec.Cmd) error {
|
||||
return signalProcessTree(cmd, syscall.SIGKILL)
|
||||
}
|
||||
|
||||
// signalProcessTree signals the process group led by cmd.Process. Because the
|
||||
// child was started with Setpgid it is its own group leader (pgid == pid), so
|
||||
// targeting -pid reaches the child and every descendant still in the group.
|
||||
// Falls back to signalling just the child if the group send fails (e.g. the
|
||||
// group has already drained), so we never silently skip the signal.
|
||||
func signalProcessTree(cmd *exec.Cmd, sig syscall.Signal) error {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
if err := syscall.Kill(-cmd.Process.Pid, sig); err != nil {
|
||||
return cmd.Process.Signal(sig)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,14 +3,51 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes
|
||||
// setProcAttributes sets platform-specific process attributes. CREATE_NO_WINDOW
|
||||
// keeps the upstream from spawning its own console window.
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
HideWindow: true,
|
||||
CreationFlags: 0x08000000, // CREATE_NO_WINDOW
|
||||
}
|
||||
}
|
||||
|
||||
// terminateProcessTree requests a graceful shutdown of the whole process tree
|
||||
// rooted at cmd.Process. Windows has no SIGTERM or process-group signalling, so
|
||||
// we shell out to `taskkill /t`, which walks the child tree by PID — the
|
||||
// equivalent of signalling a Unix process group. Without /f, taskkill asks the
|
||||
// processes to close rather than force-killing them.
|
||||
func terminateProcessTree(cmd *exec.Cmd) error {
|
||||
return taskkillProcessTree(cmd, false)
|
||||
}
|
||||
|
||||
// killProcessTree force-terminates the whole process tree rooted at cmd.Process
|
||||
// via `taskkill /f /t`, so any descendant that ignored or outlived the graceful
|
||||
// request is killed alongside the parent rather than leaked as an orphan.
|
||||
func killProcessTree(cmd *exec.Cmd) error {
|
||||
return taskkillProcessTree(cmd, true)
|
||||
}
|
||||
|
||||
// taskkillProcessTree runs taskkill against cmd.Process.Pid. The /t flag
|
||||
// terminates the process together with any child processes it started, which is
|
||||
// the Windows analogue of signalling a Unix process group via its negative PID.
|
||||
// When force is true the /f flag force-kills; otherwise taskkill requests a
|
||||
// graceful close.
|
||||
func taskkillProcessTree(cmd *exec.Cmd, force bool) error {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
args := make([]string, 0, 4)
|
||||
if force {
|
||||
args = append(args, "/f")
|
||||
}
|
||||
args = append(args, "/t", "/pid", fmt.Sprintf("%d", cmd.Process.Pid))
|
||||
kill := exec.Command("taskkill", args...)
|
||||
setProcAttributes(kill)
|
||||
return kill.Run()
|
||||
}
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
//go:build !windows
|
||||
|
||||
package process
|
||||
|
||||
// SetupTreeCleanup is a no-op on non-Windows platforms, where upstream process
|
||||
// teardown is handled via process-group signalling (see runtime_unix.go).
|
||||
func SetupTreeCleanup() error { return nil }
|
||||
@@ -0,0 +1,50 @@
|
||||
//go:build windows
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// SetupTreeCleanup assigns the current process to a Windows Job Object
|
||||
// configured with JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE. Upstream processes
|
||||
// spawned afterwards are associated with the same job, so when llama-swap exits
|
||||
// for any reason — graceful shutdown, a forced second Ctrl+C, or a crash — the
|
||||
// OS terminates the whole job and reaps every child instead of leaving orphans
|
||||
// behind. It is the parent-side complement to the per-process teardown in
|
||||
// runtime_windows.go.
|
||||
//
|
||||
// The job handle is intentionally leaked for the lifetime of the process: the
|
||||
// kill-on-close behaviour fires when the last handle is released, which the OS
|
||||
// does when the process exits.
|
||||
func SetupTreeCleanup() error {
|
||||
job, err := windows.CreateJobObject(nil, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("CreateJobObject: %w", err)
|
||||
}
|
||||
|
||||
info := windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION{
|
||||
BasicLimitInformation: windows.JOBOBJECT_BASIC_LIMIT_INFORMATION{
|
||||
LimitFlags: windows.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
|
||||
},
|
||||
}
|
||||
if _, err := windows.SetInformationJobObject(
|
||||
job,
|
||||
windows.JobObjectExtendedLimitInformation,
|
||||
uintptr(unsafe.Pointer(&info)),
|
||||
uint32(unsafe.Sizeof(info)),
|
||||
); err != nil {
|
||||
windows.CloseHandle(job)
|
||||
return fmt.Errorf("SetInformationJobObject: %w", err)
|
||||
}
|
||||
|
||||
if err := windows.AssignProcessToJobObject(job, windows.CurrentProcess()); err != nil {
|
||||
windows.CloseHandle(job)
|
||||
return fmt.Errorf("AssignProcessToJobObject: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
+152
-425
@@ -11,6 +11,8 @@ import (
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
type shutdownReq struct {
|
||||
@@ -24,66 +26,39 @@ type unloadReq struct {
|
||||
respond chan struct{}
|
||||
}
|
||||
|
||||
type handlerReq struct {
|
||||
model string
|
||||
ctx context.Context
|
||||
respond chan handlerResp
|
||||
positionCh chan int
|
||||
}
|
||||
|
||||
type handlerResp struct {
|
||||
handleFunc http.HandlerFunc
|
||||
err error
|
||||
}
|
||||
|
||||
type swapDone struct {
|
||||
modelID string
|
||||
err error
|
||||
}
|
||||
|
||||
type serveDoneEvent struct {
|
||||
modelID string
|
||||
}
|
||||
|
||||
type activeSwap struct {
|
||||
modelID string
|
||||
evict []string
|
||||
waiters []handlerReq
|
||||
}
|
||||
|
||||
// swapPlanner is the only piece of behaviour that differs between concrete
|
||||
// routers. baseRouter never inspects its internals.
|
||||
type swapPlanner interface {
|
||||
// EvictionFor returns running model IDs that must be stopped before
|
||||
// target can serve. alsoRunning lists models the baseRouter has already
|
||||
// committed to loading (in-flight swaps) which the planner cannot see
|
||||
// via process.State() yet. Pure decision; must not log.
|
||||
EvictionFor(target string, alsoRunning []string) []string
|
||||
|
||||
// OnSwapStart runs once at the start of every swap. Planners may log
|
||||
// their decision here at whatever verbosity they choose.
|
||||
OnSwapStart(target string)
|
||||
}
|
||||
|
||||
// baseRouter owns the channels, run-loop, and orchestration code shared by
|
||||
// every concrete router. Concrete routers embed *baseRouter and supply a
|
||||
// swapPlanner that captures how their eviction set is decided.
|
||||
// baseRouter owns the channels, run-loop, and process machinery shared by every
|
||||
// concrete router. Concrete routers embed *baseRouter and supply a
|
||||
// scheduler.Factory (which captures their scheduler.Swapper) describing how
|
||||
// requests are scheduled and how their eviction set is decided. baseRouter
|
||||
// implements scheduler.Effects so the scheduler can call back for side-effects.
|
||||
type baseRouter struct {
|
||||
name string
|
||||
config config.Config
|
||||
processes map[string]process.Process
|
||||
logger *logmon.Monitor
|
||||
planner swapPlanner
|
||||
schedule scheduler.Scheduler
|
||||
|
||||
// shutdownCtx governs the request machinery: cancelling it tells grant()
|
||||
// and ServeHTTP to stop granting and reject callers. It is deliberately
|
||||
// separate from procCtx — see procCtx below.
|
||||
shutdownCtx context.Context
|
||||
shutdownFn context.CancelFunc
|
||||
shuttingDown atomic.Bool
|
||||
|
||||
handlerCh chan handlerReq
|
||||
// procCtx is the parent context for every managed process and governs
|
||||
// process lifetime only. handleShutdown stops processes gracefully via
|
||||
// Stop() and cancels procCtx afterwards, so teardown is never a context
|
||||
// cancel racing the graceful path (which collapsed the grace to 100ms and
|
||||
// let the caller return before children were reaped — see process run loop).
|
||||
procCtx context.Context
|
||||
procCancel context.CancelFunc
|
||||
|
||||
handlerCh chan scheduler.HandlerReq
|
||||
cancelCh chan scheduler.HandlerReq
|
||||
shutdownCh chan shutdownReq
|
||||
unloadCh chan unloadReq
|
||||
swapDoneCh chan swapDone
|
||||
serveDoneCh chan serveDoneEvent
|
||||
swapDoneCh chan scheduler.SwapDone
|
||||
serveDoneCh chan scheduler.ServeDoneEvent
|
||||
|
||||
runDone chan struct{}
|
||||
|
||||
@@ -95,23 +70,34 @@ type baseRouter struct {
|
||||
testProcessed chan struct{}
|
||||
}
|
||||
|
||||
func newBaseRouter(name string, conf config.Config, processes map[string]process.Process, planner swapPlanner, logger *logmon.Monitor) *baseRouter {
|
||||
func newBaseRouter(
|
||||
name string,
|
||||
conf config.Config,
|
||||
processes map[string]process.Process,
|
||||
logger *logmon.Monitor,
|
||||
newSched scheduler.Factory,
|
||||
) *baseRouter {
|
||||
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
|
||||
return &baseRouter{
|
||||
procCtx, procCancel := context.WithCancel(context.Background())
|
||||
b := &baseRouter{
|
||||
name: name,
|
||||
config: conf,
|
||||
processes: processes,
|
||||
logger: logger,
|
||||
planner: planner,
|
||||
shutdownCtx: shutdownCtx,
|
||||
shutdownFn: shutdownFn,
|
||||
handlerCh: make(chan handlerReq),
|
||||
procCtx: procCtx,
|
||||
procCancel: procCancel,
|
||||
handlerCh: make(chan scheduler.HandlerReq),
|
||||
cancelCh: make(chan scheduler.HandlerReq),
|
||||
shutdownCh: make(chan shutdownReq),
|
||||
unloadCh: make(chan unloadReq),
|
||||
swapDoneCh: make(chan swapDone),
|
||||
serveDoneCh: make(chan serveDoneEvent),
|
||||
swapDoneCh: make(chan scheduler.SwapDone),
|
||||
serveDoneCh: make(chan scheduler.ServeDoneEvent),
|
||||
runDone: make(chan struct{}),
|
||||
}
|
||||
b.schedule = newSched(name, logger, b)
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *baseRouter) notifyProcessed() {
|
||||
@@ -123,30 +109,31 @@ func (b *baseRouter) notifyProcessed() {
|
||||
func (b *baseRouter) run() {
|
||||
defer close(b.runDone)
|
||||
|
||||
active := make(map[string]*activeSwap)
|
||||
inFlight := make(map[string]int)
|
||||
var queued []handlerReq
|
||||
|
||||
for {
|
||||
select {
|
||||
case req := <-b.shutdownCh:
|
||||
b.handleShutdown(req, active, queued)
|
||||
b.handleShutdown(req)
|
||||
return
|
||||
|
||||
case req := <-b.handlerCh:
|
||||
b.handleRequest(req, active, inFlight, &queued)
|
||||
b.schedule.OnRequest(req)
|
||||
b.notifyProcessed()
|
||||
|
||||
case req := <-b.cancelCh:
|
||||
b.schedule.OnCancel(req)
|
||||
b.notifyProcessed()
|
||||
|
||||
case req := <-b.unloadCh:
|
||||
b.handleUnload(req, active, inFlight, &queued)
|
||||
b.schedule.OnUnload(req.targets, req.timeout)
|
||||
close(req.respond)
|
||||
b.notifyProcessed()
|
||||
|
||||
case ev := <-b.swapDoneCh:
|
||||
b.handleSwapDone(ev, active, inFlight, &queued)
|
||||
b.schedule.OnSwapDone(ev)
|
||||
b.notifyProcessed()
|
||||
|
||||
case ev := <-b.serveDoneCh:
|
||||
b.handleServeDone(ev, active, inFlight, &queued)
|
||||
b.schedule.OnServeDone(ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -163,37 +150,68 @@ func (b *baseRouter) run() {
|
||||
// down, the send never lands, one of the other select cases fires, and we
|
||||
// report back that the grant did NOT happen.
|
||||
//
|
||||
// That distinction matters for in-flight bookkeeping — see grantHandler.
|
||||
func (b *baseRouter) grant(req handlerReq, resp handlerResp) bool {
|
||||
// That distinction matters for in-flight bookkeeping — see GrantServe.
|
||||
func (b *baseRouter) grant(req scheduler.HandlerReq, resp scheduler.HandlerResp) bool {
|
||||
select {
|
||||
case req.respond <- resp:
|
||||
case req.Respond <- resp:
|
||||
return true
|
||||
case <-req.ctx.Done():
|
||||
case <-req.Ctx.Done():
|
||||
return false
|
||||
case <-b.shutdownCtx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// grantHandler is the "this caller can now use process p" path. It does
|
||||
// two things that must stay locked together:
|
||||
//
|
||||
// 1. Hand the caller a wrapped p.ServeHTTP (via trackedServe) so when the
|
||||
// HTTP request finishes, the run loop hears about it.
|
||||
// 2. Bump inFlight[modelID] so the router knows this process is busy and
|
||||
// refuses to evict it until the count comes back down.
|
||||
//
|
||||
// The increment is gated on grant() returning true. If grant() returns
|
||||
// false, the caller already walked away and trackedServe will never run —
|
||||
// which means no matching decrement will ever arrive on serveDoneCh.
|
||||
// Incrementing in that case would strand the counter at >0 forever and
|
||||
// the router would never again be willing to swap this model out.
|
||||
//
|
||||
// In short: increment if and only if we know a decrement is coming.
|
||||
func (b *baseRouter) grantHandler(req handlerReq, modelID string, p process.Process, inFlight map[string]int) {
|
||||
if b.grant(req, handlerResp{handleFunc: b.trackedServe(modelID, p)}) {
|
||||
inFlight[modelID]++
|
||||
// ModelState implements scheduler.Effects.
|
||||
func (b *baseRouter) ModelState(modelID string) (process.ProcessState, bool) {
|
||||
p, ok := b.processes[modelID]
|
||||
if !ok {
|
||||
var zero process.ProcessState
|
||||
return zero, false
|
||||
}
|
||||
return p.State(), true
|
||||
}
|
||||
|
||||
// StartSwap implements scheduler.Effects, launching the swap goroutine.
|
||||
func (b *baseRouter) StartSwap(modelID string, evict []string) {
|
||||
go b.doSwap(modelID, evict)
|
||||
}
|
||||
|
||||
// GrantError implements scheduler.Effects.
|
||||
func (b *baseRouter) GrantError(req scheduler.HandlerReq, err error) {
|
||||
b.grant(req, scheduler.HandlerResp{Err: err})
|
||||
}
|
||||
|
||||
// GrantServe implements scheduler.Effects. It hands the caller a wrapped
|
||||
// p.ServeHTTP (via trackedServe) so the run loop hears about the request
|
||||
// finishing, and reports whether the caller received it. The scheduler bumps
|
||||
// its in-flight count only on a true return: if grant() returns false the
|
||||
// caller already walked away and trackedServe will never run, so no matching
|
||||
// decrement will ever arrive — incrementing would strand the counter at >0 and
|
||||
// the router would never again be willing to evict this model.
|
||||
func (b *baseRouter) GrantServe(req scheduler.HandlerReq, modelID string) bool {
|
||||
p := b.processes[modelID]
|
||||
return b.grant(req, scheduler.HandlerResp{HandleFunc: b.trackedServe(modelID, p)})
|
||||
}
|
||||
|
||||
// StopProcesses implements scheduler.Effects, stopping the named processes in
|
||||
// parallel and blocking until all have stopped.
|
||||
func (b *baseRouter) StopProcesses(timeout time.Duration, ids []string) {
|
||||
var wg sync.WaitGroup
|
||||
for _, id := range ids {
|
||||
p, ok := b.processes[id]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(id string, p process.Process) {
|
||||
defer wg.Done()
|
||||
if err := p.Stop(timeout); err != nil {
|
||||
b.logger.Warnf("%s: stopping %s failed: %v", b.name, id, err)
|
||||
}
|
||||
}(id, p)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// trackedServe is the wrapper that closes the loop on in-flight tracking.
|
||||
@@ -210,7 +228,7 @@ func (b *baseRouter) trackedServe(modelID string, p process.Process) http.Handle
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
select {
|
||||
case b.serveDoneCh <- serveDoneEvent{modelID: modelID}:
|
||||
case b.serveDoneCh <- scheduler.ServeDoneEvent{ModelID: modelID}:
|
||||
case <-b.shutdownCtx.Done():
|
||||
}
|
||||
}()
|
||||
@@ -218,240 +236,6 @@ func (b *baseRouter) trackedServe(modelID string, p process.Process) http.Handle
|
||||
}
|
||||
}
|
||||
|
||||
// handleRequest decides what to do with one incoming ServeHTTP request. It is
|
||||
// called from run() and never blocks indefinitely: any work that has to wait
|
||||
// (starting a process, stopping siblings, waiting for ready) is deferred to
|
||||
// a swap goroutine and reported back via swapDoneCh.
|
||||
//
|
||||
// The decision tree, in order:
|
||||
//
|
||||
// 1. Unknown model — respond with ErrNoLocalModelFound and move on.
|
||||
// 2. A swap to the same model is already in flight — attach this waiter so
|
||||
// one swap serves all callers that asked for the same model.
|
||||
// 3. Fast path — the target process is already ready, the planner sees
|
||||
// nothing to evict, and no in-flight swap is evicting it. Hand back its
|
||||
// ServeHTTP immediately (wrapped so the run loop knows when it ends).
|
||||
// 4. Would collide with an in-flight swap (we'd stop their target, or
|
||||
// they're stopping us) — park in the queue for handleSwapDone to drain.
|
||||
// 5. Would evict a process that is still handling requests — park in the
|
||||
// queue. handleServeDone will retry when the busy process drains.
|
||||
// 6. Otherwise — start a new swap. This may run in parallel with other
|
||||
// active swaps when their evict sets don't intersect.
|
||||
func (b *baseRouter) handleRequest(req handlerReq, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
||||
// (1) Unknown model.
|
||||
p, ok := b.processes[req.model]
|
||||
if !ok {
|
||||
b.logger.Debugf("%s: model %s not handled by this router", b.name, req.model)
|
||||
b.grant(req, handlerResp{err: ErrNoLocalModelFound})
|
||||
return
|
||||
}
|
||||
|
||||
// (2) Join an in-flight swap for the same model.
|
||||
if s, ok := active[req.model]; ok {
|
||||
b.logger.Debugf("%s: joining in-flight swap for model %s (%d waiters)", b.name, req.model, len(s.waiters)+1)
|
||||
s.waiters = append(s.waiters, req)
|
||||
return
|
||||
}
|
||||
|
||||
evict := b.planner.EvictionFor(req.model, activeTargets(active, req.model))
|
||||
|
||||
// (3) Fast path: ready, nothing to evict, and nobody is evicting us.
|
||||
if p.State() == process.StateReady && len(evict) == 0 && !collidesWith(req.model, evict, active) {
|
||||
b.logger.Debugf("%s: fast-path serving model %s (already ready)", b.name, req.model)
|
||||
b.grantHandler(req, req.model, p, inFlight)
|
||||
return
|
||||
}
|
||||
|
||||
// (4) Collision with an in-flight swap — queue.
|
||||
if collidesWith(req.model, evict, active) {
|
||||
b.logger.Debugf("%s: queuing request for model %s (collides with in-flight swap)", b.name, req.model)
|
||||
*queued = append(*queued, req)
|
||||
b.broadcastQueuePositions(*queued)
|
||||
return
|
||||
}
|
||||
|
||||
// (5) Would evict a busy process — queue until it drains.
|
||||
if conflictsWithInFlight(evict, inFlight) {
|
||||
b.logger.Debugf("%s: queuing request for model %s (would evict in-flight process)", b.name, req.model)
|
||||
*queued = append(*queued, req)
|
||||
b.broadcastQueuePositions(*queued)
|
||||
return
|
||||
}
|
||||
|
||||
// (6) Start a new (possibly parallel) swap.
|
||||
b.logger.Debugf("%s: starting swap for model %s, evicting %v", b.name, req.model, evict)
|
||||
s := b.startSwap(req, evict)
|
||||
active[s.modelID] = s
|
||||
}
|
||||
|
||||
// handleSwapDone is called from run() when a swap goroutine reports that it
|
||||
// has finished. It fans out the result to every waiter that joined this swap,
|
||||
// removes the swap from the active map, and then walks the queue once,
|
||||
// promoting any items that no longer collide with the remaining active set.
|
||||
// FIFO order is preserved: items still blocked stay in place.
|
||||
func (b *baseRouter) handleSwapDone(ev swapDone, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
||||
s, ok := active[ev.modelID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(active, ev.modelID)
|
||||
|
||||
for _, w := range s.waiters {
|
||||
if ev.err != nil {
|
||||
b.grant(w, handlerResp{err: ev.err})
|
||||
} else {
|
||||
p := b.processes[ev.modelID]
|
||||
b.grantHandler(w, ev.modelID, p, inFlight)
|
||||
}
|
||||
}
|
||||
|
||||
b.drainQueue(active, inFlight, queued)
|
||||
}
|
||||
|
||||
// handleServeDone is called from run() each time a tracked ServeHTTP
|
||||
// finishes. It decrements the per-model in-flight count and, when that
|
||||
// drops to zero, retries the queue: requests whose swap was deferred
|
||||
// because they would have evicted this (now-idle) process can now proceed.
|
||||
func (b *baseRouter) handleServeDone(ev serveDoneEvent, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
||||
inFlight[ev.modelID]--
|
||||
if inFlight[ev.modelID] <= 0 {
|
||||
delete(inFlight, ev.modelID)
|
||||
b.drainQueue(active, inFlight, queued)
|
||||
}
|
||||
}
|
||||
|
||||
// drainQueue walks the queued requests in order, re-running the handleRequest
|
||||
// decision tree against the (now smaller) active set. Items that can now start
|
||||
// or join become satisfied; items still blocked remain queued in original
|
||||
// order so they get another chance on the next swap completion.
|
||||
func (b *baseRouter) drainQueue(active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
||||
if len(*queued) == 0 {
|
||||
return
|
||||
}
|
||||
pending := *queued
|
||||
var remaining []handlerReq
|
||||
for _, req := range pending {
|
||||
p, ok := b.processes[req.model]
|
||||
if !ok {
|
||||
b.grant(req, handlerResp{err: ErrNoLocalModelFound})
|
||||
continue
|
||||
}
|
||||
if s, ok := active[req.model]; ok {
|
||||
b.logger.Debugf("%s: queued request for model %s now joining in-flight swap", b.name, req.model)
|
||||
s.waiters = append(s.waiters, req)
|
||||
continue
|
||||
}
|
||||
evict := b.planner.EvictionFor(req.model, activeTargets(active, req.model))
|
||||
if p.State() == process.StateReady && len(evict) == 0 && !collidesWith(req.model, evict, active) {
|
||||
b.logger.Debugf("%s: queued request for model %s now served fast-path", b.name, req.model)
|
||||
b.grantHandler(req, req.model, p, inFlight)
|
||||
continue
|
||||
}
|
||||
if collidesWith(req.model, evict, active) {
|
||||
remaining = append(remaining, req)
|
||||
continue
|
||||
}
|
||||
if conflictsWithInFlight(evict, inFlight) {
|
||||
remaining = append(remaining, req)
|
||||
continue
|
||||
}
|
||||
b.logger.Debugf("%s: queued request for model %s now starting swap, evicting %v", b.name, req.model, evict)
|
||||
s := b.startSwap(req, evict)
|
||||
active[s.modelID] = s
|
||||
}
|
||||
*queued = remaining
|
||||
b.broadcastQueuePositions(*queued)
|
||||
}
|
||||
|
||||
// broadcastQueuePositions sends each queued request its current 1-indexed
|
||||
// position. Sends are non-blocking: if the channel is full, the old value is
|
||||
// drained first so the consumer always sees the latest position.
|
||||
func (b *baseRouter) broadcastQueuePositions(queued []handlerReq) {
|
||||
for i, req := range queued {
|
||||
pos := i + 1
|
||||
select {
|
||||
case req.positionCh <- pos:
|
||||
default:
|
||||
select {
|
||||
case <-req.positionCh:
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case req.positionCh <- pos:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *baseRouter) startSwap(initial handlerReq, evict []string) *activeSwap {
|
||||
swap := &activeSwap{
|
||||
modelID: initial.model,
|
||||
evict: evict,
|
||||
waiters: []handlerReq{initial},
|
||||
}
|
||||
b.planner.OnSwapStart(initial.model)
|
||||
go b.doSwap(initial.model, evict)
|
||||
return swap
|
||||
}
|
||||
|
||||
// activeTargets returns the IDs of every in-flight swap target except exclude.
|
||||
// baseRouter passes this to the planner so eviction decisions account for
|
||||
// models that have been committed to but have not yet transitioned to
|
||||
// StateStarting in their process state machine.
|
||||
func activeTargets(active map[string]*activeSwap, exclude string) []string {
|
||||
if len(active) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(active))
|
||||
for id := range active {
|
||||
if id == exclude {
|
||||
continue
|
||||
}
|
||||
out = append(out, id)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// collidesWith reports whether a new swap with this target and evict set can
|
||||
// safely run alongside the currently active swaps. Same-target callers should
|
||||
// JOIN (handled before this) — they do not collide with themselves.
|
||||
func collidesWith(target string, evict []string, active map[string]*activeSwap) bool {
|
||||
for id, s := range active {
|
||||
if id == target {
|
||||
continue
|
||||
}
|
||||
if containsString(evict, id) {
|
||||
return true
|
||||
}
|
||||
if containsString(s.evict, target) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// conflictsWithInFlight reports whether any model in evict is still handling
|
||||
// requests. Stopping a busy process would cancel its callers' connections,
|
||||
// so the router defers the swap until those callers finish.
|
||||
func conflictsWithInFlight(evict []string, inFlight map[string]int) bool {
|
||||
for _, m := range evict {
|
||||
if inFlight[m] > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsString(xs []string, s string) bool {
|
||||
for _, x := range xs {
|
||||
if x == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *baseRouter) doSwap(modelID string, toStop []string) {
|
||||
timeout := b.healthCheckTimeout()
|
||||
|
||||
@@ -479,29 +263,24 @@ func (b *baseRouter) doSwap(modelID string, toStop []string) {
|
||||
err := target.WaitReady(b.shutdownCtx)
|
||||
|
||||
select {
|
||||
case b.swapDoneCh <- swapDone{modelID: modelID, err: err}:
|
||||
case b.swapDoneCh <- scheduler.SwapDone{ModelID: modelID, Err: err}:
|
||||
case <-b.shutdownCtx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
func (b *baseRouter) handleShutdown(req shutdownReq, active map[string]*activeSwap, queued []handlerReq) {
|
||||
func (b *baseRouter) handleShutdown(req shutdownReq) {
|
||||
shutdownErr := fmt.Errorf("%s is shutting down", b.name)
|
||||
|
||||
// Cancel shutdownCtx first so any waiter that is currently parked on
|
||||
// its respond channel can exit via its own shutdownCtx.Done() branch.
|
||||
// The grant calls below then either land (waiter happened to receive
|
||||
// The OnShutdown grants below then either land (waiter happened to receive
|
||||
// before noticing shutdown) or fall through immediately via grant's
|
||||
// shutdownCtx case — either way the waiter sees a non-OK response.
|
||||
// This does NOT touch processes: their lifetime is procCtx, cancelled
|
||||
// only after the graceful Stop() calls below have reaped them.
|
||||
b.shutdownFn()
|
||||
|
||||
for _, s := range active {
|
||||
for _, w := range s.waiters {
|
||||
b.grant(w, handlerResp{err: shutdownErr})
|
||||
}
|
||||
}
|
||||
for _, w := range queued {
|
||||
b.grant(w, handlerResp{err: shutdownErr})
|
||||
}
|
||||
b.schedule.OnShutdown(shutdownErr)
|
||||
|
||||
stopTimeout := req.timeout
|
||||
if stopTimeout <= 0 {
|
||||
@@ -535,6 +314,11 @@ func (b *baseRouter) handleShutdown(req shutdownReq, active map[string]*activeSw
|
||||
<-done
|
||||
}
|
||||
|
||||
// Every process is stopped (children reaped via Stop()). Cancel procCtx so
|
||||
// the process run-loop goroutines exit; they are already StateStopped, so
|
||||
// this is a clean no-op kill rather than a forced teardown.
|
||||
b.procCancel()
|
||||
|
||||
req.respond <- nil
|
||||
}
|
||||
|
||||
@@ -607,75 +391,6 @@ func (b *baseRouter) Unload(timeout time.Duration, models ...string) {
|
||||
<-req.respond
|
||||
}
|
||||
|
||||
// handleUnload runs on the run loop in response to an Unload call. It
|
||||
// reconciles router-owned state with the impending Stop, then performs
|
||||
// the Stop synchronously so callers of Unload remain blocked until each
|
||||
// targeted process has actually exited.
|
||||
func (b *baseRouter) handleUnload(req unloadReq, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
||||
unloadErr := fmt.Errorf("%s: model unloaded", b.name)
|
||||
|
||||
targetSet := make(map[string]bool, len(req.targets))
|
||||
for _, id := range req.targets {
|
||||
targetSet[id] = true
|
||||
}
|
||||
|
||||
// Release waiters of any in-flight swap whose target is being
|
||||
// unloaded. The swap goroutine itself is left to finish on its own;
|
||||
// when its swapDone arrives, handleSwapDone will find no entry in
|
||||
// active and silently drop it.
|
||||
for id := range targetSet {
|
||||
s, ok := active[id]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, w := range s.waiters {
|
||||
b.grant(w, handlerResp{err: unloadErr})
|
||||
}
|
||||
delete(active, id)
|
||||
}
|
||||
|
||||
// Drop queued requests addressed to unloaded models. Requests for
|
||||
// other models stay queued and may benefit from drainQueue at the end.
|
||||
if len(*queued) > 0 {
|
||||
kept := (*queued)[:0]
|
||||
for _, w := range *queued {
|
||||
if targetSet[w.model] {
|
||||
b.grant(w, handlerResp{err: unloadErr})
|
||||
continue
|
||||
}
|
||||
kept = append(kept, w)
|
||||
}
|
||||
*queued = kept
|
||||
}
|
||||
|
||||
// Stop the targeted processes. Done synchronously so Unload's caller
|
||||
// can rely on "after Unload returns, the process is stopped". inFlight
|
||||
// is intentionally NOT cleared here: each dying handler will fire its
|
||||
// trackedServe defer and reach handleServeDone in the normal way once
|
||||
// the run loop is free again.
|
||||
var wg sync.WaitGroup
|
||||
for id := range targetSet {
|
||||
p, ok := b.processes[id]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(id string, p process.Process) {
|
||||
defer wg.Done()
|
||||
if err := p.Stop(req.timeout); err != nil {
|
||||
b.logger.Warnf("%s: unloading %s failed: %v", b.name, id, err)
|
||||
}
|
||||
}(id, p)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Removing entries from active above may have unblocked queued
|
||||
// requests that previously collided with the now-cancelled swaps.
|
||||
b.drainQueue(active, inFlight, queued)
|
||||
|
||||
close(req.respond)
|
||||
}
|
||||
|
||||
func (b *baseRouter) Shutdown(timeout time.Duration) error {
|
||||
if !b.shuttingDown.CompareAndSwap(false, true) {
|
||||
return fmt.Errorf("%s shutdown already in progress", b.name)
|
||||
@@ -691,24 +406,24 @@ func (b *baseRouter) Shutdown(timeout time.Duration) error {
|
||||
|
||||
func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
if b.shuttingDown.Load() {
|
||||
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||
shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||
return
|
||||
}
|
||||
|
||||
data, err := FetchContext(req, b.config)
|
||||
data, err := shared.FetchContext(req, b.config)
|
||||
if err != nil {
|
||||
SendError(w, req, err)
|
||||
shared.SendError(w, req, err)
|
||||
return
|
||||
}
|
||||
|
||||
hr := handlerReq{
|
||||
model: data.ModelID,
|
||||
ctx: req.Context(),
|
||||
// Unbuffered: a successful send on respond proves the waiter is
|
||||
hr := scheduler.HandlerReq{
|
||||
Model: data.ModelID,
|
||||
Ctx: req.Context(),
|
||||
// Unbuffered: a successful send on Respond proves the waiter is
|
||||
// alive and consuming. grant() relies on this to avoid handing a
|
||||
// handleFunc to a cancelled waiter and leaking the inFlight count.
|
||||
respond: make(chan handlerResp),
|
||||
positionCh: make(chan int, 1),
|
||||
Respond: make(chan scheduler.HandlerResp),
|
||||
PositionCh: make(chan int, 1),
|
||||
}
|
||||
|
||||
select {
|
||||
@@ -716,7 +431,7 @@ func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
case <-req.Context().Done():
|
||||
return
|
||||
case <-b.shutdownCtx.Done():
|
||||
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||
shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -736,7 +451,7 @@ func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case pos := <-hr.positionCh:
|
||||
case pos := <-hr.PositionCh:
|
||||
lw.setUpdate(fmt.Sprintf("Queue position: #%d", pos))
|
||||
case <-swapCtx.Done():
|
||||
return
|
||||
@@ -745,31 +460,43 @@ func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
}()
|
||||
}
|
||||
|
||||
var resp handlerResp
|
||||
select {
|
||||
case resp = <-hr.respond:
|
||||
// finishLoading stops the loading stream and fences its goroutine off from
|
||||
// the ResponseWriter before the real handler (or ServeHTTP's return)
|
||||
// reclaims it. release() must run even when waitForCompletion times out:
|
||||
// otherwise a still-streaming goroutine flushes a finalized response and
|
||||
// panics on the recycled *bufio.Writer.
|
||||
finishLoading := func() {
|
||||
cancelLoad()
|
||||
if lw != nil {
|
||||
lw.waitForCompletion(1 * time.Second)
|
||||
lw.release()
|
||||
}
|
||||
}
|
||||
|
||||
var resp scheduler.HandlerResp
|
||||
select {
|
||||
case resp = <-hr.Respond:
|
||||
finishLoading()
|
||||
case <-req.Context().Done():
|
||||
cancelLoad()
|
||||
if lw != nil {
|
||||
lw.waitForCompletion(1 * time.Second)
|
||||
finishLoading()
|
||||
// Notify the scheduler so it can prune this request from its queue
|
||||
// and swap waiters. Without this, a queued request whose client left
|
||||
// would sit in the scheduler until drainQueue eventually starts a
|
||||
// wasted model load for it.
|
||||
select {
|
||||
case b.cancelCh <- hr:
|
||||
case <-b.shutdownCtx.Done():
|
||||
}
|
||||
return
|
||||
case <-b.shutdownCtx.Done():
|
||||
cancelLoad()
|
||||
if lw != nil {
|
||||
lw.waitForCompletion(1 * time.Second)
|
||||
}
|
||||
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||
finishLoading()
|
||||
shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||
return
|
||||
}
|
||||
|
||||
if resp.err != nil {
|
||||
SendError(w, req, resp.err)
|
||||
if resp.Err != nil {
|
||||
shared.SendError(w, req, resp.Err)
|
||||
return
|
||||
}
|
||||
resp.handleFunc(w, req)
|
||||
resp.HandleFunc(w, req)
|
||||
}
|
||||
|
||||
+15
-614
@@ -5,35 +5,34 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
|
||||
)
|
||||
|
||||
// stubPlanner is a swapPlanner that returns a fixed eviction list per target
|
||||
// and never logs. It lets the base-router tests cover shared run-loop
|
||||
// behaviour without dragging in either real router's eviction rules.
|
||||
type stubPlanner struct {
|
||||
evict map[string][]string
|
||||
}
|
||||
// These tests cover baseRouter's own machinery — the run loop, process
|
||||
// lifecycle (doSwap), grant/ServeHTTP plumbing, Unload, and Shutdown. The
|
||||
// scheduling decision logic (queueing, collation, eviction collisions) lives in
|
||||
// the scheduler package and is tested directly there; see fifo_test.go.
|
||||
|
||||
func (s *stubPlanner) EvictionFor(target string, _ []string) []string {
|
||||
if s.evict == nil {
|
||||
return nil
|
||||
}
|
||||
return s.evict[target]
|
||||
}
|
||||
// stubPlanner evicts nothing. baseRouter tests drive the run loop through the
|
||||
// default FIFO scheduler without exercising any particular eviction policy.
|
||||
type stubPlanner struct{}
|
||||
|
||||
func (s *stubPlanner) OnSwapStart(string) {}
|
||||
func (s *stubPlanner) EvictionFor(string, []string) []string { return nil }
|
||||
func (s *stubPlanner) OnSwapStart(string, []string) {}
|
||||
|
||||
func newTestBase(t *testing.T, processes map[string]process.Process, planner swapPlanner) *baseRouter {
|
||||
func newTestBase(t *testing.T, processes map[string]process.Process, planner scheduler.Swapper) *baseRouter {
|
||||
t.Helper()
|
||||
conf := config.Config{HealthCheckTimeout: 5}
|
||||
b := newBaseRouter("test", conf, processes, planner, logmon.NewWriter(io.Discard))
|
||||
b := newBaseRouter("test", conf, processes, logmon.NewWriter(io.Discard),
|
||||
func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler {
|
||||
return scheduler.NewFIFO(name, logger, planner, conf.Routing.Scheduler.Settings.Fifo, eff)
|
||||
})
|
||||
b.testProcessed = make(chan struct{}, 64)
|
||||
go b.run()
|
||||
t.Cleanup(func() {
|
||||
@@ -157,114 +156,6 @@ func TestBaseRouter_Unload_StopsInParallel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_Unload_ReleasesActiveSwapWaiters verifies that Unload
|
||||
// rejoins router state: a request whose swap to the unloaded model is
|
||||
// still in progress receives an error, instead of being abandoned
|
||||
// against a process that's about to vanish.
|
||||
func TestBaseRouter_Unload_ReleasesActiveSwapWaiters(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
// autoReady=false: the swap parks on WaitReady so we can interrupt
|
||||
// it with Unload before it completes.
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w, newRequest("a"))
|
||||
close(done)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1) // handlerReq absorbed; swap started
|
||||
<-a.runStarted
|
||||
|
||||
b.Unload(time.Second, "a")
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("ServeHTTP did not return after Unload")
|
||||
}
|
||||
if w.Code == http.StatusOK {
|
||||
t.Errorf("expected non-OK status after Unload, got %d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if a.State() != process.StateStopped {
|
||||
t.Errorf("a state=%q want stopped", a.State())
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_Unload_DropsQueuedRequests verifies that queued requests
|
||||
// for an unloaded model receive an error rather than sitting forever in
|
||||
// the queue against state the router no longer maintains.
|
||||
func TestBaseRouter_Unload_DropsQueuedRequests(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
// Loading B evicts A — so a request for B while A is loading queues.
|
||||
planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}}
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, planner)
|
||||
|
||||
// r1 starts the swap to A and parks on WaitReady (autoReady=false).
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
<-a.runStarted
|
||||
|
||||
// r2 for B collides with A's in-flight swap and queues.
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
// Unload B — r2 (queued, targeting B) must be released with an error.
|
||||
b.Unload(time.Second, "b")
|
||||
|
||||
select {
|
||||
case <-done2:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("queued B request did not return after Unload(b)")
|
||||
}
|
||||
if w2.Code == http.StatusOK {
|
||||
t.Errorf("queued B request: expected non-OK status, got %d", w2.Code)
|
||||
}
|
||||
if got := pb.runCalls.Load(); got != 0 {
|
||||
t.Errorf("b.runCalls=%d want 0 (B should never have been started)", got)
|
||||
}
|
||||
|
||||
// Release r1 so the test cleans up cleanly.
|
||||
a.markReady()
|
||||
select {
|
||||
case <-done1:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("r1 did not complete after a.markReady")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_FastPath(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
b.ServeHTTP(w, newRequest("a"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := a.serveCalls.Load(); got != 1 {
|
||||
t.Errorf("serveCalls=%d want 1", got)
|
||||
}
|
||||
if got := a.runCalls.Load(); got != 0 {
|
||||
t.Errorf("runCalls=%d want 0 (fast path should not start)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_OnDemandStart(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.autoReady = true
|
||||
@@ -285,43 +176,6 @@ func TestBaseRouter_OnDemandStart(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_ConcurrentSameModel(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
// autoReady=false so the swap parks on WaitReady until we release it.
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||
|
||||
const N = 5
|
||||
var wg sync.WaitGroup
|
||||
codes := make([]int, N)
|
||||
for i := 0; i < N; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
w := httptest.NewRecorder()
|
||||
b.ServeHTTP(w, newRequest("a"))
|
||||
codes[i] = w.Code
|
||||
}(i)
|
||||
}
|
||||
|
||||
waitProcessed(t, b.testProcessed, N) // all N handlerReqs absorbed by run()
|
||||
<-a.runStarted // swap goroutine reached Run()
|
||||
a.markReady()
|
||||
wg.Wait()
|
||||
|
||||
for i, c := range codes {
|
||||
if c != http.StatusOK {
|
||||
t.Errorf("request %d: status=%d", i, c)
|
||||
}
|
||||
}
|
||||
if got := a.runCalls.Load(); got != 1 {
|
||||
t.Errorf("runCalls=%d want 1 (single swap should issue one Run)", got)
|
||||
}
|
||||
if got := a.serveCalls.Load(); got != N {
|
||||
t.Errorf("serveCalls=%d want %d", got, N)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_ContextCancel(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
// autoReady=false so swap parks forever until we mark ready.
|
||||
@@ -364,459 +218,6 @@ func TestBaseRouter_ContextCancel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_QueuedDifferentModel(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pa := newFakeProcess("b")
|
||||
|
||||
// Loading b must stop a.
|
||||
planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}}
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pa}, planner)
|
||||
|
||||
// First request starts a swap to A; A's autoReady=false so it parks.
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
<-a.runStarted
|
||||
|
||||
// Second request for B should queue while A's swap is in flight.
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
if got := pa.runCalls.Load(); got != 0 {
|
||||
t.Errorf("b started early: runCalls=%d want 0 while A's swap is pending", got)
|
||||
}
|
||||
|
||||
// Release A's swap. B's swap should then run.
|
||||
a.markReady()
|
||||
waitProcessed(t, b.testProcessed, 1) // swapDone for A → B's swap kicked off
|
||||
<-pa.runStarted
|
||||
|
||||
select {
|
||||
case <-done1:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("A request did not complete")
|
||||
}
|
||||
pa.markReady()
|
||||
select {
|
||||
case <-done2:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("queued B request did not complete after A's swap")
|
||||
}
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Errorf("B status=%d body=%q", w2.Code, w2.Body.String())
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 1 {
|
||||
t.Errorf("a.stopCalls=%d want 1 (B's swap must stop A)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_QueueCollation verifies that incoming requests of the form
|
||||
// a, b, c, a, b, c collapse into three swaps (one per model) and that the
|
||||
// second request for each model rides the fast path — either joining the
|
||||
// active swap, or being pulled out of the queue when handleSwapDone promotes
|
||||
// the next model.
|
||||
func TestBaseRouter_QueueCollation(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
pc := newFakeProcess("c")
|
||||
|
||||
// Each model evicts the other two so all swaps are mutually exclusive.
|
||||
planner := &stubPlanner{evict: map[string][]string{
|
||||
"a": {"b", "c"},
|
||||
"b": {"a", "c"},
|
||||
"c": {"a", "b"},
|
||||
}}
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner)
|
||||
|
||||
var (
|
||||
completedMu sync.Mutex
|
||||
completed []string
|
||||
)
|
||||
record := func(id string) {
|
||||
completedMu.Lock()
|
||||
defer completedMu.Unlock()
|
||||
completed = append(completed, id)
|
||||
}
|
||||
|
||||
ids := []string{"a", "b", "c", "a", "b", "c"}
|
||||
var wg sync.WaitGroup
|
||||
for _, id := range ids {
|
||||
id := id
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
w := httptest.NewRecorder()
|
||||
b.ServeHTTP(w, newRequest(id))
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("%s: status=%d body=%q", id, w.Code, w.Body.String())
|
||||
return
|
||||
}
|
||||
record(id)
|
||||
}()
|
||||
// Wait for run() to absorb this request before launching the next,
|
||||
// so handlerCh receives them in launch order.
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
}
|
||||
|
||||
// All 6 are now parked in run()'s waiters/queue. Release each swap in
|
||||
// sequence, waiting deterministically for each promotion to fire.
|
||||
<-a.runStarted
|
||||
a.markReady()
|
||||
waitProcessed(t, b.testProcessed, 1) // swapDone(a) → b swap kicked off
|
||||
|
||||
<-pb.runStarted
|
||||
pb.markReady()
|
||||
waitProcessed(t, b.testProcessed, 1) // swapDone(b) → c swap kicked off
|
||||
|
||||
<-pc.runStarted
|
||||
pc.markReady()
|
||||
wg.Wait()
|
||||
|
||||
if got := len(completed); got != 6 {
|
||||
t.Fatalf("completed=%v want 6", completed)
|
||||
}
|
||||
|
||||
// run() fans out responses in model-grouped order (a1,a2 → b1,b2 → c1,c2)
|
||||
// but waiter goroutines may be scheduled in any order after their respond
|
||||
// channel fires, so completion order isn't deterministic. Per-model counts
|
||||
// (combined with the runCalls checks below) are sufficient to prove queue
|
||||
// collation collapsed each pair into a single swap.
|
||||
aDone, bDone, cDone := 0, 0, 0
|
||||
for _, id := range completed {
|
||||
switch id {
|
||||
case "a":
|
||||
aDone++
|
||||
case "b":
|
||||
bDone++
|
||||
case "c":
|
||||
cDone++
|
||||
}
|
||||
}
|
||||
if aDone != 2 || bDone != 2 || cDone != 2 {
|
||||
t.Errorf("per-model counts: a=%d b=%d c=%d, want 2 each (order=%v)", aDone, bDone, cDone, completed)
|
||||
}
|
||||
|
||||
// Single swap per model — the second request for each must have ridden
|
||||
// the fast path (joined active swap or joined a queued sibling), not
|
||||
// triggered an extra Run.
|
||||
if got := a.runCalls.Load(); got != 1 {
|
||||
t.Errorf("a.runCalls=%d want 1", got)
|
||||
}
|
||||
if got := pb.runCalls.Load(); got != 1 {
|
||||
t.Errorf("b.runCalls=%d want 1", got)
|
||||
}
|
||||
if got := pc.runCalls.Load(); got != 1 {
|
||||
t.Errorf("c.runCalls=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_ConcurrentDisjointSwaps verifies that two requests with
|
||||
// non-conflicting evict sets are loaded in parallel: both Run() calls happen
|
||||
// before either process is marked ready.
|
||||
func TestBaseRouter_ConcurrentDisjointSwaps(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
|
||||
// Empty evict sets for both: they can load in parallel.
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, &stubPlanner{})
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
// Both swaps must have reached Run() before either is marked ready —
|
||||
// proves they ran in parallel rather than serializing.
|
||||
<-a.runStarted
|
||||
<-pb.runStarted
|
||||
|
||||
a.markReady()
|
||||
pb.markReady()
|
||||
|
||||
select {
|
||||
case <-done1:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("request A did not complete")
|
||||
}
|
||||
select {
|
||||
case <-done2:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("request B did not complete")
|
||||
}
|
||||
|
||||
if w1.Code != http.StatusOK {
|
||||
t.Errorf("A status=%d body=%q", w1.Code, w1.Body.String())
|
||||
}
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Errorf("B status=%d body=%q", w2.Code, w2.Body.String())
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 0 {
|
||||
t.Errorf("a.stopCalls=%d want 0 (parallel swap, no eviction)", got)
|
||||
}
|
||||
if got := pb.stopCalls.Load(); got != 0 {
|
||||
t.Errorf("b.stopCalls=%d want 0 (parallel swap, no eviction)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_QueueDrainPromotesMultiple verifies that completing one swap
|
||||
// unblocks every queued request that no longer collides — they all start in
|
||||
// parallel rather than one-per-completion.
|
||||
func TestBaseRouter_QueueDrainPromotesMultiple(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
pc := newFakeProcess("c")
|
||||
|
||||
// A's swap evicts both B and C, so B and C must queue. Once A finishes
|
||||
// B and C themselves have empty evict sets, so they can start together.
|
||||
planner := &stubPlanner{evict: map[string][]string{
|
||||
"a": {"b", "c"},
|
||||
}}
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner)
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
<-a.runStarted
|
||||
|
||||
// B and C arrive while A is loading. evict_b and evict_c are empty,
|
||||
// but collidesWith returns true because they appear in A's evict set.
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
w3 := httptest.NewRecorder()
|
||||
done3 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w3, newRequest("c"))
|
||||
close(done3)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
if got := pb.runCalls.Load(); got != 0 {
|
||||
t.Errorf("b started early: runCalls=%d", got)
|
||||
}
|
||||
if got := pc.runCalls.Load(); got != 0 {
|
||||
t.Errorf("c started early: runCalls=%d", got)
|
||||
}
|
||||
|
||||
// Release A. The swapDone handler should drain the queue and start
|
||||
// both B and C in parallel.
|
||||
a.markReady()
|
||||
waitProcessed(t, b.testProcessed, 1) // swapDone(A) → drainQueue starts B and C
|
||||
<-pb.runStarted
|
||||
<-pc.runStarted
|
||||
|
||||
pb.markReady()
|
||||
pc.markReady()
|
||||
|
||||
for i, ch := range []chan struct{}{done1, done2, done3} {
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("request %d did not complete", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_Shutdown_FailsAllInFlight verifies that shutdown returns
|
||||
// the shutdown error to every waiter on every active swap AND to every
|
||||
// queued request.
|
||||
func TestBaseRouter_Shutdown_FailsAllInFlight(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
pc := newFakeProcess("c")
|
||||
|
||||
// a and b load in parallel (empty evicts). c collides with both.
|
||||
planner := &stubPlanner{evict: map[string][]string{
|
||||
"c": {"a", "b"},
|
||||
}}
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner)
|
||||
|
||||
const waitersPer = 2
|
||||
var wg sync.WaitGroup
|
||||
codes := make([]int, 0, 2*waitersPer+1)
|
||||
var codesMu sync.Mutex
|
||||
record := func(code int) {
|
||||
codesMu.Lock()
|
||||
codes = append(codes, code)
|
||||
codesMu.Unlock()
|
||||
}
|
||||
|
||||
launch := func(model string) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
w := httptest.NewRecorder()
|
||||
b.ServeHTTP(w, newRequest(model))
|
||||
record(w.Code)
|
||||
}()
|
||||
}
|
||||
|
||||
// Active swaps for a and b, each with 2 waiters.
|
||||
for i := 0; i < waitersPer; i++ {
|
||||
launch("a")
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
}
|
||||
for i := 0; i < waitersPer; i++ {
|
||||
launch("b")
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
}
|
||||
// c collides with both → queues.
|
||||
launch("c")
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
<-a.runStarted
|
||||
<-pb.runStarted
|
||||
|
||||
if err := b.Shutdown(time.Second); err != nil {
|
||||
t.Fatalf("Shutdown: %v", err)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
codesMu.Lock()
|
||||
defer codesMu.Unlock()
|
||||
if len(codes) != 2*waitersPer+1 {
|
||||
t.Fatalf("got %d responses, want %d", len(codes), 2*waitersPer+1)
|
||||
}
|
||||
for i, c := range codes {
|
||||
if c == http.StatusOK {
|
||||
t.Errorf("response %d: status=%d, want non-200 (shutdown)", i, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_NoSwapWhileServing verifies that an already-loaded model
|
||||
// is not stopped to satisfy another model's swap while it is still handling
|
||||
// a request.
|
||||
//
|
||||
// Sequence:
|
||||
// 1. r1 (A) — A loads; ServeHTTP enters and is pinned via serveBlock.
|
||||
// 2. r2 (B, planner: B evicts A) — must NOT cause A.Stop while r1 is live.
|
||||
// 3. r3 (A) — arrives next; the existing code queues it because B's swap
|
||||
// intent collides with A.
|
||||
// 4. r1 released — A finishes r1, then r3 is served by A.
|
||||
// 5. B's swap then proceeds; r2 is served by B.
|
||||
//
|
||||
// fakeProcess.stoppedWhileServing flips true if Stop is ever called while
|
||||
// a ServeHTTP is in flight — a direct, race-free signal of the violation.
|
||||
func TestBaseRouter_NoSwapWhileServing(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
// autoReady left false: we markReady manually after observing runStarted,
|
||||
// so autoReady's setState(Ready) cannot race with a later Stop and leave
|
||||
// A in Ready, masking the bug.
|
||||
a.serveBlock = make(chan struct{})
|
||||
pb := newFakeProcess("b")
|
||||
// Same reasoning for B: park its swap on WaitReady until we choose.
|
||||
|
||||
planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}}
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, planner)
|
||||
|
||||
// r1 — load A and enter its ServeHTTP (which blocks on serveBlock).
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1) // handlerReq for r1
|
||||
<-a.runStarted
|
||||
a.markReady()
|
||||
waitProcessed(t, b.testProcessed, 1) // swapDone for A
|
||||
<-a.serveStarted
|
||||
|
||||
// r2 — would evict A. A must not be stopped while r1 is in flight.
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
// r3 — another request for A, arrives behind r2 and queues because
|
||||
// B's swap intent (which evicts A) is recorded as active.
|
||||
w3 := httptest.NewRecorder()
|
||||
done3 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w3, newRequest("a"))
|
||||
close(done3)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
// Release r1 (and r3 if it is fast-pathed onto the still-loaded A).
|
||||
// The router must hold off B's swap until A has drained.
|
||||
close(a.serveBlock)
|
||||
|
||||
select {
|
||||
case <-done1:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("r1 did not complete after serveBlock release")
|
||||
}
|
||||
|
||||
// Wait for B.Run before marking it ready: markReady before Run would
|
||||
// skip the Run path entirely and leave pb.runCalls at 0. In a correct
|
||||
// implementation B's swap only starts after A has drained; in the
|
||||
// current implementation it has already started — either way runStarted
|
||||
// fires.
|
||||
<-pb.runStarted
|
||||
pb.markReady()
|
||||
|
||||
select {
|
||||
case <-done2:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("r2 did not complete after B marked ready")
|
||||
}
|
||||
select {
|
||||
case <-done3:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("r3 did not complete")
|
||||
}
|
||||
|
||||
if w1.Code != http.StatusOK || w2.Code != http.StatusOK || w3.Code != http.StatusOK {
|
||||
t.Fatalf("statuses: w1=%d w2=%d w3=%d", w1.Code, w2.Code, w3.Code)
|
||||
}
|
||||
if w1.Body.String() != "ok:a" {
|
||||
t.Errorf("r1 body=%q want ok:a", w1.Body.String())
|
||||
}
|
||||
if w3.Body.String() != "ok:a" {
|
||||
t.Errorf("r3 body=%q want ok:a (r3 must be served by A)", w3.Body.String())
|
||||
}
|
||||
if w2.Body.String() != "ok:b" {
|
||||
t.Errorf("r2 body=%q want ok:b", w2.Body.String())
|
||||
}
|
||||
if a.stoppedWhileServing.Load() {
|
||||
t.Errorf("A.Stop was called while A was still handling a request — the router swapped out a busy process")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_ModelNotFound(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||
|
||||
@@ -0,0 +1,404 @@
|
||||
# Router design
|
||||
|
||||
A developer tutorial for the `internal/router` package and its `scheduler`
|
||||
sub-package.
|
||||
|
||||
## Intro
|
||||
|
||||
A llama-swap router is the component that sits behind the proxy and answers one
|
||||
question for every incoming request: _can this model serve right now, and if
|
||||
not, what has to happen first?_ Answering it means juggling three concerns that
|
||||
used to live tangled together in one type:
|
||||
|
||||
1. **Process machinery** — owning the OS processes, starting and stopping them,
|
||||
running health checks, and shuttling HTTP requests onto the right upstream.
|
||||
2. **Scheduling strategy** — the queue, in-flight bookkeeping, and the decision
|
||||
tree that turns one request into "serve now", "join an existing swap",
|
||||
"queue", or "start a swap".
|
||||
3. **Eviction policy** — given a model we want to load, which currently-running
|
||||
models have to be stopped to make room?
|
||||
|
||||
The design pulls those three apart into separate, independently replaceable
|
||||
pieces:
|
||||
|
||||
| Concern | Type | Lives in |
|
||||
| ------------------- | ------------------------------ | ------------------------------- |
|
||||
| Process machinery | `baseRouter` | `internal/router/base.go` |
|
||||
| Scheduling strategy | `scheduler.Scheduler` (`FIFO`) | `internal/router/scheduler/` |
|
||||
| Eviction policy | `scheduler.Swapper` | `groupSwapper`, `matrixSwapper` |
|
||||
|
||||
`baseRouter` keeps the channels, run loop, process lifecycle, and shutdown
|
||||
teardown, and exposes the side-effects a scheduler needs through the
|
||||
`scheduler.Effects` interface. The scheduler owns the queue and decision tree
|
||||
but performs no side-effects directly — it calls back through `Effects`. The
|
||||
`Swapper` is a pure function from "target model + currently running" to "models
|
||||
to evict", and knows nothing about queues, channels, or processes.
|
||||
|
||||
Because the seams are interfaces, you can replace the scheduling strategy
|
||||
without touching process management, or write a new eviction policy without
|
||||
touching either. `FIFO` is the first and currently only `Scheduler`;
|
||||
`groupSwapper` and `matrixSwapper` are the two `Swapper`s.
|
||||
|
||||
## Key concepts
|
||||
|
||||
### One run loop, no locks
|
||||
|
||||
`baseRouter.run()` is a single goroutine selecting over a handful of channels:
|
||||
|
||||
```go
|
||||
for {
|
||||
select {
|
||||
case req := <-b.shutdownCh: b.handleShutdown(req); return
|
||||
case req := <-b.handlerCh: b.schedule.OnRequest(req)
|
||||
case req := <-b.unloadCh: b.schedule.OnUnload(req.targets, req.timeout); close(req.respond)
|
||||
case ev := <-b.swapDoneCh: b.schedule.OnSwapDone(ev)
|
||||
case ev := <-b.serveDoneCh: b.schedule.OnServeDone(ev)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Every `Scheduler` method runs on this one goroutine. That is the single most
|
||||
important fact about the design: **the scheduler never needs a mutex for its own
|
||||
state**. All scheduler state is touched only from these callbacks, which are
|
||||
serialized by the run loop. If you write a new scheduler, you get the same
|
||||
guarantee for free — and you must not break it by spinning up goroutines that
|
||||
mutate scheduler state.
|
||||
|
||||
### Events flow in, side-effects flow out
|
||||
|
||||
The run loop turns external happenings into method calls on the scheduler:
|
||||
|
||||
- A new HTTP request becomes `OnRequest(HandlerReq)`.
|
||||
- A swap goroutine finishing becomes `OnSwapDone(SwapDone)`.
|
||||
- A tracked request handler returning becomes `OnServeDone(ServeDoneEvent)`.
|
||||
- An admin unload becomes `OnUnload(targets, timeout)`.
|
||||
- Shutdown becomes `OnShutdown(err)`.
|
||||
|
||||
The scheduler reacts by calling **back out** through `Effects`: inspect a
|
||||
process state, start a swap, grant a response to a caller, or stop processes. It
|
||||
never calls `process.Process` directly and never writes to a channel directly.
|
||||
This keeps the scheduler pure enough to unit-test against a fake `Effects` with
|
||||
no goroutines or real processes involved (see `scheduler/fifo_test.go`).
|
||||
|
||||
```
|
||||
HTTP request admin Unload / Shutdown
|
||||
│ │
|
||||
▼ ▼
|
||||
ServeHTTP ──HandlerReq──▶ baseRouter.run() ◀──unloadCh/shutdownCh
|
||||
│ (single goroutine)
|
||||
▼
|
||||
Scheduler.On*(...)
|
||||
│ calls back through
|
||||
▼
|
||||
Effects: ModelState / StartSwap /
|
||||
GrantServe / GrantError / StopProcesses
|
||||
│
|
||||
▼
|
||||
baseRouter side-effects: doSwap goroutine,
|
||||
grant() to caller, process.Stop()
|
||||
│
|
||||
swap completes ──SwapDone──▶ back into run loop
|
||||
```
|
||||
|
||||
### The swap goroutine
|
||||
|
||||
Scheduling decisions must be quick and non-blocking, but loading a model is
|
||||
slow. The two are reconciled by doing the slow part on a separate goroutine.
|
||||
|
||||
When the scheduler decides to start a swap, inside `OnRequest` it:
|
||||
|
||||
1. records "a swap for X is in flight" in its own state, then
|
||||
2. calls `Effects.StartSwap(modelID, evict)`.
|
||||
|
||||
`StartSwap` does **not** load the model itself — it just launches a detached
|
||||
goroutine (`doSwap`) and returns straight away. `doSwap` is what does the slow
|
||||
work: stop the evicted processes, start the target, wait for it to become ready.
|
||||
Because `StartSwap` returned immediately, `OnRequest` returns too, and the run
|
||||
loop is free to pick up the next event — another request, a serve-done, an
|
||||
unload — while `doSwap` runs in the background.
|
||||
|
||||
The swap's eventual result comes back as just another event: when `doSwap`
|
||||
finishes it posts a `SwapDone` onto `swapDoneCh`, which the run loop delivers as
|
||||
`OnSwapDone`. So a slow load never blocks the run loop; it brackets it with two
|
||||
quick events (`OnRequest` to start, `OnSwapDone` to finish) and everything in
|
||||
between is handled normally.
|
||||
|
||||
### In-flight tracking and `trackedServe`
|
||||
|
||||
When the scheduler grants a request, the handler it hands back is wrapped by
|
||||
`baseRouter.trackedServe`. The wrapper runs the real `ServeHTTP` and, on return,
|
||||
posts a `ServeDoneEvent` so the run loop can decrement the per-model in-flight
|
||||
count. This is why the scheduler can know whether a process is "busy": it counts
|
||||
grants out and serve-dones in. A swap that would evict a busy process is
|
||||
deferred until that process's in-flight count hits zero (`OnServeDone` then
|
||||
re-drains the queue).
|
||||
|
||||
The subtle contract here is `GrantServe`'s boolean return. The caller's
|
||||
`Respond` channel is unbuffered, so a successful send proves the HTTP goroutine
|
||||
is alive and took the handler. If the caller already disconnected, the send
|
||||
fails, `trackedServe` never runs, and **no** `ServeDoneEvent` will ever arrive —
|
||||
so the scheduler must only increment `inFlight` when `GrantServe` returns true.
|
||||
Incrementing on a false return would strand the counter above zero and the model
|
||||
could never be evicted again.
|
||||
|
||||
## The interfaces
|
||||
|
||||
All three live in `scheduler/scheduler.go`.
|
||||
|
||||
### `Scheduler`
|
||||
|
||||
```go
|
||||
type Scheduler interface {
|
||||
OnRequest(req HandlerReq)
|
||||
OnSwapDone(ev SwapDone)
|
||||
OnServeDone(ev ServeDoneEvent)
|
||||
OnUnload(targets []string, timeout time.Duration)
|
||||
OnShutdown(err error)
|
||||
}
|
||||
```
|
||||
|
||||
Owns the queue, in-flight tracking, and the decision tree. All methods run on
|
||||
the run-loop goroutine, so no internal locking is needed.
|
||||
|
||||
### `Swapper`
|
||||
|
||||
```go
|
||||
type Swapper interface {
|
||||
EvictionFor(target string, running []string) []string
|
||||
OnSwapStart(target string, running []string)
|
||||
}
|
||||
```
|
||||
|
||||
The eviction policy. `EvictionFor` is a **pure decision** — given the target and
|
||||
the complete `running` set, return the running model IDs that must stop. It must
|
||||
not log or mutate anything, and it does **not** inspect process state itself:
|
||||
the scheduler hands it `running` already assembled (every non-stopped process,
|
||||
unioned with the targets of in-flight swaps already committed but not yet
|
||||
visible in process state). That keeps the swapper a pure function of its inputs,
|
||||
with no reference to processes.
|
||||
|
||||
The reason it must not log is that it is a _speculative_ query — "what would we
|
||||
evict if we started this swap right now?" — called far more often than swaps
|
||||
actually happen. The scheduler calls it once per incoming request, and then
|
||||
**again for every still-queued request on every queue drain** (each `OnSwapDone`,
|
||||
`OnServeDone`, and `OnUnload`). Most of those calls end in "still queued",
|
||||
"collides", or "nothing to evict", not a real swap. Logging there would emit
|
||||
duplicate lines for a request that simply sits in the queue, and lines for
|
||||
decisions that never happen — the log would stop meaning "a swap occurred".
|
||||
|
||||
`OnSwapStart` is the one place a Swapper may log, because it is called exactly
|
||||
once, at the moment a swap is committed. One log line there equals one real swap,
|
||||
with the evict set that is genuinely being applied — which is why `matrixSwapper`
|
||||
re-solves and logs the full decision (set, DSL, cost) in `OnSwapStart` rather
|
||||
than in `EvictionFor`.
|
||||
|
||||
### `Effects`
|
||||
|
||||
```go
|
||||
type Effects interface {
|
||||
ModelState(modelID string) (process.ProcessState, bool)
|
||||
RunningModels() map[string]process.ProcessState
|
||||
StartSwap(modelID string, evict []string)
|
||||
GrantError(req HandlerReq, err error)
|
||||
GrantServe(req HandlerReq, modelID string) bool
|
||||
StopProcesses(timeout time.Duration, ids []string)
|
||||
}
|
||||
```
|
||||
|
||||
Implemented by `baseRouter`. This is the scheduler's entire window onto the
|
||||
outside world; everything else about the router is hidden from it. See the
|
||||
deep-dive below.
|
||||
|
||||
### `Factory` — wiring it together
|
||||
|
||||
```go
|
||||
type Factory func(name string, logger *logmon.Monitor, eff Effects) Scheduler
|
||||
```
|
||||
|
||||
`baseRouter` doesn't know which scheduler or swapper it has — it is handed a
|
||||
`Factory` at construction and calls it once, passing itself as the `Effects`.
|
||||
The concrete router captures its `Swapper` in the closure. From `group.go`:
|
||||
|
||||
```go
|
||||
swapper := &groupSwapper{ /* ... */ }
|
||||
base := newBaseRouter("group", conf, processes, proxylog,
|
||||
func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler {
|
||||
return scheduler.NewFIFO(name, logger, swapper, eff)
|
||||
})
|
||||
```
|
||||
|
||||
This closure is the single point where the three pieces meet: it binds a
|
||||
specific `Swapper` (`swapper`) and a specific `Scheduler` (`FIFO`) to the
|
||||
`baseRouter`'s `Effects` (`eff`).
|
||||
|
||||
**The swapper is a separate type from the concrete router.** There are currently two router implementations router.Group and router.Matrix. Each of these has a custom swapper that implements scheduler.Swapper for custom eviction logic. This decoupling of responsibilities makes it easy to implement custom swapping strategies.
|
||||
|
||||
### The events
|
||||
|
||||
A single goroutine in `baseRouter.run()` owns and serializes all state changes in the router. By processing events one at a time it ensures correctness and eliminates complex mutex lock logic.
|
||||
|
||||
These are the events the router currently uses:
|
||||
|
||||
```go
|
||||
type HandlerReq struct { // one in-flight ServeHTTP awaiting a decision
|
||||
Model string
|
||||
Ctx context.Context
|
||||
Respond chan HandlerResp // UNBUFFERED — see GrantServe contract
|
||||
PositionCh chan int // queue-position updates for the loading UI
|
||||
}
|
||||
|
||||
type HandlerResp struct { // the decision handed back to the caller
|
||||
HandleFunc http.HandlerFunc // serve with this, or...
|
||||
Err error // ...fail with this
|
||||
}
|
||||
|
||||
type SwapDone struct{ ModelID string; Err error } // swap goroutine finished
|
||||
type ServeDoneEvent struct{ ModelID string } // tracked handler returned
|
||||
```
|
||||
|
||||
## Deep-dive: the `Effects` interface and why it exists
|
||||
|
||||
`Effects` is the inversion-of-control boundary that makes the split possible.
|
||||
The scheduler decides and `baseRouter` _acts_. Pulling the side-effects behind this
|
||||
interface buys three things:
|
||||
|
||||
1. **Purity and testability.** The scheduler performs no I/O, starts no
|
||||
goroutines of its own, and touches no real processes. Its tests drive the
|
||||
`On*` methods directly and assert on a `fakeEffects` that just records the
|
||||
calls — synchronous, deterministic, no sleeps. (`scheduler/fifo_test.go`.)
|
||||
2. **A single, auditable side-effect surface.** Every externally-visible thing a
|
||||
scheduler can do is one of six methods. You can reason about the whole
|
||||
contract by reading one interface.
|
||||
3. **Decoupling lifetime.** The scheduler never holds a `process.Process`,
|
||||
never sees a channel, and never learns how shutdown teardown works. It only
|
||||
knows model IDs and states.
|
||||
|
||||
Method by method, as implemented in `base.go`:
|
||||
|
||||
- **`ModelState(modelID) (state, ok)`** — read-only snapshot of a process's
|
||||
state, and whether this router handles the model at all. The scheduler uses it
|
||||
for the "unknown model" check and the "already ready" fast path. Safe to call
|
||||
any time because the process map is fixed at construction and `State()` is a
|
||||
snapshot.
|
||||
|
||||
- **`RunningModels()`** — the state of every process that isn't stopped or shut
|
||||
down. The scheduler unions its keys with its own in-flight swap targets to
|
||||
build the `running` set it hands the `Swapper`, so the swapper never has to
|
||||
touch process state itself.
|
||||
|
||||
- **`StartSwap(modelID, evict)`** — fire-and-forget. `baseRouter` launches the
|
||||
`doSwap` goroutine and returns immediately; the result comes back later as a
|
||||
`SwapDone`. The scheduler records the swap as active _before_ calling this so
|
||||
that requests arriving in the meantime can join it.
|
||||
|
||||
- **`GrantError(req, err)`** — hand a caller an error response. Used for unknown
|
||||
models, failed swaps, unloads, and shutdown.
|
||||
|
||||
- **`GrantServe(req, modelID) bool`** — hand a caller the tracked handler for a
|
||||
ready model, returning whether the caller was still there to receive it. The
|
||||
scheduler increments the in-flight count **only on a true return** (see the
|
||||
in-flight contract above). This is the one `Effects` method whose return value
|
||||
carries state-machine significance.
|
||||
|
||||
- **`StopProcesses(timeout, ids)`** — stop processes in parallel and **block**
|
||||
until all have stopped. Used by `OnUnload` so an admin `Unload` call can
|
||||
guarantee the process is dead by the time it returns. (Note `StartSwap` is
|
||||
async but `StopProcesses` is sync — the difference is deliberate and tied to
|
||||
the caller's expectations.)
|
||||
|
||||
A useful way to hold it in your head: `Effects` is the scheduler's syscall
|
||||
table. The scheduler is a pure state machine; `Effects` is how it touches the
|
||||
world, and `baseRouter` is the kernel that implements those syscalls with real
|
||||
goroutines, channels, and processes.
|
||||
|
||||
## How to implement a new `Swapper`
|
||||
|
||||
A `Swapper` is a pure decision function plus a logging hook — the easiest of the three pieces to replace.
|
||||
|
||||
1. **Write the swapper type** and give it whatever config it needs to make a
|
||||
decision. It does **not** need the process map — the scheduler supplies the
|
||||
running set as an argument. `groupSwapper` holds only its group config;
|
||||
`matrixSwapper` holds only its solver and logger:
|
||||
|
||||
```go
|
||||
type mySwapper struct {
|
||||
config config.Config
|
||||
}
|
||||
```
|
||||
|
||||
2. **Implement `EvictionFor(target, running)`** as a _pure_ decision:
|
||||
- `running` is the complete live set, already assembled for you: every
|
||||
non-stopped process unioned with the targets of in-flight swaps the
|
||||
scheduler has committed to. You don't filter process state or fold in
|
||||
in-flight targets yourself, that's the scheduler's job. Just decide against the slice you're handed.
|
||||
- Return the list of model IDs in `running` that must stop for `target` to
|
||||
run. Return `nil`/empty when nothing needs evicting.
|
||||
- Do **not** mutate state here.
|
||||
- Do **not** log here. It can be called multiple times per request. Since it is pure function have tests verify the expected behaviour.
|
||||
|
||||
3. **Implement `OnSwapStart(target, running)`** — called once when a swap
|
||||
actually begins, with the same `running` set `EvictionFor` saw. This is the
|
||||
right place to log: one call equals one real swap. `matrixSwapper` re-solves
|
||||
and logs the chosen set and cost here; `groupSwapper` logs nothing.
|
||||
|
||||
4. **Wire it in** by instantiating the swapper in your router's constructor and
|
||||
capturing it in the `Factory` closure passed to `newBaseRouter` — exactly as
|
||||
`NewGroup` and `NewMatrix` do. The router struct itself only ever embeds
|
||||
`*baseRouter`; the swapper reaches the scheduler solely through that closure.
|
||||
|
||||
Reference implementations: `groupSwapper` (static group config) in `group.go`
|
||||
and `matrixSwapper` (cost-based set solver) in `matrix.go`.
|
||||
|
||||
## How to implement a new `Scheduler`
|
||||
|
||||
Replacing the scheduler means taking over the queue and the entire decision tree. Read `scheduler/fifo.go` end to end first — it is the reference implementation and the rules below are easiest to understand in context.
|
||||
|
||||
The rules you must honour:
|
||||
|
||||
- **Single goroutine.** Every method runs on the `baseRouter.run()` goroutine. Keep your state in plain maps/slices and never read or write it from another goroutine. If you need slow work done, hand it to `Effects.StartSwap` and react to the resulting `SwapDone` — do not block a method waiting for it.
|
||||
|
||||
- **Never block the run loop.** `OnRequest`, `OnSwapDone`, and `OnServeDone` must make a decision and return. The one method allowed to block is `OnUnload`, and only because it must wait on the synchronous `StopProcesses` so the admin caller's guarantee holds.
|
||||
|
||||
- **Respect the `GrantServe` boolean.** Only count a request as in-flight when `GrantServe` returns true (see the in-flight contract above). A false return means the caller is gone; no `ServeDoneEvent` will ever arrive, so incrementing on false permanently strands the counter.
|
||||
|
||||
- **Account for in-flight swaps in your running set.** When you call `Swapper.EvictionFor`, the running set you pass must include not just live processes (`Effects.RunningModels`) but also the targets of swaps you've already started that aren't yet visible in process state — otherwise the swapper contradicts decisions already in motion.
|
||||
|
||||
What each method must do:
|
||||
|
||||
- **`OnRequest(req)`** — every request must resolve to exactly one of: granted, errored, joined (piggybacks an in-flight swap), queued, or swap-started. No request may be silently dropped.
|
||||
|
||||
- **`OnSwapDone(ev)`** — deliver the result to every waiter that joined this swap (grant on success, error on `ev.Err`), drop the swap from active tracking, then re-examine anything queued — a finished swap may have unblocked it.
|
||||
|
||||
- **`OnServeDone(ev)`** — decrement the model's in-flight count; when it hits zero, re-examine the queue. Do **not** clear in-flight counts by hand; the handlers post their own `ServeDoneEvent`s on return.
|
||||
|
||||
- **`OnUnload(targets, timeout)`** — error out any waiters or queued requests for the unloaded models, call `Effects.StopProcesses` (synchronously — the admin caller relies on the process being dead afterwards), then re-examine the queue.
|
||||
|
||||
- **`OnShutdown(err)`** — error out every waiter you still hold (active swap waiters and queued requests). Don't touch processes; teardown is `baseRouter`'s job.
|
||||
|
||||
Expose a constructor matching the `Factory` shape:
|
||||
|
||||
```go
|
||||
func NewMyScheduler(name string, logger *logmon.Monitor, swapper Swapper, eff Effects) *MyScheduler {
|
||||
// ...
|
||||
}
|
||||
|
||||
// in the concrete router:
|
||||
base := newBaseRouter(name, conf, processes, proxylog,
|
||||
func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler {
|
||||
return scheduler.NewMyScheduler(name, logger, swapper, eff)
|
||||
})
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
- **Schedulers** are tested as pure state machines in the `scheduler` package:
|
||||
drive the `On*` methods directly against a `fakeEffects` and assert on the
|
||||
recorded grants/starts/stops. No goroutines, no sleeps. See
|
||||
`scheduler/fifo_test.go` as the reference; follow the `TestSchedulerName_<scenario>`
|
||||
naming convention.
|
||||
- **`baseRouter` mechanism** (run loop, `grant`/`ServeHTTP`, `Unload`,
|
||||
`Shutdown`) is tested in `base_test.go`. The run loop exposes a
|
||||
`testProcessed` channel so tests can wait for an event to be fully processed
|
||||
instead of sleeping.
|
||||
- Run new tests with `go test -v -run TestMyScheduler_... ./internal/router/scheduler/`,
|
||||
then `make test-dev` for a quick `go test` + `staticcheck` pass over `proxy/`.
|
||||
+17
-20
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
|
||||
)
|
||||
|
||||
type Group struct {
|
||||
@@ -14,7 +15,7 @@ type Group struct {
|
||||
|
||||
func NewGroup(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Group, error) {
|
||||
modelToGroup := make(map[string]string)
|
||||
for gid, gcfg := range conf.Groups {
|
||||
for gid, gcfg := range conf.Routing.Router.Settings.Groups {
|
||||
for _, mid := range gcfg.Members {
|
||||
if existing, dup := modelToGroup[mid]; dup {
|
||||
return nil, fmt.Errorf("model %q is in multiple groups: %q and %q", mid, existing, gid)
|
||||
@@ -23,25 +24,29 @@ func NewGroup(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Group
|
||||
}
|
||||
}
|
||||
|
||||
planner := &groupPlanner{
|
||||
swapper := &groupSwapper{
|
||||
config: conf,
|
||||
modelToGroup: modelToGroup,
|
||||
}
|
||||
|
||||
processes := make(map[string]process.Process, len(modelToGroup))
|
||||
base := newBaseRouter("group", conf, processes, planner, proxylog)
|
||||
planner.processes = processes
|
||||
base := newBaseRouter("group", conf, processes, proxylog,
|
||||
func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler {
|
||||
return scheduler.NewFIFO(name, logger, swapper, conf.Routing.Scheduler.Settings.Fifo, eff)
|
||||
})
|
||||
|
||||
for mid := range modelToGroup {
|
||||
modelCfg, _, ok := conf.FindConfig(mid)
|
||||
if !ok {
|
||||
base.shutdownFn()
|
||||
base.procCancel()
|
||||
return nil, fmt.Errorf("no model config for %q", mid)
|
||||
}
|
||||
procLog := logmon.NewWriter(upstreamlog)
|
||||
p, err := process.New(base.shutdownCtx, mid, modelCfg, procLog, proxylog)
|
||||
p, err := process.New(base.procCtx, mid, modelCfg, procLog, proxylog)
|
||||
if err != nil {
|
||||
base.shutdownFn()
|
||||
base.procCancel()
|
||||
return nil, fmt.Errorf("creating process for %q: %w", mid, err)
|
||||
}
|
||||
processes[mid] = p
|
||||
@@ -52,21 +57,20 @@ func NewGroup(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Group
|
||||
return g, nil
|
||||
}
|
||||
|
||||
// groupPlanner decides evictions from static group configuration.
|
||||
// groupSwapper decides evictions from static group configuration.
|
||||
//
|
||||
// Same-group siblings are stopped when the group has swap=true. Cross-group
|
||||
// members are stopped only when the target's group is exclusive; loading a
|
||||
// model from a non-exclusive group leaves running exclusive groups alone,
|
||||
// matching the gotcha in the original ProcessGroup behaviour.
|
||||
type groupPlanner struct {
|
||||
type groupSwapper struct {
|
||||
config config.Config
|
||||
modelToGroup map[string]string
|
||||
processes map[string]process.Process
|
||||
}
|
||||
|
||||
func (p *groupPlanner) EvictionFor(target string, alsoRunning []string) []string {
|
||||
func (p *groupSwapper) EvictionFor(target string, running []string) []string {
|
||||
tg := p.modelToGroup[target]
|
||||
tgCfg := p.config.Groups[tg]
|
||||
tgCfg := p.config.Routing.Router.Settings.Groups[tg]
|
||||
|
||||
seen := make(map[string]struct{})
|
||||
var result []string
|
||||
@@ -87,24 +91,17 @@ func (p *groupPlanner) EvictionFor(target string, alsoRunning []string) []string
|
||||
// for backwards compatibility. The newer swap matrix approach does not
|
||||
// have this issue.
|
||||
case og != tg && tgCfg.Exclusive:
|
||||
if ogCfg := p.config.Groups[og]; !ogCfg.Persistent {
|
||||
if ogCfg := p.config.Routing.Router.Settings.Groups[og]; !ogCfg.Persistent {
|
||||
seen[mID] = struct{}{}
|
||||
result = append(result, mID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for mID, proc := range p.processes {
|
||||
st := proc.State()
|
||||
if st == process.StateStopped || st == process.StateShutdown {
|
||||
continue
|
||||
}
|
||||
consider(mID)
|
||||
}
|
||||
for _, mID := range alsoRunning {
|
||||
for _, mID := range running {
|
||||
consider(mID)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (p *groupPlanner) OnSwapStart(target string) {}
|
||||
func (p *groupSwapper) OnSwapStart(target string, running []string) {}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
|
||||
)
|
||||
|
||||
// newTestGroup builds a Group directly from the supplied processes and config,
|
||||
@@ -17,17 +18,19 @@ import (
|
||||
func newTestGroup(t *testing.T, conf config.Config, processes map[string]process.Process) *Group {
|
||||
t.Helper()
|
||||
modelToGroup := make(map[string]string)
|
||||
for gid, gcfg := range conf.Groups {
|
||||
for gid, gcfg := range conf.Routing.Router.Settings.Groups {
|
||||
for _, mid := range gcfg.Members {
|
||||
modelToGroup[mid] = gid
|
||||
}
|
||||
}
|
||||
planner := &groupPlanner{
|
||||
swapper := &groupSwapper{
|
||||
config: conf,
|
||||
modelToGroup: modelToGroup,
|
||||
processes: processes,
|
||||
}
|
||||
base := newBaseRouter("group", conf, processes, planner, logmon.NewWriter(io.Discard))
|
||||
base := newBaseRouter("group", conf, processes, logmon.NewWriter(io.Discard),
|
||||
func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler {
|
||||
return scheduler.NewFIFO(name, logger, swapper, conf.Routing.Scheduler.Settings.Fifo, eff)
|
||||
})
|
||||
base.testProcessed = make(chan struct{}, 64)
|
||||
g := &Group{baseRouter: base}
|
||||
go base.run()
|
||||
@@ -41,10 +44,10 @@ func newTestGroup(t *testing.T, conf config.Config, processes map[string]process
|
||||
|
||||
func TestGroup_NewGroup_DuplicateMembership(t *testing.T) {
|
||||
conf := config.Config{
|
||||
Groups: map[string]config.GroupConfig{
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"g1": {Swap: true, Members: []string{"a"}},
|
||||
"g2": {Swap: true, Members: []string{"a"}},
|
||||
},
|
||||
}),
|
||||
Models: map[string]config.ModelConfig{
|
||||
"a": {},
|
||||
},
|
||||
@@ -65,9 +68,9 @@ func TestGroup_ServeHTTP_SwapStopsPrevious(t *testing.T) {
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Groups: map[string]config.GroupConfig{
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"g": {Swap: true, Exclusive: true, Members: []string{"a", "b"}},
|
||||
},
|
||||
}),
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
@@ -97,9 +100,9 @@ func TestGroup_NonSwapGroup_NoStop(t *testing.T) {
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Groups: map[string]config.GroupConfig{
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"g": {Swap: false, Exclusive: false, Members: []string{"a", "b"}},
|
||||
},
|
||||
}),
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
@@ -127,10 +130,10 @@ func TestGroup_CrossGroupExclusive(t *testing.T) {
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Groups: map[string]config.GroupConfig{
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"g1": {Swap: true, Exclusive: true, Members: []string{"a"}},
|
||||
"g2": {Swap: true, Exclusive: true, Members: []string{"b"}},
|
||||
},
|
||||
}),
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
@@ -154,10 +157,10 @@ func TestGroup_CrossGroupNonExclusiveParallel(t *testing.T) {
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Groups: map[string]config.GroupConfig{
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"g1": {Swap: true, Exclusive: false, Members: []string{"a"}},
|
||||
"g2": {Swap: true, Exclusive: false, Members: []string{"b"}},
|
||||
},
|
||||
}),
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": pb})
|
||||
|
||||
@@ -202,16 +205,17 @@ func TestGroup_CrossGroupNonExclusiveParallel(t *testing.T) {
|
||||
|
||||
// TestGroup_SameGroupSwapSerialises verifies that two same-group requests
|
||||
// (Swap=true) serialise even when both arrive while neither has reached
|
||||
// StateStarting yet — the alsoRunning hint to the planner closes that race.
|
||||
// StateStarting yet — the in-flight swap target the scheduler folds into the
|
||||
// running set closes that race.
|
||||
func TestGroup_SameGroupSwapSerialises(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Groups: map[string]config.GroupConfig{
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"g": {Swap: true, Exclusive: false, Members: []string{"a", "b"}},
|
||||
},
|
||||
}),
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": pb})
|
||||
|
||||
@@ -224,8 +228,9 @@ func TestGroup_SameGroupSwapSerialises(t *testing.T) {
|
||||
waitProcessed(t, g.testProcessed, 1)
|
||||
|
||||
// Request B arrives before A transitions to StateStarting in the process
|
||||
// state machine. Without the alsoRunning hint, the planner would not see
|
||||
// A as running, and B would start in parallel, violating Swap=true.
|
||||
// state machine. Without folding the in-flight swap target into the running
|
||||
// set, the swapper would not see A as running, and B would start in
|
||||
// parallel, violating Swap=true.
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
@@ -269,10 +274,10 @@ func TestGroup_PersistentNotEvicted(t *testing.T) {
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Groups: map[string]config.GroupConfig{
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"persist": {Swap: true, Exclusive: false, Persistent: true, Members: []string{"a"}},
|
||||
"other": {Swap: true, Exclusive: true, Members: []string{"b"}},
|
||||
},
|
||||
}),
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
@@ -306,10 +311,10 @@ func TestGroup_NonExclusiveDoesNotUnloadExclusive(t *testing.T) {
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Groups: map[string]config.GroupConfig{
|
||||
Routing: groupRouting(map[string]config.GroupConfig{
|
||||
"g1": {Swap: true, Exclusive: true, Members: []string{"a"}},
|
||||
"g2": {Swap: true, Exclusive: false, Members: []string{"b"}},
|
||||
},
|
||||
}),
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
|
||||
@@ -12,10 +12,23 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
)
|
||||
|
||||
// groupRouting builds a normalized RoutingConfig for the group router, mirroring
|
||||
// what config.LoadConfigFromReader produces. Tests use it to populate
|
||||
// config.Config.Routing without going through LoadConfig.
|
||||
func groupRouting(groups map[string]config.GroupConfig) config.RoutingConfig {
|
||||
return config.RoutingConfig{
|
||||
Router: config.RouterConfig{
|
||||
Use: "group",
|
||||
Settings: config.RouterSettings{Groups: groups},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// fakeProcess is an in-memory implementation of process.Process used to drive
|
||||
// the routers through their state machine without spawning real upstreams.
|
||||
type fakeProcess struct {
|
||||
|
||||
@@ -38,6 +38,13 @@ type loadingWriter struct {
|
||||
pendingMu sync.Mutex
|
||||
pendingUpdate string
|
||||
|
||||
// writeMu serializes writes to the underlying writer and guards released.
|
||||
// Once released is set, the streaming goroutine must not touch the writer
|
||||
// again — ServeHTTP has reclaimed it (to run the real handler or to return)
|
||||
// and writing/flushing a finalized response panics.
|
||||
writeMu sync.Mutex
|
||||
released bool
|
||||
|
||||
// closed by start when the goroutine finishes (after cleanup messages)
|
||||
done chan struct{}
|
||||
|
||||
@@ -217,12 +224,33 @@ func (s *loadingWriter) sendData(data string) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData)
|
||||
if err != nil {
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
// Once ServeHTTP has reclaimed the writer (release), writing/flushing it
|
||||
// races the real handler or panics on a finalized response. Stop here.
|
||||
if s.released {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData); err != nil {
|
||||
s.logger.Debugf("<%s> Failed to write SSE data (client likely disconnected): %v", s.modelName, err)
|
||||
return
|
||||
}
|
||||
s.Flush()
|
||||
if flusher, ok := s.writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// release fences the loadingWriter off from the underlying ResponseWriter.
|
||||
// After it returns, the streaming goroutine will not write to or flush the
|
||||
// writer again: any in-flight write completes under writeMu first, and later
|
||||
// writes short-circuit on released. The caller can then safely hand the writer
|
||||
// to the real handler or let ServeHTTP return without racing a finalized
|
||||
// response (a use-after-return Flush panics on the recycled *bufio.Writer).
|
||||
func (s *loadingWriter) release() {
|
||||
s.writeMu.Lock()
|
||||
s.released = true
|
||||
s.writeMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *loadingWriter) Header() http.Header {
|
||||
|
||||
@@ -226,69 +226,6 @@ func TestIsLoadingPath(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_Streaming_GET(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
wantStreaming bool
|
||||
}{
|
||||
{"streaming true", "model=llama3&stream=true", true},
|
||||
{"streaming false", "model=llama3&stream=false", false},
|
||||
{"no stream param", "model=llama3", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
|
||||
got, err := ExtractContext(r)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
if got.Streaming != tt.wantStreaming {
|
||||
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_Streaming_JSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantStreaming bool
|
||||
}{
|
||||
{"streaming true", `{"model":"llama3","stream":true}`, true},
|
||||
{"streaming false", `{"model":"llama3","stream":false}`, false},
|
||||
{"no stream param", `{"model":"llama3"}`, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
got, err := ExtractContext(r)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
if got.Streaming != tt.wantStreaming {
|
||||
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_Streaming_URLEncodedForm(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader("model=whisper-1&stream=true"))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
got, err := ExtractContext(r)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
if !got.Streaming {
|
||||
t.Error("Streaming should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func countSSEMessages(s string) int {
|
||||
scanner := bufio.NewScanner(strings.NewReader(s))
|
||||
count := 0
|
||||
|
||||
+19
-46
@@ -2,11 +2,11 @@ package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
|
||||
)
|
||||
|
||||
type Matrix struct {
|
||||
@@ -14,26 +14,30 @@ type Matrix struct {
|
||||
}
|
||||
|
||||
func NewMatrix(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Matrix, error) {
|
||||
if conf.Matrix == nil {
|
||||
mtx := conf.Routing.Router.Settings.Matrix
|
||||
if mtx == nil {
|
||||
return nil, fmt.Errorf("matrix router requires a matrix configuration")
|
||||
}
|
||||
|
||||
planner := &matrixPlanner{
|
||||
solver: newMatrixSolver(conf.ExpandedSets, conf.Matrix.ResolvedEvictCosts()),
|
||||
swapper := &matrixSwapper{
|
||||
solver: newMatrixSolver(mtx.ExpandedSets, mtx.ResolvedEvictCosts()),
|
||||
logger: proxylog,
|
||||
}
|
||||
|
||||
// Build a process for every model in the config. Any model can run alone
|
||||
// even if it is not part of a set; this mirrors proxy.NewMatrix.
|
||||
processes := make(map[string]process.Process, len(conf.Models))
|
||||
base := newBaseRouter("matrix", conf, processes, planner, proxylog)
|
||||
planner.processes = processes
|
||||
base := newBaseRouter("matrix", conf, processes, proxylog,
|
||||
func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler {
|
||||
return scheduler.NewFIFO(name, logger, swapper, conf.Routing.Scheduler.Settings.Fifo, eff)
|
||||
})
|
||||
|
||||
for mid, modelCfg := range conf.Models {
|
||||
procLog := logmon.NewWriter(upstreamlog)
|
||||
p, err := process.New(base.shutdownCtx, mid, modelCfg, procLog, proxylog)
|
||||
p, err := process.New(base.procCtx, mid, modelCfg, procLog, proxylog)
|
||||
if err != nil {
|
||||
base.shutdownFn()
|
||||
base.procCancel()
|
||||
return nil, fmt.Errorf("creating process for %q: %w", mid, err)
|
||||
}
|
||||
processes[mid] = p
|
||||
@@ -44,20 +48,18 @@ func NewMatrix(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Matr
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// matrixPlanner decides evictions by asking the matrix solver against the
|
||||
// current running set.
|
||||
type matrixPlanner struct {
|
||||
solver *matrixSolver
|
||||
processes map[string]process.Process
|
||||
logger *logmon.Monitor
|
||||
// matrixSwapper decides evictions by asking the matrix solver against the
|
||||
// running set the scheduler hands it.
|
||||
type matrixSwapper struct {
|
||||
solver *matrixSolver
|
||||
logger *logmon.Monitor
|
||||
}
|
||||
|
||||
func (p *matrixPlanner) EvictionFor(target string, alsoRunning []string) []string {
|
||||
return p.solver.Solve(target, p.runningSet(alsoRunning)).Evict
|
||||
func (p *matrixSwapper) EvictionFor(target string, running []string) []string {
|
||||
return p.solver.Solve(target, running).Evict
|
||||
}
|
||||
|
||||
func (p *matrixPlanner) OnSwapStart(target string) {
|
||||
running := p.runningModels()
|
||||
func (p *matrixSwapper) OnSwapStart(target string, running []string) {
|
||||
result := p.solver.Solve(target, running)
|
||||
switch {
|
||||
case len(result.Evict) > 0:
|
||||
@@ -69,32 +71,3 @@ func (p *matrixPlanner) OnSwapStart(target string) {
|
||||
p.logger.Debugf("matrix: model=%s already running in set=%s dsl=%q", target, result.SetName, result.DSL)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *matrixPlanner) runningModels() []string {
|
||||
return p.runningSet(nil)
|
||||
}
|
||||
|
||||
// runningSet returns the union of live processes (State != Stopped/Shutdown)
|
||||
// and any extra IDs the baseRouter has already committed to loading but which
|
||||
// the process state machine has not yet reflected.
|
||||
func (p *matrixPlanner) runningSet(alsoRunning []string) []string {
|
||||
seen := make(map[string]struct{}, len(p.processes))
|
||||
var running []string
|
||||
for id, proc := range p.processes {
|
||||
st := proc.State()
|
||||
if st == process.StateStopped || st == process.StateShutdown {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
running = append(running, id)
|
||||
}
|
||||
for _, id := range alsoRunning {
|
||||
if _, dup := seen[id]; dup {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
running = append(running, id)
|
||||
}
|
||||
sort.Strings(running)
|
||||
return running
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
|
||||
)
|
||||
|
||||
// newTestMatrix builds a Matrix router from supplied processes, bypassing
|
||||
@@ -17,12 +18,14 @@ import (
|
||||
func newTestMatrix(t *testing.T, conf config.Config, expanded []config.ExpandedSet, evictCosts map[string]int, processes map[string]process.Process) *Matrix {
|
||||
t.Helper()
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
planner := &matrixPlanner{
|
||||
solver: newMatrixSolver(expanded, evictCosts),
|
||||
processes: processes,
|
||||
logger: logger,
|
||||
swapper := &matrixSwapper{
|
||||
solver: newMatrixSolver(expanded, evictCosts),
|
||||
logger: logger,
|
||||
}
|
||||
base := newBaseRouter("matrix", conf, processes, planner, logger)
|
||||
base := newBaseRouter("matrix", conf, processes, logger,
|
||||
func(name string, l *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler {
|
||||
return scheduler.NewFIFO(name, l, swapper, conf.Routing.Scheduler.Settings.Fifo, eff)
|
||||
})
|
||||
base.testProcessed = make(chan struct{}, 64)
|
||||
r := &Matrix{baseRouter: base}
|
||||
go base.run()
|
||||
@@ -153,8 +156,8 @@ func TestMatrix_CoexistingSetParallel(t *testing.T) {
|
||||
|
||||
// TestMatrix_IncompatibleQueues verifies that the second request for a model
|
||||
// that cannot coexist with the in-flight first model queues until the first
|
||||
// completes, and then evicts it. This exercises the alsoRunning hint via the
|
||||
// matrix solver's union into runningSet.
|
||||
// completes, and then evicts it. This exercises the scheduler folding in-flight
|
||||
// swap targets into the running set it hands the swapper.
|
||||
func TestMatrix_IncompatibleQueues(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
@@ -173,8 +176,9 @@ func TestMatrix_IncompatibleQueues(t *testing.T) {
|
||||
}()
|
||||
waitProcessed(t, r.testProcessed, 1)
|
||||
|
||||
// B arrives before A transitions to StateStarting. The solver sees A via
|
||||
// alsoRunning and returns evict=[a], so collidesWith forces B to queue.
|
||||
// B arrives before A transitions to StateStarting. The running set the
|
||||
// scheduler builds includes A (an in-flight swap target), so the solver
|
||||
// returns evict=[a] and collidesWith forces B to queue.
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
|
||||
+11
-11
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
type peerMember struct {
|
||||
@@ -62,13 +63,12 @@ func NewPeer(cfg config.Config, logger *logmon.Monitor) (*Peer, error) {
|
||||
IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second,
|
||||
}
|
||||
|
||||
reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL)
|
||||
reverseProxy.Transport = peerTransport
|
||||
|
||||
originalDirector := reverseProxy.Director
|
||||
reverseProxy.Director = func(req *http.Request) {
|
||||
originalDirector(req)
|
||||
req.Host = req.URL.Host
|
||||
reverseProxy := &httputil.ReverseProxy{
|
||||
Transport: peerTransport,
|
||||
Rewrite: func(r *httputil.ProxyRequest) {
|
||||
r.SetURL(peer.ProxyURL)
|
||||
r.Out.Host = r.Out.URL.Host
|
||||
},
|
||||
}
|
||||
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
@@ -147,22 +147,22 @@ func (r *Peer) Shutdown(timeout time.Duration) error {
|
||||
|
||||
func (r *Peer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
if r.shuttingDown.Load() {
|
||||
SendError(w, req, fmt.Errorf("peer proxy is shutting down"))
|
||||
shared.SendError(w, req, fmt.Errorf("peer proxy is shutting down"))
|
||||
return
|
||||
}
|
||||
r.inflight.Add(1)
|
||||
defer r.inflight.Done()
|
||||
|
||||
data, err := FetchContext(req, r.cfg)
|
||||
data, err := shared.FetchContext(req, r.cfg)
|
||||
if err != nil {
|
||||
SendError(w, req, err)
|
||||
shared.SendError(w, req, err)
|
||||
return
|
||||
}
|
||||
|
||||
pp, found := r.peers[data.ModelID]
|
||||
if !found {
|
||||
r.logger.Warnf("peer model not found: %s", data.ModelID)
|
||||
SendError(w, req, ErrNoPeerModelFound)
|
||||
shared.SendError(w, req, ErrNoPeerModelFound)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
var testLogger = logmon.NewWriter(os.Stdout)
|
||||
@@ -142,7 +143,7 @@ func TestPeer_ServeHTTP_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
@@ -178,7 +179,7 @@ func TestPeer_ServeHTTP_PeerModelNotFound(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "nonexistent-model", ModelID: "nonexistent-model"}))
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "nonexistent-model", ModelID: "nonexistent-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
@@ -212,7 +213,7 @@ func TestPeer_ServeHTTP_ApiKeyInjection(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
@@ -246,7 +247,7 @@ func TestPeer_ServeHTTP_NoApiKey(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
@@ -279,7 +280,7 @@ func TestPeer_ServeHTTP_HostHeaderSet(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
@@ -311,7 +312,7 @@ func TestPeer_ServeHTTP_SSEHeaderModification(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
@@ -347,7 +348,7 @@ func TestPeer_ServeHTTP_ShutdownRejectsNewRequests(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
@@ -385,7 +386,7 @@ func TestPeer_ServeHTTP_WaitsForInflightDuringShutdown(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
@@ -448,7 +449,7 @@ func TestPeer_ServeHTTP_ShutdownTimeoutCancelsInflight(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
@@ -551,7 +552,7 @@ func TestPeer_ServeHTTP_ContextOverridesBodyModel(t *testing.T) {
|
||||
body := strings.NewReader(`{"model":"body-model","prompt":"hello"}`)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "context-model", ModelID: "context-model"}))
|
||||
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "context-model", ModelID: "context-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
|
||||
+4
-151
@@ -1,39 +1,18 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
type contextkey struct {
|
||||
name string
|
||||
}
|
||||
|
||||
type ReqContextData struct {
|
||||
Model string
|
||||
ModelID string
|
||||
Streaming bool
|
||||
SendLoadingState bool
|
||||
}
|
||||
|
||||
var (
|
||||
ErrNoModelInContext = fmt.Errorf("no model in request context")
|
||||
ErrNoRouterFound = fmt.Errorf("no router found for model")
|
||||
ErrNoPeerModelFound = fmt.Errorf("peer model not found")
|
||||
ErrNoLocalModelFound = fmt.Errorf("local model not found")
|
||||
|
||||
ContextKey = &contextkey{"context"}
|
||||
ErrNoRouterFound = shared.ErrNoRouterFound
|
||||
ErrNoPeerModelFound = shared.ErrNoPeerModelFound
|
||||
ErrNoLocalModelFound = shared.ErrNoLocalModelFound
|
||||
)
|
||||
|
||||
type Router interface {
|
||||
@@ -71,129 +50,3 @@ type LocalRouter interface {
|
||||
// model is not known to this router.
|
||||
ProcessLogger(modelID string) (*logmon.Monitor, bool)
|
||||
}
|
||||
|
||||
// FetchContext will attempt to get the model id from the context then
|
||||
// from the model body. If it extracts the model from the body it will
|
||||
// store the model in the context for downstream handlers. An error
|
||||
// will be returned when model can not be fetch from either location.
|
||||
func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
|
||||
data, ok := ReadContext(r.Context())
|
||||
if ok {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
if data, err := ExtractContext(r); err == nil {
|
||||
realName, _ := cfg.RealModelName(data.Model)
|
||||
if realName == "" {
|
||||
realName = data.Model
|
||||
}
|
||||
data.ModelID = realName
|
||||
if mc, ok := cfg.Models[realName]; ok {
|
||||
data.SendLoadingState = mc.SendLoadingState != nil && *mc.SendLoadingState
|
||||
}
|
||||
*r = *r.WithContext(SetContext(r.Context(), data))
|
||||
return data, nil
|
||||
}
|
||||
|
||||
return ReqContextData{}, ErrNoModelInContext
|
||||
}
|
||||
|
||||
func SetContext(ctx context.Context, data ReqContextData) context.Context {
|
||||
return context.WithValue(ctx, ContextKey, data)
|
||||
}
|
||||
|
||||
func ReadContext(ctx context.Context) (ReqContextData, bool) {
|
||||
data, ok := ctx.Value(ContextKey).(ReqContextData)
|
||||
return data, ok
|
||||
}
|
||||
|
||||
// ExtractContext pulls the model name from an HTTP request without consuming the
|
||||
// body. For GET requests it reads the "model" query parameter. For POST
|
||||
// requests it inspects Content-Type and parses JSON, multipart/form-data, or
|
||||
// application/x-www-form-urlencoded bodies. The request body is always restored
|
||||
// before returning so downstream handlers — including reverse proxies that
|
||||
// forward raw bytes upstream — can still read it.
|
||||
func ExtractContext(r *http.Request) (ReqContextData, error) {
|
||||
if r.Method == http.MethodGet {
|
||||
if model := r.URL.Query().Get("model"); model != "" {
|
||||
return ReqContextData{Model: model, Streaming: r.URL.Query().Get("stream") == "true"}, nil
|
||||
}
|
||||
return ReqContextData{}, fmt.Errorf("missing 'model' query parameter")
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return ReqContextData{}, fmt.Errorf("error reading request body: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
}()
|
||||
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
model := gjson.GetBytes(bodyBytes, "model").String()
|
||||
if model == "" {
|
||||
return ReqContextData{}, fmt.Errorf("missing or empty 'model' in JSON body")
|
||||
}
|
||||
return ReqContextData{Model: model, Streaming: gjson.GetBytes(bodyBytes, "stream").Bool()}, nil
|
||||
}
|
||||
|
||||
// Form parsers read from r.Body, so feed them a fresh reader over the
|
||||
// buffered bytes. The deferred restore above will reset r.Body again
|
||||
// after parsing.
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
if strings.Contains(contentType, "multipart/form-data") {
|
||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||
return ReqContextData{}, fmt.Errorf("error parsing multipart form: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return ReqContextData{}, fmt.Errorf("error parsing form: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if model := r.FormValue("model"); model != "" {
|
||||
return ReqContextData{Model: model, Streaming: r.FormValue("stream") == "true"}, nil
|
||||
}
|
||||
|
||||
return ReqContextData{}, fmt.Errorf("missing 'model' parameter")
|
||||
}
|
||||
|
||||
func SendError(w http.ResponseWriter, r *http.Request, err error) {
|
||||
switch {
|
||||
case errors.Is(err, ErrNoModelInContext):
|
||||
SendResponse(w, r, http.StatusNotFound, "no model id could be identified")
|
||||
case errors.Is(err, ErrNoPeerModelFound):
|
||||
SendResponse(w, r, http.StatusNotFound, "no peer found for requested model")
|
||||
case errors.Is(err, ErrNoLocalModelFound):
|
||||
SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
|
||||
case errors.Is(err, ErrNoRouterFound):
|
||||
SendResponse(w, r, http.StatusNotFound, "no router for requested model")
|
||||
default:
|
||||
SendResponse(w, r, http.StatusInternalServerError, fmt.Sprintf("unspecific error: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// SendResponse detects what content type the client prefers and returns an error response in that format.
|
||||
func SendResponse(w http.ResponseWriter, r *http.Request, status int, message string) {
|
||||
// Check Accept header for preferred response format
|
||||
acceptHeader := r.Header.Get("Accept")
|
||||
if strings.Contains(acceptHeader, "text/plain") {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(fmt.Sprintf("llama-swap: %s", message)))
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(acceptHeader, "text/html") {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(fmt.Sprintf(`<html><body><h1>llama-swap</h1><p>%s</p></body></html>`, message)))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(fmt.Sprintf(`{"src":"llama-swap", "error": "%s"}`, message)))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,451 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
)
|
||||
|
||||
// activeSwap tracks one in-flight swap and the callers waiting on it.
|
||||
type activeSwap struct {
|
||||
modelID string
|
||||
evict []string
|
||||
waiters []HandlerReq
|
||||
}
|
||||
|
||||
// FIFO is the default scheduler. Requests are handled in a first-in, first-out order.
|
||||
// To reduce swapping requests for a model that is already running will be handled
|
||||
// immediately by the running process.
|
||||
//
|
||||
// Requests into this schedule are handled like this:
|
||||
//
|
||||
// A B C A B C --> A A B B C C
|
||||
//
|
||||
// The strategy is simple and reduces the number of swaps required.
|
||||
type FIFO struct {
|
||||
name string
|
||||
logger *logmon.Monitor
|
||||
planner Swapper
|
||||
cfg config.FifoConfig
|
||||
effects Effects
|
||||
|
||||
active map[string]*activeSwap
|
||||
inFlight map[string]int
|
||||
queued []HandlerReq
|
||||
}
|
||||
|
||||
// NewFIFO builds a FIFO scheduler. It matches scheduler.Factory once a planner
|
||||
// is captured in a closure.
|
||||
func NewFIFO(name string, logger *logmon.Monitor, planner Swapper, cfg config.FifoConfig, eff Effects) *FIFO {
|
||||
return &FIFO{
|
||||
name: name,
|
||||
logger: logger,
|
||||
planner: planner,
|
||||
cfg: cfg,
|
||||
effects: eff,
|
||||
active: make(map[string]*activeSwap),
|
||||
inFlight: make(map[string]int),
|
||||
}
|
||||
}
|
||||
|
||||
// OnRequest decides what to do with one incoming ServeHTTP request. It never
|
||||
// blocks indefinitely: any work that has to wait (starting a process, stopping
|
||||
// siblings, waiting for ready) is deferred to a swap goroutine and reported back
|
||||
// via OnSwapDone.
|
||||
//
|
||||
// The decision tree, in order:
|
||||
//
|
||||
// 1. Unknown model — respond with ErrModelNotFound and move on.
|
||||
// 2. A swap to the same model is already in flight — attach this waiter so
|
||||
// one swap serves all callers that asked for the same model.
|
||||
// 3. Fast path — the target process is already ready, the planner sees
|
||||
// nothing to evict, and no in-flight swap is evicting it. Hand back its
|
||||
// ServeHTTP immediately.
|
||||
// 4. Would collide with an in-flight swap (we'd stop their target, or they're
|
||||
// stopping us) — park in the queue for OnSwapDone to drain.
|
||||
// 5. Would evict a process that is still handling requests — park in the
|
||||
// queue. OnServeDone will retry when the busy process drains.
|
||||
// 6. Otherwise — start a new swap. This may run in parallel with other active
|
||||
// swaps when their evict sets don't intersect.
|
||||
func (s *FIFO) OnRequest(req HandlerReq) {
|
||||
// (1) Unknown model.
|
||||
state, ok := s.effects.ModelState(req.Model)
|
||||
if !ok {
|
||||
s.logger.Debugf("%s: model %s not handled by this router", s.name, req.Model)
|
||||
s.effects.GrantError(req, ErrModelNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// (2) Join an in-flight swap for the same model.
|
||||
if sw, ok := s.active[req.Model]; ok {
|
||||
s.logger.Debugf("%s: joining in-flight swap for model %s (%d waiters)", s.name, req.Model, len(sw.waiters)+1)
|
||||
sw.waiters = append(sw.waiters, req)
|
||||
return
|
||||
}
|
||||
|
||||
running := s.runningSet(req.Model)
|
||||
evict := s.planner.EvictionFor(req.Model, running)
|
||||
|
||||
// (3) Fast path: ready, nothing to evict, and nobody is evicting us.
|
||||
if state == process.StateReady && len(evict) == 0 && !collidesWith(req.Model, evict, s.active) {
|
||||
s.logger.Debugf("%s: fast-path serving model %s (already ready)", s.name, req.Model)
|
||||
s.grantHandler(req, req.Model)
|
||||
return
|
||||
}
|
||||
|
||||
// (4) Collision with an in-flight swap — queue.
|
||||
if collidesWith(req.Model, evict, s.active) {
|
||||
s.logger.Debugf("%s: queuing request for model %s (collides with in-flight swap)", s.name, req.Model)
|
||||
s.enqueue(req)
|
||||
return
|
||||
}
|
||||
|
||||
// (5) Would evict a busy process — queue until it drains.
|
||||
if conflictsWithInFlight(evict, s.inFlight) {
|
||||
s.logger.Debugf("%s: queuing request for model %s (would evict in-flight process)", s.name, req.Model)
|
||||
s.enqueue(req)
|
||||
return
|
||||
}
|
||||
|
||||
// (6) Start a new (possibly parallel) swap.
|
||||
s.logger.Debugf("%s: starting swap for model %s, evicting %v", s.name, req.Model, evict)
|
||||
s.startSwap(req, evict, running)
|
||||
}
|
||||
|
||||
// OnCancel removes a request whose client has disconnected from the queue and
|
||||
// from every in-flight swap's waiters. If the request was the sole waiter of an
|
||||
// active swap, the swap goroutine is left to complete on its own — OnSwapDone
|
||||
// will find no waiters and simply clean up. This prevents drainQueue from ever
|
||||
// starting a model load for a caller that is no longer there.
|
||||
func (s *FIFO) OnCancel(req HandlerReq) {
|
||||
removed := false
|
||||
|
||||
// Prune from the queue.
|
||||
if len(s.queued) > 0 {
|
||||
kept := s.queued[:0]
|
||||
for _, q := range s.queued {
|
||||
if q.Respond == req.Respond {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
kept = append(kept, q)
|
||||
}
|
||||
s.queued = kept
|
||||
}
|
||||
|
||||
// Prune from any active swap's waiters.
|
||||
for _, sw := range s.active {
|
||||
filtered := sw.waiters[:0]
|
||||
for _, w := range sw.waiters {
|
||||
if w.Respond == req.Respond {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, w)
|
||||
}
|
||||
sw.waiters = filtered
|
||||
}
|
||||
|
||||
if removed {
|
||||
s.logger.Debugf("%s: cancelled request for model %s pruned from scheduler", s.name, req.Model)
|
||||
broadcastQueuePositions(s.queued)
|
||||
}
|
||||
}
|
||||
|
||||
// OnSwapDone fans the result out to every waiter that joined this swap, removes
|
||||
// the swap from the active map, then walks the queue once, promoting any items
|
||||
// that no longer collide with the remaining active set. FIFO order is preserved:
|
||||
// items still blocked stay in place.
|
||||
func (s *FIFO) OnSwapDone(ev SwapDone) {
|
||||
sw, ok := s.active[ev.ModelID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(s.active, ev.ModelID)
|
||||
|
||||
for _, w := range sw.waiters {
|
||||
if ev.Err != nil {
|
||||
s.effects.GrantError(w, ev.Err)
|
||||
} else {
|
||||
s.grantHandler(w, ev.ModelID)
|
||||
}
|
||||
}
|
||||
|
||||
s.drainQueue()
|
||||
}
|
||||
|
||||
// OnServeDone decrements the per-model in-flight count and, when that drops to
|
||||
// zero, retries the queue: requests whose swap was deferred because they would
|
||||
// have evicted this (now-idle) process can now proceed.
|
||||
func (s *FIFO) OnServeDone(ev ServeDoneEvent) {
|
||||
s.inFlight[ev.ModelID]--
|
||||
if s.inFlight[ev.ModelID] <= 0 {
|
||||
delete(s.inFlight, ev.ModelID)
|
||||
s.drainQueue()
|
||||
}
|
||||
}
|
||||
|
||||
// OnUnload reconciles router-owned state with the impending Stop, performs the
|
||||
// Stop (synchronously, via Effects) so callers of Unload remain blocked until
|
||||
// each targeted process has exited, then drains the queue.
|
||||
func (s *FIFO) OnUnload(targets []string, timeout time.Duration) {
|
||||
unloadErr := fmt.Errorf("%s: model unloaded", s.name)
|
||||
|
||||
targetSet := make(map[string]bool, len(targets))
|
||||
for _, id := range targets {
|
||||
targetSet[id] = true
|
||||
}
|
||||
|
||||
// Release waiters of any in-flight swap whose target is being unloaded.
|
||||
// The swap goroutine itself is left to finish on its own; when its
|
||||
// SwapDone arrives, OnSwapDone will find no entry in active and drop it.
|
||||
for id := range targetSet {
|
||||
sw, ok := s.active[id]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, w := range sw.waiters {
|
||||
s.effects.GrantError(w, unloadErr)
|
||||
}
|
||||
delete(s.active, id)
|
||||
}
|
||||
|
||||
// Drop queued requests addressed to unloaded models. Requests for other
|
||||
// models stay queued and may benefit from drainQueue at the end.
|
||||
if len(s.queued) > 0 {
|
||||
kept := s.queued[:0]
|
||||
for _, w := range s.queued {
|
||||
if targetSet[w.Model] {
|
||||
s.effects.GrantError(w, unloadErr)
|
||||
continue
|
||||
}
|
||||
kept = append(kept, w)
|
||||
}
|
||||
s.queued = kept
|
||||
}
|
||||
|
||||
// Stop the targeted processes. Done synchronously so Unload's caller can
|
||||
// rely on "after Unload returns, the process is stopped". inFlight is
|
||||
// intentionally NOT cleared here: each dying handler will fire its tracked
|
||||
// serve and reach OnServeDone in the normal way.
|
||||
s.effects.StopProcesses(timeout, targets)
|
||||
|
||||
// Removing entries from active above may have unblocked queued requests
|
||||
// that previously collided with the now-cancelled swaps.
|
||||
s.drainQueue()
|
||||
}
|
||||
|
||||
// OnShutdown grants err to every waiter still held by the scheduler.
|
||||
func (s *FIFO) OnShutdown(err error) {
|
||||
for _, sw := range s.active {
|
||||
for _, w := range sw.waiters {
|
||||
s.effects.GrantError(w, err)
|
||||
}
|
||||
}
|
||||
for _, w := range s.queued {
|
||||
s.effects.GrantError(w, err)
|
||||
}
|
||||
}
|
||||
|
||||
// grantHandler hands the caller a tracked handler for modelID and, only if the
|
||||
// caller was still there to receive it, bumps the in-flight count. Incrementing
|
||||
// when the grant failed would strand the counter and block future evictions.
|
||||
func (s *FIFO) grantHandler(req HandlerReq, modelID string) {
|
||||
if s.effects.GrantServe(req, modelID) {
|
||||
s.inFlight[modelID]++
|
||||
}
|
||||
}
|
||||
|
||||
// startSwap records the swap as active and launches it via Effects. running is
|
||||
// the set EvictionFor saw, forwarded to OnSwapStart so the planner logs against
|
||||
// the same picture it decided on.
|
||||
func (s *FIFO) startSwap(initial HandlerReq, evict, running []string) {
|
||||
s.active[initial.Model] = &activeSwap{
|
||||
modelID: initial.Model,
|
||||
evict: evict,
|
||||
waiters: []HandlerReq{initial},
|
||||
}
|
||||
s.planner.OnSwapStart(initial.Model, running)
|
||||
s.effects.StartSwap(initial.Model, evict)
|
||||
}
|
||||
|
||||
// enqueue inserts req into the queue in priority order: it goes just before the
|
||||
// first queued item whose priority is strictly lower, so higher-priority models
|
||||
// are serviced first while equal-priority requests keep their arrival (FIFO)
|
||||
// order. Priorities come from the FifoConfig; unlisted models default to 0.
|
||||
func (s *FIFO) enqueue(req HandlerReq) {
|
||||
p := s.cfg.Priority[req.Model]
|
||||
i := len(s.queued)
|
||||
for j, q := range s.queued {
|
||||
if s.cfg.Priority[q.Model] < p {
|
||||
i = j
|
||||
break
|
||||
}
|
||||
}
|
||||
s.queued = append(s.queued, HandlerReq{})
|
||||
copy(s.queued[i+1:], s.queued[i:])
|
||||
s.queued[i] = req
|
||||
broadcastQueuePositions(s.queued)
|
||||
}
|
||||
|
||||
// drainQueue walks the queued requests in order, re-running the OnRequest
|
||||
// decision tree against the (now smaller) active set. Items that can now start
|
||||
// or join become satisfied; items still blocked remain queued in original order
|
||||
// so they get another chance on the next swap completion.
|
||||
func (s *FIFO) drainQueue() {
|
||||
if len(s.queued) == 0 {
|
||||
return
|
||||
}
|
||||
pending := s.queued
|
||||
var remaining []HandlerReq
|
||||
for _, req := range pending {
|
||||
state, ok := s.effects.ModelState(req.Model)
|
||||
if !ok {
|
||||
s.effects.GrantError(req, ErrModelNotFound)
|
||||
continue
|
||||
}
|
||||
if sw, ok := s.active[req.Model]; ok {
|
||||
s.logger.Debugf("%s: queued request for model %s now joining in-flight swap", s.name, req.Model)
|
||||
sw.waiters = append(sw.waiters, req)
|
||||
continue
|
||||
}
|
||||
running := s.runningSet(req.Model)
|
||||
evict := s.planner.EvictionFor(req.Model, running)
|
||||
if state == process.StateReady && len(evict) == 0 && !collidesWith(req.Model, evict, s.active) {
|
||||
s.logger.Debugf("%s: queued request for model %s now served fast-path", s.name, req.Model)
|
||||
s.grantHandler(req, req.Model)
|
||||
continue
|
||||
}
|
||||
if collidesWith(req.Model, evict, s.active) {
|
||||
remaining = append(remaining, req)
|
||||
continue
|
||||
}
|
||||
if conflictsWithInFlight(evict, s.inFlight) {
|
||||
remaining = append(remaining, req)
|
||||
continue
|
||||
}
|
||||
s.logger.Debugf("%s: queued request for model %s now starting swap, evicting %v", s.name, req.Model, evict)
|
||||
s.startSwap(req, evict, running)
|
||||
}
|
||||
s.queued = remaining
|
||||
broadcastQueuePositions(s.queued)
|
||||
}
|
||||
|
||||
// runningSet is the live model set handed to the Swapper: every process the
|
||||
// baseRouter reports as running, unioned with the targets of in-flight swaps
|
||||
// (excluding excludeActive, the model whose own swap is being decided — its
|
||||
// in-flight entry must not count as "already running"). The result is sorted so
|
||||
// eviction decisions derived from it are deterministic.
|
||||
func (s *FIFO) runningSet(excludeActive string) []string {
|
||||
seen := make(map[string]struct{})
|
||||
var out []string
|
||||
add := func(id string) {
|
||||
if _, dup := seen[id]; dup {
|
||||
return
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
for id := range s.effects.RunningModels() {
|
||||
add(id)
|
||||
}
|
||||
for _, id := range activeTargets(s.active, excludeActive) {
|
||||
add(id)
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// activeTargets returns the IDs of every in-flight swap target except exclude.
|
||||
// The planner uses this to account for models committed to but not yet reflected
|
||||
// in process state.
|
||||
func activeTargets(active map[string]*activeSwap, exclude string) []string {
|
||||
if len(active) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(active))
|
||||
for id := range active {
|
||||
if id == exclude {
|
||||
continue
|
||||
}
|
||||
out = append(out, id)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// collidesWith reports whether a new swap with this target and evict set can
|
||||
// safely run alongside the currently active swaps. Same-target callers should
|
||||
// JOIN (handled before this) — they do not collide with themselves.
|
||||
func collidesWith(target string, evict []string, active map[string]*activeSwap) bool {
|
||||
for id, sw := range active {
|
||||
if id == target {
|
||||
continue
|
||||
}
|
||||
if containsString(evict, id) {
|
||||
return true
|
||||
}
|
||||
if containsString(sw.evict, target) {
|
||||
return true
|
||||
}
|
||||
if slicesOverlap(evict, sw.evict) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// slicesOverlap reports whether xs and ys share any common element.
|
||||
func slicesOverlap(xs, ys []string) bool {
|
||||
for _, x := range xs {
|
||||
if containsString(ys, x) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// conflictsWithInFlight reports whether any model in evict is still handling
|
||||
// requests. Stopping a busy process would cancel its callers' connections, so
|
||||
// the scheduler defers the swap until those callers finish.
|
||||
func conflictsWithInFlight(evict []string, inFlight map[string]int) bool {
|
||||
for _, m := range evict {
|
||||
if inFlight[m] > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsString(xs []string, s string) bool {
|
||||
for _, x := range xs {
|
||||
if x == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// broadcastQueuePositions sends each queued request its current 1-indexed
|
||||
// position. Sends are non-blocking: if the channel is full, the old value is
|
||||
// drained first so the consumer always sees the latest position.
|
||||
func broadcastQueuePositions(queued []HandlerReq) {
|
||||
for i, req := range queued {
|
||||
pos := i + 1
|
||||
select {
|
||||
case req.PositionCh <- pos:
|
||||
default:
|
||||
select {
|
||||
case <-req.PositionCh:
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case req.PositionCh <- pos:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,633 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
)
|
||||
|
||||
// FIFO methods all run on the router's single run-loop goroutine, so these
|
||||
// tests drive them directly and synchronously. A swap is "completed" by calling
|
||||
// OnSwapDone, a served request "finishes" by calling OnServeDone — exactly the
|
||||
// events the run loop would deliver. fakeEffects records every side-effect and
|
||||
// stubPlanner supplies a fixed eviction set per target.
|
||||
|
||||
// stubPlanner returns a fixed eviction list per target.
|
||||
type stubPlanner struct {
|
||||
evict map[string][]string
|
||||
}
|
||||
|
||||
func (s *stubPlanner) EvictionFor(target string, _ []string) []string {
|
||||
if s.evict == nil {
|
||||
return nil
|
||||
}
|
||||
return s.evict[target]
|
||||
}
|
||||
|
||||
func (s *stubPlanner) OnSwapStart(string, []string) {}
|
||||
|
||||
// grantRec is one GrantError / GrantServe call. err!=nil marks an error grant;
|
||||
// otherwise it is a serve grant and serve reports whether the caller received it.
|
||||
type grantRec struct {
|
||||
model string
|
||||
err error
|
||||
serve bool
|
||||
}
|
||||
|
||||
type startRec struct {
|
||||
model string
|
||||
evict []string
|
||||
}
|
||||
|
||||
type stopRec struct {
|
||||
timeout time.Duration
|
||||
ids []string
|
||||
}
|
||||
|
||||
// fakeEffects is an in-memory scheduler.Effects. Tests program process states
|
||||
// and GrantServe outcomes, then assert on the recorded calls.
|
||||
type fakeEffects struct {
|
||||
states map[string]process.ProcessState // model -> state; missing => not handled
|
||||
serveResult map[string]bool // GrantServe return per model (default true)
|
||||
|
||||
starts []startRec
|
||||
grants []grantRec
|
||||
stops []stopRec
|
||||
}
|
||||
|
||||
func newFakeEffects() *fakeEffects {
|
||||
return &fakeEffects{
|
||||
states: map[string]process.ProcessState{},
|
||||
serveResult: map[string]bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeEffects) ModelState(modelID string) (process.ProcessState, bool) {
|
||||
st, ok := f.states[modelID]
|
||||
return st, ok
|
||||
}
|
||||
|
||||
func (f *fakeEffects) RunningModels() map[string]process.ProcessState {
|
||||
out := make(map[string]process.ProcessState)
|
||||
for id, st := range f.states {
|
||||
if st == process.StateStopped || st == process.StateShutdown {
|
||||
continue
|
||||
}
|
||||
out[id] = st
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (f *fakeEffects) StartSwap(modelID string, evict []string) {
|
||||
f.starts = append(f.starts, startRec{model: modelID, evict: evict})
|
||||
}
|
||||
|
||||
func (f *fakeEffects) GrantError(req HandlerReq, err error) {
|
||||
f.grants = append(f.grants, grantRec{model: req.Model, err: err})
|
||||
}
|
||||
|
||||
func (f *fakeEffects) GrantServe(req HandlerReq, modelID string) bool {
|
||||
ok := true
|
||||
if v, set := f.serveResult[modelID]; set {
|
||||
ok = v
|
||||
}
|
||||
f.grants = append(f.grants, grantRec{model: modelID, serve: ok})
|
||||
return ok
|
||||
}
|
||||
|
||||
func (f *fakeEffects) StopProcesses(timeout time.Duration, ids []string) {
|
||||
f.stops = append(f.stops, stopRec{timeout: timeout, ids: ids})
|
||||
}
|
||||
|
||||
// served counts grants that handed modelID a handler and were received.
|
||||
func (f *fakeEffects) served(modelID string) int {
|
||||
n := 0
|
||||
for _, g := range f.grants {
|
||||
if g.err == nil && g.serve && g.model == modelID {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// errored counts error grants, optionally filtered by model ("" = any).
|
||||
func (f *fakeEffects) errored(model string) int {
|
||||
n := 0
|
||||
for _, g := range f.grants {
|
||||
if g.err != nil && (model == "" || g.model == model) {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// startsFor counts StartSwap calls for modelID.
|
||||
func (f *fakeEffects) startsFor(modelID string) int {
|
||||
n := 0
|
||||
for _, s := range f.starts {
|
||||
if s.model == modelID {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func newFIFO(planner Swapper, eff Effects) *FIFO {
|
||||
return NewFIFO("test", logmon.NewWriter(io.Discard), planner, config.FifoConfig{}, eff)
|
||||
}
|
||||
|
||||
func req(model string) HandlerReq { return HandlerReq{Model: model} }
|
||||
|
||||
// reqCh creates a HandlerReq with a unique Respond channel so OnCancel can
|
||||
// identify it among queued requests and swap waiters.
|
||||
func reqCh(model string) HandlerReq {
|
||||
return HandlerReq{
|
||||
Model: model,
|
||||
Respond: make(chan HandlerResp, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_FastPath(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateReady
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
|
||||
if got := eff.startsFor("a"); got != 0 {
|
||||
t.Errorf("StartSwap calls=%d want 0 (fast path should not swap)", got)
|
||||
}
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Errorf("served(a)=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_ModelNotFound(t *testing.T) {
|
||||
eff := newFakeEffects() // no states => model unknown
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
s.OnRequest(req("ghost"))
|
||||
|
||||
if got := len(eff.starts); got != 0 {
|
||||
t.Errorf("StartSwap calls=%d want 0", got)
|
||||
}
|
||||
if eff.errored("ghost") != 1 {
|
||||
t.Fatalf("want 1 error grant for ghost, grants=%+v", eff.grants)
|
||||
}
|
||||
if !errors.Is(eff.grants[0].err, ErrModelNotFound) {
|
||||
t.Errorf("err=%v want ErrModelNotFound", eff.grants[0].err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_OnDemandStartThenServe(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
if got := eff.startsFor("a"); got != 1 {
|
||||
t.Fatalf("StartSwap(a)=%d want 1", got)
|
||||
}
|
||||
if got := eff.served("a"); got != 0 {
|
||||
t.Errorf("served(a)=%d want 0 before swap completes", got)
|
||||
}
|
||||
|
||||
// Swap finishes, model is now ready.
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Errorf("served(a)=%d want 1 after swap done", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_JoinInFlightSwap(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // starts swap
|
||||
s.OnRequest(req("a")) // joins
|
||||
s.OnRequest(req("a")) // joins
|
||||
|
||||
if got := eff.startsFor("a"); got != 1 {
|
||||
t.Fatalf("StartSwap(a)=%d want 1 (all three share one swap)", got)
|
||||
}
|
||||
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
|
||||
if got := eff.served("a"); got != 3 {
|
||||
t.Errorf("served(a)=%d want 3 (one swap serves all waiters)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_SwapDoneError_FailsAllWaiters(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
s.OnRequest(req("a"))
|
||||
|
||||
s.OnSwapDone(SwapDone{ModelID: "a", Err: errors.New("boom")})
|
||||
|
||||
if eff.served("a") != 0 {
|
||||
t.Errorf("served(a)=%d want 0 on swap error", eff.served("a"))
|
||||
}
|
||||
if eff.errored("a") != 2 {
|
||||
t.Errorf("errored(a)=%d want 2 (both waiters fail)", eff.errored("a"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_QueueOnEvictionCollision covers a request whose target evicts the
|
||||
// model currently being swapped: it must queue until that swap finishes AND its
|
||||
// served request drains, because starting it would stop a busy process.
|
||||
func TestFIFO_QueueOnEvictionCollision(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
// Loading b evicts a.
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"b": {"a"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // StartSwap(a)
|
||||
s.OnRequest(req("b")) // collides with a's in-flight swap -> queue
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Fatalf("b started early: StartSwap(b)=%d want 0", got)
|
||||
}
|
||||
|
||||
// a becomes ready and is granted (now serving, inFlight[a]=1).
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Fatalf("b started while a is serving: StartSwap(b)=%d want 0", got)
|
||||
}
|
||||
|
||||
// a's request finishes -> a no longer in-flight -> b may now swap.
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
|
||||
if got := eff.startsFor("b"); got != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 after a drained", got)
|
||||
}
|
||||
if got := eff.starts[len(eff.starts)-1].evict; len(got) != 1 || got[0] != "a" {
|
||||
t.Errorf("b swap evict=%v want [a]", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_DisjointSwapsRunInParallel verifies two requests with
|
||||
// non-conflicting evict sets both start without waiting for each other.
|
||||
func TestFIFO_DisjointSwapsRunInParallel(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
s := newFIFO(&stubPlanner{}, eff) // empty evicts
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
s.OnRequest(req("b"))
|
||||
|
||||
if eff.startsFor("a") != 1 || eff.startsFor("b") != 1 {
|
||||
t.Fatalf("StartSwap a=%d b=%d want 1 each (parallel)", eff.startsFor("a"), eff.startsFor("b"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_OverlappingEvictSetsDoNotRunInParallel verifies two swaps with
|
||||
// different targets that evict the *same* model do not run concurrently: the
|
||||
// second must queue rather than double-evict the shared model. Neither target is
|
||||
// in the other's evict set, so this is only caught by the evict-set overlap
|
||||
// check in collidesWith.
|
||||
func TestFIFO_OverlappingEvictSetsDoNotRunInParallel(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
eff.states["x"] = process.StateReady // shared eviction target, running
|
||||
// Loading a or b both require evicting x.
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"a": {"x"}, "b": {"x"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // StartSwap(a, [x])
|
||||
s.OnRequest(req("b")) // overlaps a's evict set ([x]) -> queue
|
||||
if eff.startsFor("a") != 1 {
|
||||
t.Fatalf("StartSwap(a)=%d want 1", eff.startsFor("a"))
|
||||
}
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Fatalf("b started in parallel while a evicts x: StartSwap(b)=%d want 0", got)
|
||||
}
|
||||
|
||||
// a's swap completes and x is gone; b can now evict nothing and start.
|
||||
eff.states["a"] = process.StateReady
|
||||
eff.states["x"] = process.StateStopped
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
if got := eff.startsFor("b"); got != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 after a's swap drained", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_QueueDrainPromotesMultiple verifies completing one swap unblocks
|
||||
// every queued request that no longer collides — they all start together.
|
||||
func TestFIFO_QueueDrainPromotesMultiple(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
eff.states["c"] = process.StateStopped
|
||||
// a's swap evicts both b and c; b and c evict nothing.
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"a": {"b", "c"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // StartSwap(a, [b,c])
|
||||
s.OnRequest(req("b")) // collides (in a's evict set) -> queue
|
||||
s.OnRequest(req("c")) // collides -> queue
|
||||
if eff.startsFor("b") != 0 || eff.startsFor("c") != 0 {
|
||||
t.Fatalf("b/c started early")
|
||||
}
|
||||
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
|
||||
// b and c have empty evict sets and don't evict a, so both start now.
|
||||
if eff.startsFor("b") != 1 || eff.startsFor("c") != 1 {
|
||||
t.Fatalf("StartSwap b=%d c=%d want 1 each after a done", eff.startsFor("b"), eff.startsFor("c"))
|
||||
}
|
||||
if eff.served("a") != 1 {
|
||||
t.Errorf("served(a)=%d want 1", eff.served("a"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_QueueCollation verifies duplicate requests collapse into one swap
|
||||
// per model: the second request for each model joins the active swap (at arrival
|
||||
// or at drain time) rather than triggering its own swap.
|
||||
func TestFIFO_QueueCollation(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
for _, id := range []string{"a", "b", "c"} {
|
||||
eff.states[id] = process.StateStopped
|
||||
}
|
||||
// Each model evicts the other two: all swaps are mutually exclusive.
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{
|
||||
"a": {"b", "c"},
|
||||
"b": {"a", "c"},
|
||||
"c": {"a", "b"},
|
||||
}}, eff)
|
||||
|
||||
for _, id := range []string{"a", "b", "c", "a", "b", "c"} {
|
||||
s.OnRequest(req(id))
|
||||
}
|
||||
|
||||
// Drain a, then its served requests, which promotes b; repeat for b -> c.
|
||||
drain := func(model string, waiters int) {
|
||||
eff.states[model] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: model})
|
||||
for i := 0; i < waiters; i++ {
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: model})
|
||||
}
|
||||
}
|
||||
drain("a", 2)
|
||||
drain("b", 2)
|
||||
drain("c", 2)
|
||||
|
||||
for _, id := range []string{"a", "b", "c"} {
|
||||
if got := eff.startsFor(id); got != 1 {
|
||||
t.Errorf("StartSwap(%s)=%d want 1 (collation)", id, got)
|
||||
}
|
||||
if got := eff.served(id); got != 2 {
|
||||
t.Errorf("served(%s)=%d want 2", id, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_NoSwapWhileServing verifies a model still handling requests is not
|
||||
// evicted: the evicting request waits until every in-flight request drains.
|
||||
func TestFIFO_NoSwapWhileServing(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateReady
|
||||
eff.states["b"] = process.StateStopped
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"b": {"a"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // fast path, inFlight[a]=1
|
||||
s.OnRequest(req("a")) // fast path, inFlight[a]=2
|
||||
s.OnRequest(req("b")) // would evict busy a -> queue
|
||||
if eff.startsFor("b") != 0 {
|
||||
t.Fatalf("b started while a serving")
|
||||
}
|
||||
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"}) // inFlight[a]=1
|
||||
if eff.startsFor("b") != 0 {
|
||||
t.Fatalf("b started while a still serving one request")
|
||||
}
|
||||
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"}) // inFlight[a]=0
|
||||
if eff.startsFor("b") != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 after a fully drained", eff.startsFor("b"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_GrantServeFalseDoesNotLeakInFlight verifies that when a caller has
|
||||
// walked away (GrantServe returns false) the in-flight count is not bumped, so a
|
||||
// later evicting request is not blocked forever.
|
||||
func TestFIFO_GrantServeFalseDoesNotLeakInFlight(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
eff.serveResult["a"] = false // a's waiter is gone by grant time
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"b": {"a"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"}) // grant fails, inFlight[a] stays 0
|
||||
|
||||
// b evicts a; since a is not in-flight, b should start immediately.
|
||||
s.OnRequest(req("b"))
|
||||
if eff.startsFor("b") != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 (no leaked in-flight on a)", eff.startsFor("b"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_OnShutdown_FailsAllWaiters verifies shutdown errors every waiter the
|
||||
// scheduler holds: active-swap waiters and queued requests alike.
|
||||
func TestFIFO_OnShutdown_FailsAllWaiters(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
for _, id := range []string{"a", "b", "c"} {
|
||||
eff.states[id] = process.StateStopped
|
||||
}
|
||||
// a and b load in parallel; c collides with both and queues.
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"c": {"a", "b"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // StartSwap(a)
|
||||
s.OnRequest(req("a")) // join a
|
||||
s.OnRequest(req("b")) // StartSwap(b)
|
||||
s.OnRequest(req("b")) // join b
|
||||
s.OnRequest(req("c")) // queued
|
||||
|
||||
s.OnShutdown(errors.New("shutting down"))
|
||||
|
||||
if got := eff.errored(""); got != 5 {
|
||||
t.Errorf("error grants=%d want 5 (2 a + 2 b + 1 c)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_OnUnload_ReleasesActiveWaiters(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // active swap a with one waiter
|
||||
s.OnRequest(req("a")) // join
|
||||
|
||||
s.OnUnload([]string{"a"}, time.Second)
|
||||
|
||||
if got := eff.errored("a"); got != 2 {
|
||||
t.Errorf("errored(a)=%d want 2 (active swap waiters released)", got)
|
||||
}
|
||||
if len(eff.stops) != 1 || len(eff.stops[0].ids) != 1 || eff.stops[0].ids[0] != "a" {
|
||||
t.Errorf("StopProcesses=%+v want one call stopping [a]", eff.stops)
|
||||
}
|
||||
if eff.stops[0].timeout != time.Second {
|
||||
t.Errorf("StopProcesses timeout=%v want 1s", eff.stops[0].timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIFO_OnUnload_DropsQueuedRequests(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
// b evicts a, so a request for b queues while a is loading.
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"b": {"a"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // StartSwap(a)
|
||||
s.OnRequest(req("b")) // queued
|
||||
|
||||
s.OnUnload([]string{"b"}, time.Second)
|
||||
|
||||
if got := eff.errored("b"); got != 1 {
|
||||
t.Errorf("errored(b)=%d want 1 (queued request dropped)", got)
|
||||
}
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Errorf("StartSwap(b)=%d want 0 (b should never start)", got)
|
||||
}
|
||||
// a's swap is untouched: its waiter is neither served nor errored yet.
|
||||
if eff.served("a") != 0 || eff.errored("a") != 0 {
|
||||
t.Errorf("a swap should be untouched: served=%d errored=%d", eff.served("a"), eff.errored("a"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_PriorityQueueOrder verifies queued requests are ordered by descending
|
||||
// priority, with arrival (FIFO) order preserved among equal-priority models.
|
||||
func TestFIFO_PriorityQueueOrder(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
for _, m := range []string{"z", "A", "B", "C", "D"} {
|
||||
eff.states[m] = process.StateStopped
|
||||
}
|
||||
// z's swap evicts every other model, so any request that arrives while z is
|
||||
// loading collides with z's in-flight swap and parks in the queue.
|
||||
planner := &stubPlanner{evict: map[string][]string{"z": {"A", "B", "C", "D"}}}
|
||||
cfg := config.FifoConfig{Priority: map[string]int{"A": 10, "B": 5, "C": 5, "D": 1}}
|
||||
s := NewFIFO("test", logmon.NewWriter(io.Discard), planner, cfg, eff)
|
||||
|
||||
s.OnRequest(req("z")) // StartSwap(z, [A,B,C,D])
|
||||
|
||||
// Arrive out of priority order; B before C exercises FIFO tie-breaking.
|
||||
for _, m := range []string{"B", "D", "C", "A"} {
|
||||
s.OnRequest(req(m))
|
||||
}
|
||||
|
||||
got := make([]string, len(s.queued))
|
||||
for i, q := range s.queued {
|
||||
got[i] = q.Model
|
||||
}
|
||||
want := []string{"A", "B", "C", "D"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("queue=%v want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("queue=%v want %v", got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_OnCancel_QueuedRequest verifies that cancelling a queued request
|
||||
// prevents drainQueue from ever starting a model load for it. Without OnCancel
|
||||
// the dead request would sit in the queue until a drain triggers a wasted swap.
|
||||
func TestFIFO_OnCancel_QueuedRequest(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
// b evicts a, so a request for b queues while a is loading.
|
||||
s := newFIFO(&stubPlanner{evict: map[string][]string{"b": {"a"}}}, eff)
|
||||
|
||||
s.OnRequest(req("a")) // StartSwap(a)
|
||||
|
||||
cancelledReq := reqCh("b")
|
||||
s.OnRequest(cancelledReq) // queued (collides with a's in-flight swap)
|
||||
if len(s.queued) != 1 {
|
||||
t.Fatalf("queue len=%d want 1 before cancel", len(s.queued))
|
||||
}
|
||||
|
||||
// Client disconnects.
|
||||
s.OnCancel(cancelledReq)
|
||||
|
||||
if len(s.queued) != 0 {
|
||||
t.Fatalf("queue len=%d want 0 after cancel", len(s.queued))
|
||||
}
|
||||
|
||||
// a's swap finishes; drainQueue runs but b is gone — no swap for b.
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Errorf("StartSwap(b)=%d want 0 (cancelled request should not trigger a load)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_OnCancel_SwapWaiter verifies that cancelling a request that joined an
|
||||
// in-flight swap removes it from the waiter list. When the swap completes, the
|
||||
// cancelled waiter receives no grant and does not bump the in-flight count.
|
||||
func TestFIFO_OnCancel_SwapWaiter(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
liveReq := reqCh("a")
|
||||
cancelledReq := reqCh("a")
|
||||
s.OnRequest(liveReq) // starts swap
|
||||
s.OnRequest(cancelledReq) // joins
|
||||
|
||||
if sw := s.active["a"]; len(sw.waiters) != 2 {
|
||||
t.Fatalf("waiters=%d want 2", len(sw.waiters))
|
||||
}
|
||||
|
||||
s.OnCancel(cancelledReq)
|
||||
|
||||
if sw := s.active["a"]; len(sw.waiters) != 1 {
|
||||
t.Fatalf("waiters=%d want 1 after cancel", len(sw.waiters))
|
||||
}
|
||||
|
||||
// Swap finishes: only the live waiter is granted.
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Errorf("served(a)=%d want 1 (only the non-cancelled waiter)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFIFO_OnCancel_NotPresent is a no-op: cancelling a request that was already
|
||||
// granted (and is no longer queued or waiting) must not affect anything.
|
||||
func TestFIFO_OnCancel_NotPresent(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateReady
|
||||
s := newFIFO(&stubPlanner{}, eff)
|
||||
|
||||
r := reqCh("a")
|
||||
s.OnRequest(r) // fast-path served immediately
|
||||
|
||||
// Cancel after grant — should be a harmless no-op.
|
||||
s.OnCancel(r)
|
||||
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Errorf("served(a)=%d want 1 (cancel of granted request is a no-op)", got)
|
||||
}
|
||||
if len(s.queued) != 0 {
|
||||
t.Errorf("queue should be empty, len=%d", len(s.queued))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
// Package scheduler contains the request-scheduling strategies used by the
|
||||
// router's baseRouter. A Scheduler owns the queue, in-flight tracking, and the
|
||||
// decision tree for when to start a swap versus queue a request. The baseRouter
|
||||
// owns the channels, run loop, and process machinery, and exposes the
|
||||
// side-effects a scheduler needs through the Effects interface.
|
||||
//
|
||||
// Splitting these apart lets the scheduling strategy be swapped out
|
||||
// independently of both the process machinery (baseRouter) and the eviction
|
||||
// policy (Swapper). FIFO is the first and currently only implementation.
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// ErrModelNotFound is granted to callers whose model is not handled by this
|
||||
// router. It is an alias for shared.ErrNoLocalModelFound.
|
||||
var ErrModelNotFound = shared.ErrNoLocalModelFound
|
||||
|
||||
// Swapper is the eviction policy: it decides which running models must be
|
||||
// stopped before a target can serve. It is orthogonal to the scheduling
|
||||
// strategy — any Scheduler works with any Swapper.
|
||||
type Swapper interface {
|
||||
// EvictionFor returns running model IDs that must be stopped before
|
||||
// target can serve. running is the complete set the scheduler considers
|
||||
// live: every process that is not stopped, unioned with the targets of
|
||||
// in-flight swaps the scheduler has already committed to (which are not yet
|
||||
// visible in process state). The planner does not inspect process state
|
||||
// itself. Pure decision; must not log.
|
||||
EvictionFor(target string, running []string) []string
|
||||
|
||||
// OnSwapStart runs once at the start of every swap, with the same running
|
||||
// set EvictionFor was given for this decision. Planners may log their
|
||||
// decision here at whatever verbosity they choose.
|
||||
OnSwapStart(target string, running []string)
|
||||
}
|
||||
|
||||
// Scheduler decides what happens to each event the router's run loop receives.
|
||||
// All methods run on that single run-loop goroutine, so implementations need no
|
||||
// internal locking for their own state.
|
||||
type Scheduler interface {
|
||||
// OnRequest handles one incoming ServeHTTP request.
|
||||
OnRequest(req HandlerReq)
|
||||
// OnCancel handles a request whose client has disconnected before it was
|
||||
// granted. The scheduler must remove the request from its queue and from
|
||||
// any in-flight swap's waiters so it never triggers a model load or grant
|
||||
// for a caller that is no longer there.
|
||||
OnCancel(req HandlerReq)
|
||||
// OnSwapDone handles a swap goroutine reporting completion.
|
||||
OnSwapDone(ev SwapDone)
|
||||
// OnServeDone handles a tracked ServeHTTP finishing (in-flight decrement).
|
||||
OnServeDone(ev ServeDoneEvent)
|
||||
// OnUnload reconciles scheduler state for an unload, stops the targeted
|
||||
// processes via Effects, and drains the queue. It must block until the
|
||||
// targeted processes have stopped.
|
||||
OnUnload(targets []string, timeout time.Duration)
|
||||
// OnShutdown grants err to every waiter the scheduler still holds (active
|
||||
// swap waiters and queued requests). Process teardown is the baseRouter's
|
||||
// responsibility.
|
||||
OnShutdown(err error)
|
||||
}
|
||||
|
||||
// Effects is implemented by the baseRouter. The scheduler calls back through it
|
||||
// for every side-effect: inspecting process state, launching swaps, responding
|
||||
// to callers, and stopping processes.
|
||||
type Effects interface {
|
||||
// ModelState returns the current state of a model's process. ok is false
|
||||
// when the model is not handled by this router.
|
||||
ModelState(modelID string) (process.ProcessState, bool)
|
||||
// RunningModels returns the state of every process that is not stopped or
|
||||
// shut down, keyed by model ID. The scheduler uses it to build the running
|
||||
// set it hands the Swapper.
|
||||
RunningModels() map[string]process.ProcessState
|
||||
// StartSwap launches the swap goroutine for modelID, stopping evict first.
|
||||
StartSwap(modelID string, evict []string)
|
||||
// GrantError responds to a caller with an error.
|
||||
GrantError(req HandlerReq, err error)
|
||||
// GrantServe hands a caller the wrapped handler for modelID and reports
|
||||
// whether the caller was still there to receive it. The scheduler bumps
|
||||
// its in-flight count only when this returns true.
|
||||
GrantServe(req HandlerReq, modelID string) bool
|
||||
// StopProcesses stops the named processes in parallel and blocks until all
|
||||
// have stopped. Unknown IDs are skipped.
|
||||
StopProcesses(timeout time.Duration, ids []string)
|
||||
}
|
||||
|
||||
// Factory builds a Scheduler bound to a baseRouter's Effects. The concrete
|
||||
// router captures its Swapper in the closure it passes as a Factory.
|
||||
type Factory func(name string, logger *logmon.Monitor, eff Effects) Scheduler
|
||||
|
||||
// HandlerReq is one in-flight ServeHTTP request waiting for a routing decision.
|
||||
type HandlerReq struct {
|
||||
Model string
|
||||
Ctx context.Context
|
||||
Respond chan HandlerResp
|
||||
PositionCh chan int
|
||||
}
|
||||
|
||||
// HandlerResp is the routing decision returned to a HandlerReq's caller: either
|
||||
// a handler to serve with, or an error.
|
||||
type HandlerResp struct {
|
||||
HandleFunc http.HandlerFunc
|
||||
Err error
|
||||
}
|
||||
|
||||
// SwapDone is reported by a swap goroutine when its target is ready (or failed).
|
||||
type SwapDone struct {
|
||||
ModelID string
|
||||
Err error
|
||||
}
|
||||
|
||||
// ServeDoneEvent is reported when a tracked ServeHTTP handler returns.
|
||||
type ServeDoneEvent struct {
|
||||
ModelID string
|
||||
}
|
||||
+128
-17
@@ -9,19 +9,126 @@ import (
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// apiUnloadTimeout is used by the API endpoints to stop processes
|
||||
const apiUnloadTimeout = 10 * time.Second
|
||||
|
||||
// modelRecord is one entry in the OpenAI-compatible /v1/models listing.
|
||||
type modelRecord struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Meta map[string]any `json:"meta,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Architecture map[string]any `json:"architecture,omitempty"`
|
||||
Capabilities map[string]any `json:"capabilities,omitempty"`
|
||||
SupportedParameters []string `json:"supported_parameters,omitempty"`
|
||||
ContextLength int `json:"context_length,omitempty"`
|
||||
Meta map[string]any `json:"meta,omitempty"`
|
||||
}
|
||||
|
||||
// cappedMetadataKeys are top-level /v1/models fields produced by the
|
||||
// capabilities renderer. If a model's metadata block defines any of these
|
||||
// keys, the renderer's values win and the metadata keys are dropped.
|
||||
var cappedMetadataKeys = map[string]struct{}{
|
||||
"architecture": {},
|
||||
"capabilities": {},
|
||||
"supported_parameters": {},
|
||||
"context_length": {},
|
||||
}
|
||||
|
||||
// renderCapabilities converts a model's capabilities config into additional
|
||||
// /v1/models fields. Returns zero values when caps.Empty() is true.
|
||||
func renderCapabilities(caps config.ModelCapConfig) (arch map[string]any, capsMap map[string]any, params []string, ctxLen int) {
|
||||
if caps.Empty() {
|
||||
return
|
||||
}
|
||||
|
||||
hasIn := len(caps.In) > 0
|
||||
hasOut := len(caps.Out) > 0
|
||||
|
||||
if hasIn || hasOut {
|
||||
arch = make(map[string]any)
|
||||
}
|
||||
if hasIn {
|
||||
arch["input_modalities"] = caps.In
|
||||
}
|
||||
if hasOut {
|
||||
arch["output_modalities"] = caps.Out
|
||||
}
|
||||
if hasIn && hasOut {
|
||||
arch["modality"] = strings.Join(caps.In, "+") + "->" + strings.Join(caps.Out, "+")
|
||||
}
|
||||
|
||||
// Build capabilities map only if there's something to put in it.
|
||||
if hasIn || hasOut || caps.Tools || caps.Reranker {
|
||||
capsMap = make(map[string]any)
|
||||
}
|
||||
|
||||
if hasIn {
|
||||
if contains(caps.In, "image") {
|
||||
capsMap["vision"] = true
|
||||
}
|
||||
}
|
||||
if hasIn && hasOut {
|
||||
if contains(caps.In, "audio") && contains(caps.Out, "text") {
|
||||
capsMap["audio_transcriptions"] = true
|
||||
}
|
||||
if contains(caps.In, "text") && contains(caps.Out, "audio") {
|
||||
capsMap["audio_speech"] = true
|
||||
}
|
||||
if contains(caps.In, "text") && contains(caps.Out, "image") {
|
||||
capsMap["image_generation"] = true
|
||||
}
|
||||
if contains(caps.In, "image") && contains(caps.Out, "image") {
|
||||
capsMap["image_to_image"] = true
|
||||
}
|
||||
}
|
||||
|
||||
if caps.Tools {
|
||||
capsMap["function_calling"] = true
|
||||
params = []string{"tools", "tool_choice"}
|
||||
}
|
||||
|
||||
if caps.Reranker {
|
||||
capsMap["reranker"] = true
|
||||
}
|
||||
|
||||
if caps.Context > 0 {
|
||||
ctxLen = caps.Context
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// contains reports whether s is present in ss.
|
||||
func contains(ss []string, s string) bool {
|
||||
for _, v := range ss {
|
||||
if v == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// filterCappedMetadata returns metadata with renderer-owned keys removed.
|
||||
func filterCappedMetadata(md map[string]any) map[string]any {
|
||||
if len(md) == 0 {
|
||||
return nil
|
||||
}
|
||||
filtered := make(map[string]any, len(md))
|
||||
for k, v := range md {
|
||||
if _, capped := cappedMetadataKeys[k]; !capped {
|
||||
filtered[k] = v
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return nil
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// handleListModels serves the OpenAI-compatible model listing: local models
|
||||
@@ -30,7 +137,7 @@ func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
|
||||
created := time.Now().Unix()
|
||||
data := make([]modelRecord, 0, len(s.cfg.Models))
|
||||
|
||||
newRecord := func(id, name, description string, metadata map[string]any) modelRecord {
|
||||
newRecord := func(id, name, description string, metadata map[string]any, caps config.ModelCapConfig) modelRecord {
|
||||
rec := modelRecord{
|
||||
ID: id,
|
||||
Object: "model",
|
||||
@@ -39,6 +146,10 @@ func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
|
||||
Name: strings.TrimSpace(name),
|
||||
Description: strings.TrimSpace(description),
|
||||
}
|
||||
rec.Architecture, rec.Capabilities, rec.SupportedParameters, rec.ContextLength = renderCapabilities(caps)
|
||||
if !caps.Empty() {
|
||||
metadata = filterCappedMetadata(metadata)
|
||||
}
|
||||
if len(metadata) > 0 {
|
||||
rec.Meta = map[string]any{"llamaswap": metadata}
|
||||
}
|
||||
@@ -49,12 +160,12 @@ func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
|
||||
if mc.Unlisted {
|
||||
continue
|
||||
}
|
||||
data = append(data, newRecord(id, mc.Name, mc.Description, mc.Metadata))
|
||||
data = append(data, newRecord(id, mc.Name, mc.Description, mc.Metadata, mc.Capabilities))
|
||||
|
||||
if s.cfg.IncludeAliasesInList {
|
||||
for _, alias := range mc.Aliases {
|
||||
if alias := strings.TrimSpace(alias); alias != "" {
|
||||
data = append(data, newRecord(alias, mc.Name, mc.Description, mc.Metadata))
|
||||
data = append(data, newRecord(alias, mc.Name, mc.Description, mc.Metadata, mc.Capabilities))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -62,7 +173,7 @@ func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
for peerID, peer := range s.cfg.Peers {
|
||||
for _, modelID := range peer.Models {
|
||||
data = append(data, newRecord(modelID, peerID+": "+modelID, "", map[string]any{"peerID": peerID}))
|
||||
data = append(data, newRecord(modelID, peerID+": "+modelID, "", map[string]any{"peerID": peerID}, config.ModelCapConfig{}))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,7 +205,7 @@ type runningModel struct {
|
||||
// handleUnload stops every running local process. Peer models are remote and
|
||||
// unaffected.
|
||||
func (s *Server) handleUnload(w http.ResponseWriter, r *http.Request) {
|
||||
s.local.Unload(0)
|
||||
s.local.Unload(apiUnloadTimeout)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}
|
||||
@@ -160,7 +271,7 @@ func (s *Server) startPreload() {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
req = req.WithContext(router.SetContext(req.Context(), router.ReqContextData{Model: modelID, ModelID: modelID}))
|
||||
req = req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: modelID, ModelID: modelID}))
|
||||
|
||||
dw := &discardResponseWriter{status: http.StatusOK}
|
||||
s.local.ServeHTTP(dw, req)
|
||||
@@ -205,7 +316,7 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
searchName, modelID, remainingPath, found := findModelInPath(s.cfg, "/"+upstreamPath)
|
||||
if !found {
|
||||
router.SendResponse(w, r, http.StatusNotFound, "model not found")
|
||||
shared.SendResponse(w, r, http.StatusNotFound, "model not found")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -227,7 +338,7 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
||||
// Strip the /upstream/<model> prefix before forwarding.
|
||||
r.URL.Path = remainingPath
|
||||
// Pin the resolved model so the router skips body/query extraction.
|
||||
*r = *r.WithContext(router.SetContext(r.Context(), router.ReqContextData{Model: searchName, ModelID: modelID}))
|
||||
*r = *r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: searchName, ModelID: modelID}))
|
||||
|
||||
switch {
|
||||
case s.local.Handles(modelID):
|
||||
@@ -235,7 +346,7 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
||||
case s.peer.Handles(modelID):
|
||||
s.peer.ServeHTTP(w, r)
|
||||
default:
|
||||
router.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID)
|
||||
shared.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -157,3 +157,262 @@ func TestServer_Redirects(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_HandleListModels_Capabilities(t *testing.T) {
|
||||
newServer := func(mc config.ModelConfig) *Server {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
s.cfg = config.Config{Models: map[string]config.ModelConfig{"m": mc}}
|
||||
return s
|
||||
}
|
||||
getModel := func(t *testing.T, s *Server) modelRecord {
|
||||
t.Helper()
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/v1/models", nil))
|
||||
var resp struct {
|
||||
Data []modelRecord `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if len(resp.Data) != 1 {
|
||||
t.Fatalf("expected 1 model, got %d", len(resp.Data))
|
||||
}
|
||||
return resp.Data[0]
|
||||
}
|
||||
|
||||
t.Run("all_fields", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{
|
||||
Capabilities: config.ModelCapConfig{
|
||||
In: []string{"text", "image"},
|
||||
Out: []string{"text", "audio"},
|
||||
Tools: true,
|
||||
Context: 100000,
|
||||
},
|
||||
}))
|
||||
if m.Architecture == nil {
|
||||
t.Fatal("architecture is nil")
|
||||
}
|
||||
if !anySliceStrEqual(m.Architecture["input_modalities"], []string{"text", "image"}) {
|
||||
t.Errorf("input_modalities = %v", m.Architecture["input_modalities"])
|
||||
}
|
||||
if !anySliceStrEqual(m.Architecture["output_modalities"], []string{"text", "audio"}) {
|
||||
t.Errorf("output_modalities = %v", m.Architecture["output_modalities"])
|
||||
}
|
||||
if m.Architecture["modality"] != "text+image->text+audio" {
|
||||
t.Errorf("modality = %v", m.Architecture["modality"])
|
||||
}
|
||||
if m.Capabilities == nil || m.Capabilities["vision"] != true {
|
||||
t.Errorf("vision = %v", m.Capabilities)
|
||||
}
|
||||
if m.Capabilities["audio_speech"] != true {
|
||||
t.Errorf("audio_speech = %v", m.Capabilities["audio_speech"])
|
||||
}
|
||||
if m.Capabilities["function_calling"] != true {
|
||||
t.Errorf("function_calling = %v", m.Capabilities["function_calling"])
|
||||
}
|
||||
if !stringSliceEqual(m.SupportedParameters, []string{"tools", "tool_choice"}) {
|
||||
t.Errorf("supported_parameters = %v", m.SupportedParameters)
|
||||
}
|
||||
if m.ContextLength != 100000 {
|
||||
t.Errorf("context_length = %d", m.ContextLength)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("in_only", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{
|
||||
Capabilities: config.ModelCapConfig{In: []string{"text", "image"}},
|
||||
}))
|
||||
if m.Architecture == nil {
|
||||
t.Fatal("architecture is nil")
|
||||
}
|
||||
if _, ok := m.Architecture["output_modalities"]; ok {
|
||||
t.Error("should not have output_modalities")
|
||||
}
|
||||
if _, ok := m.Architecture["modality"]; ok {
|
||||
t.Error("should not have modality")
|
||||
}
|
||||
if m.Capabilities == nil || m.Capabilities["vision"] != true {
|
||||
t.Error("expected vision: true")
|
||||
}
|
||||
if m.SupportedParameters != nil {
|
||||
t.Error("should not have supported_parameters")
|
||||
}
|
||||
if m.ContextLength != 0 {
|
||||
t.Error("should not have context_length")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("out_only", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{
|
||||
Capabilities: config.ModelCapConfig{Out: []string{"audio"}},
|
||||
}))
|
||||
if m.Architecture == nil {
|
||||
t.Fatal("architecture is nil")
|
||||
}
|
||||
if _, ok := m.Architecture["input_modalities"]; ok {
|
||||
t.Error("should not have input_modalities")
|
||||
}
|
||||
if len(m.Capabilities) > 0 {
|
||||
t.Errorf("expected no capabilities, got %v", m.Capabilities)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tools", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{
|
||||
Capabilities: config.ModelCapConfig{Tools: true},
|
||||
}))
|
||||
if m.Capabilities == nil || m.Capabilities["function_calling"] != true {
|
||||
t.Error("expected function_calling: true")
|
||||
}
|
||||
if !stringSliceEqual(m.SupportedParameters, []string{"tools", "tool_choice"}) {
|
||||
t.Errorf("supported_parameters = %v", m.SupportedParameters)
|
||||
}
|
||||
if m.Architecture != nil {
|
||||
t.Error("should not have architecture")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("reranker", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{
|
||||
Capabilities: config.ModelCapConfig{Reranker: true},
|
||||
}))
|
||||
if m.Capabilities == nil || m.Capabilities["reranker"] != true {
|
||||
t.Error("expected reranker: true")
|
||||
}
|
||||
if m.Architecture != nil {
|
||||
t.Error("should not have architecture")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("context", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{
|
||||
Capabilities: config.ModelCapConfig{Context: 32768},
|
||||
}))
|
||||
if m.ContextLength != 32768 {
|
||||
t.Errorf("context_length = %d", m.ContextLength)
|
||||
}
|
||||
if m.Architecture != nil {
|
||||
t.Error("should not have architecture")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("audio_transcriptions", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{
|
||||
Capabilities: config.ModelCapConfig{In: []string{"audio"}, Out: []string{"text"}},
|
||||
}))
|
||||
if m.Capabilities == nil || m.Capabilities["audio_transcriptions"] != true {
|
||||
t.Error("expected audio_transcriptions: true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("image_generation", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{
|
||||
Capabilities: config.ModelCapConfig{In: []string{"text"}, Out: []string{"image"}},
|
||||
}))
|
||||
if m.Capabilities == nil || m.Capabilities["image_generation"] != true {
|
||||
t.Error("expected image_generation: true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("image_to_image", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{
|
||||
Capabilities: config.ModelCapConfig{In: []string{"image"}, Out: []string{"image"}},
|
||||
}))
|
||||
if m.Capabilities == nil || m.Capabilities["image_to_image"] != true {
|
||||
t.Error("expected image_to_image: true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty_skip", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{}))
|
||||
if m.Architecture != nil {
|
||||
t.Error("should not have architecture")
|
||||
}
|
||||
if m.Capabilities != nil {
|
||||
t.Error("should not have capabilities")
|
||||
}
|
||||
if m.SupportedParameters != nil {
|
||||
t.Error("should not have supported_parameters")
|
||||
}
|
||||
if m.ContextLength != 0 {
|
||||
t.Error("should not have context_length")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("metadata_precedence", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{
|
||||
Capabilities: config.ModelCapConfig{In: []string{"text"}},
|
||||
Metadata: map[string]any{
|
||||
"architecture": "should-be-dropped",
|
||||
"custom_field": "should-remain",
|
||||
"capabilities": "also-dropped",
|
||||
"other_metadata": "also-remain",
|
||||
},
|
||||
}))
|
||||
if m.Architecture == nil || m.Architecture["input_modalities"] == nil {
|
||||
t.Fatal("architecture should be rendered, not from metadata")
|
||||
}
|
||||
if m.Meta == nil || m.Meta["llamaswap"] == nil {
|
||||
t.Fatal("meta.llamaswap should exist")
|
||||
}
|
||||
meta := m.Meta["llamaswap"].(map[string]any)
|
||||
if _, ok := meta["architecture"]; ok {
|
||||
t.Error("architecture should be filtered from metadata")
|
||||
}
|
||||
if _, ok := meta["custom_field"]; !ok {
|
||||
t.Error("custom_field should remain in metadata")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("metadata_passthrough_no_caps", func(t *testing.T) {
|
||||
m := getModel(t, newServer(config.ModelConfig{
|
||||
Metadata: map[string]any{
|
||||
"architecture": "preserved",
|
||||
"context_length": 4096,
|
||||
"capabilities": "preserved",
|
||||
"custom_field": "preserved",
|
||||
},
|
||||
}))
|
||||
if m.Architecture != nil {
|
||||
t.Error("should not have architecture when caps is empty")
|
||||
}
|
||||
if m.Meta == nil || m.Meta["llamaswap"] == nil {
|
||||
t.Fatal("meta.llamaswap should exist")
|
||||
}
|
||||
meta := m.Meta["llamaswap"].(map[string]any)
|
||||
if _, ok := meta["architecture"]; !ok {
|
||||
t.Error("architecture should be preserved in metadata when caps is empty")
|
||||
}
|
||||
if _, ok := meta["context_length"]; !ok {
|
||||
t.Error("context_length should be preserved in metadata when caps is empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func stringSliceEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func anySliceStrEqual(v any, want []string) bool {
|
||||
arr, ok := v.([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if len(arr) != len(want) {
|
||||
return false
|
||||
}
|
||||
for i := range arr {
|
||||
if s, ok := arr[i].(string); !ok || s != want[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
+27
-25
@@ -12,19 +12,19 @@ import (
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// apiModel is one entry in the /api/events modelStatus payload.
|
||||
type apiModel struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
Unlisted bool `json:"unlisted"`
|
||||
PeerID string `json:"peerID"`
|
||||
Aliases []string `json:"aliases,omitempty"`
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
Unlisted bool `json:"unlisted"`
|
||||
PeerID string `json:"peerID"`
|
||||
Aliases []string `json:"aliases,omitempty"`
|
||||
Capabilities map[string]any `json:"capabilities,omitempty"`
|
||||
}
|
||||
|
||||
// modelStatus returns every configured model joined with its current process
|
||||
@@ -45,13 +45,15 @@ func (s *Server) modelStatus() []apiModel {
|
||||
if st, ok := running[id]; ok {
|
||||
state = string(st)
|
||||
}
|
||||
_, capsMap, _, _ := renderCapabilities(mc.Capabilities)
|
||||
models = append(models, apiModel{
|
||||
Id: id,
|
||||
Name: mc.Name,
|
||||
Description: mc.Description,
|
||||
State: state,
|
||||
Unlisted: mc.Unlisted,
|
||||
Aliases: mc.Aliases,
|
||||
Id: id,
|
||||
Name: mc.Name,
|
||||
Description: mc.Description,
|
||||
State: state,
|
||||
Unlisted: mc.Unlisted,
|
||||
Aliases: mc.Aliases,
|
||||
Capabilities: capsMap,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -66,7 +68,7 @@ func (s *Server) modelStatus() []apiModel {
|
||||
|
||||
// handleAPIUnloadAll stops every running local process.
|
||||
func (s *Server) handleAPIUnloadAll(w http.ResponseWriter, r *http.Request) {
|
||||
s.local.Unload(0)
|
||||
s.local.Unload(apiUnloadTimeout)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{"msg": "ok"})
|
||||
}
|
||||
@@ -76,14 +78,14 @@ func (s *Server) handleAPIUnloadModel(w http.ResponseWriter, r *http.Request) {
|
||||
requested := strings.TrimPrefix(r.PathValue("model"), "/")
|
||||
realName, found := s.cfg.RealModelName(requested)
|
||||
if !found {
|
||||
router.SendResponse(w, r, http.StatusNotFound, "model not found")
|
||||
shared.SendResponse(w, r, http.StatusNotFound, "model not found")
|
||||
return
|
||||
}
|
||||
if !s.local.Handles(realName) {
|
||||
router.SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
|
||||
shared.SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
|
||||
return
|
||||
}
|
||||
s.local.Unload(0, realName)
|
||||
s.local.Unload(apiUnloadTimeout, realName)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}
|
||||
@@ -92,7 +94,7 @@ func (s *Server) handleAPIUnloadModel(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Server) handleAPIMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
data, err := s.metrics.getMetricsJSON()
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, "failed to get metrics")
|
||||
shared.SendResponse(w, r, http.StatusInternalServerError, "failed to get metrics")
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@@ -103,7 +105,7 @@ func (s *Server) handleAPIMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
// filtered to samples after the ?after=<RFC3339> timestamp.
|
||||
func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
||||
if s.perf == nil {
|
||||
router.SendResponse(w, r, http.StatusServiceUnavailable, "performance monitor not available")
|
||||
shared.SendResponse(w, r, http.StatusServiceUnavailable, "performance monitor not available")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -112,7 +114,7 @@ func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
||||
if afterStr := r.URL.Query().Get("after"); afterStr != "" {
|
||||
after, err := time.Parse(time.RFC3339, afterStr)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusBadRequest, "invalid 'after' timestamp, use RFC3339 format")
|
||||
shared.SendResponse(w, r, http.StatusBadRequest, "invalid 'after' timestamp, use RFC3339 format")
|
||||
return
|
||||
}
|
||||
filteredSys := make([]perf.SysStat, 0, len(sysStats))
|
||||
@@ -153,19 +155,19 @@ func (s *Server) handleAPIVersion(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Server) handleAPICapture(w http.ResponseWriter, r *http.Request) {
|
||||
id, err := strconv.Atoi(r.PathValue("id"))
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusBadRequest, "invalid capture ID")
|
||||
shared.SendResponse(w, r, http.StatusBadRequest, "invalid capture ID")
|
||||
return
|
||||
}
|
||||
|
||||
capture := s.metrics.getCaptureByID(id)
|
||||
if capture == nil {
|
||||
router.SendResponse(w, r, http.StatusNotFound, "capture not found")
|
||||
shared.SendResponse(w, r, http.StatusNotFound, "capture not found")
|
||||
return
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(capture)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, "failed to marshal capture")
|
||||
shared.SendResponse(w, r, http.StatusInternalServerError, "failed to marshal capture")
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@@ -198,7 +200,7 @@ func (s *Server) handleAPIEvents(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
|
||||
shared.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
+17
-31
@@ -1,19 +1,17 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// CreateAuthMiddleware returns middleware that validates API keys when the
|
||||
// config declares any. It accepts the key via Authorization: Bearer,
|
||||
// Authorization: Basic (password field), or x-api-key. On success the auth
|
||||
// headers are stripped so they never leak to upstream. When no keys are
|
||||
// Authorization: Basic (password field), or x-api-key. When no keys are
|
||||
// configured the middleware is a pass-through.
|
||||
func CreateAuthMiddleware(cfg config.Config) chain.Middleware {
|
||||
keys := cfg.RequiredAPIKeys
|
||||
@@ -22,7 +20,7 @@ func CreateAuthMiddleware(cfg config.Config) chain.Middleware {
|
||||
return next
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
provided := extractAPIKey(r)
|
||||
provided := shared.ExtractAPIKey(r)
|
||||
|
||||
valid := false
|
||||
for _, key := range keys {
|
||||
@@ -33,41 +31,29 @@ func CreateAuthMiddleware(cfg config.Config) chain.Middleware {
|
||||
}
|
||||
if !valid {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="llama-swap"`)
|
||||
router.SendResponse(w, r, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
|
||||
shared.SendResponse(w, r, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
|
||||
return
|
||||
}
|
||||
|
||||
r.Header.Del("Authorization")
|
||||
r.Header.Del("x-api-key")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// extractAPIKey pulls a candidate API key from the request, preferring Basic,
|
||||
// then Bearer, then x-api-key.
|
||||
func extractAPIKey(r *http.Request) string {
|
||||
var bearerKey, basicKey string
|
||||
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
bearerKey = strings.TrimPrefix(auth, "Bearer ")
|
||||
} else if strings.HasPrefix(auth, "Basic ") {
|
||||
encoded := strings.TrimPrefix(auth, "Basic ")
|
||||
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
|
||||
if parts := strings.SplitN(string(decoded), ":", 2); len(parts) == 2 {
|
||||
basicKey = parts[1] // password field is the API key
|
||||
}
|
||||
// CreateRequestContextMiddleware returns middleware that extracts model and
|
||||
// auth info from the request into the context. Requests where no model can be
|
||||
// identified are rejected with a 404.
|
||||
func CreateRequestContextMiddleware(cfg config.Config) chain.Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
data, err := shared.FetchContext(r, cfg)
|
||||
if err != nil {
|
||||
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case basicKey != "":
|
||||
return basicKey
|
||||
case bearerKey != "":
|
||||
return bearerKey
|
||||
default:
|
||||
return r.Header.Get("x-api-key")
|
||||
_ = data
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,48 +1,14 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
)
|
||||
|
||||
func TestServer_ExtractAPIKey(t *testing.T) {
|
||||
basicHeader := func(user, pass string) string {
|
||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
|
||||
}
|
||||
cases := []struct {
|
||||
name string
|
||||
auth string
|
||||
xapi string
|
||||
want string
|
||||
}{
|
||||
{"none", "", "", ""},
|
||||
{"bearer", "Bearer tok123", "", "tok123"},
|
||||
{"basic", basicHeader("user", "pw-key"), "", "pw-key"},
|
||||
{"x-api-key", "", "xkey", "xkey"},
|
||||
{"basic beats bearer", basicHeader("u", "bk"), "", "bk"},
|
||||
{"bearer beats x-api-key", "Bearer btok", "xkey", "btok"},
|
||||
{"malformed basic falls back to x-api-key", "Basic !!!notbase64", "xkey", "xkey"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if c.auth != "" {
|
||||
r.Header.Set("Authorization", c.auth)
|
||||
}
|
||||
if c.xapi != "" {
|
||||
r.Header.Set("x-api-key", c.xapi)
|
||||
}
|
||||
if got := extractAPIKey(r); got != c.want {
|
||||
t.Errorf("extractAPIKey() = %q, want %q", got, c.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_SanitizeAccessControlRequestHeaders(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
@@ -74,11 +40,42 @@ func TestServer_IsTokenChar(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_RequestContextMiddleware(t *testing.T) {
|
||||
cfg := config.Config{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"llama3": {},
|
||||
},
|
||||
}
|
||||
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
mw := CreateRequestContextMiddleware(cfg)
|
||||
|
||||
t.Run("known model passes through", func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"model":"llama3"}`))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
mw(final).ServeHTTP(w, r)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing model returns 404", func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{}`))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
mw(final).ServeHTTP(w, r)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("status = %d, want 404", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_AuthMiddleware(t *testing.T) {
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Authorization") != "" || r.Header.Get("x-api-key") != "" {
|
||||
t.Error("auth headers leaked to upstream")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// defaultConcurrencyLimit caps simultaneous in-flight requests per model when
|
||||
@@ -32,9 +32,9 @@ func CreateConcurrencyMiddleware(cfg config.Config) chain.Middleware {
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
data, err := router.FetchContext(r, cfg)
|
||||
data, err := shared.FetchContext(r, cfg)
|
||||
if err != nil {
|
||||
router.SendError(w, r, router.ErrNoModelInContext)
|
||||
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -45,7 +45,9 @@ func CreateConcurrencyMiddleware(cfg config.Config) chain.Middleware {
|
||||
return
|
||||
}
|
||||
if !sem.TryAcquire(1) {
|
||||
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
w.Write([]byte(`{"error":"Too many requests"}`))
|
||||
return
|
||||
}
|
||||
defer sem.Release(1)
|
||||
|
||||
@@ -7,12 +7,12 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
func concurrencyTestReq(model string) *http.Request {
|
||||
r := httptest.NewRequest("GET", "/v1/chat/completions", nil)
|
||||
return r.WithContext(router.SetContext(r.Context(), router.ReqContextData{Model: model, ModelID: model}))
|
||||
return r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: model, ModelID: model}))
|
||||
}
|
||||
|
||||
func TestServer_ConcurrencyMiddleware_RejectsOverLimit(t *testing.T) {
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -75,6 +77,55 @@ func TestServer_BodyCopier_Flush(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// hijackRecorder is an httptest.ResponseRecorder that also implements
|
||||
// http.Hijacker, returning a pipe so Hijack forwarding can be exercised.
|
||||
type hijackRecorder struct {
|
||||
*httptest.ResponseRecorder
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (h *hijackRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return h.conn, bufio.NewReadWriter(bufio.NewReader(h.conn), bufio.NewWriter(h.conn)), nil
|
||||
}
|
||||
|
||||
func TestServer_BodyCopier_Hijack(t *testing.T) {
|
||||
t.Run("forwards to underlying hijacker", func(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
bc := newBodyCopier(&hijackRecorder{httptest.NewRecorder(), server})
|
||||
conn, _, err := bc.Hijack()
|
||||
if err != nil {
|
||||
t.Fatalf("Hijack: %v", err)
|
||||
}
|
||||
if conn != server {
|
||||
t.Errorf("Hijack returned unexpected conn")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("errors when underlying writer is not a hijacker", func(t *testing.T) {
|
||||
bc := newBodyCopier(httptest.NewRecorder())
|
||||
if _, _, err := bc.Hijack(); err == nil {
|
||||
t.Error("expected error hijacking a non-Hijacker ResponseWriter")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_BodyCopier_SkipsBufferingOnUpgrade(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
bc := newBodyCopier(rec)
|
||||
bc.WriteHeader(http.StatusSwitchingProtocols)
|
||||
bc.Write([]byte("websocket frame bytes"))
|
||||
|
||||
if bc.body.Len() != 0 {
|
||||
t.Errorf("upgrade body buffered = %q, want empty", bc.body.Bytes())
|
||||
}
|
||||
if got := rec.Body.String(); got != "websocket frame bytes" {
|
||||
t.Errorf("client body = %q, want %q", got, "websocket frame bytes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_HeaderMapAndRedact(t *testing.T) {
|
||||
h := http.Header{
|
||||
"Content-Type": {"application/json"},
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
@@ -34,9 +34,9 @@ func CreateFilterMiddleware(cfg config.Config) chain.Middleware {
|
||||
return
|
||||
}
|
||||
|
||||
data, err := router.FetchContext(r, cfg)
|
||||
data, err := shared.FetchContext(r, cfg)
|
||||
if err != nil {
|
||||
router.SendError(w, r, router.ErrNoModelInContext)
|
||||
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -48,13 +48,13 @@ func CreateFilterMiddleware(cfg config.Config) chain.Middleware {
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusBadRequest, "could not read request body")
|
||||
shared.SendResponse(w, r, http.StatusBadRequest, "could not read request body")
|
||||
return
|
||||
}
|
||||
|
||||
body, err = applyFilters(body, data.Model, useModelName, filters)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, err.Error())
|
||||
shared.SendResponse(w, r, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -84,9 +84,9 @@ func CreateFormFilterMiddleware(cfg config.Config) chain.Middleware {
|
||||
return
|
||||
}
|
||||
|
||||
data, err := router.FetchContext(r, cfg)
|
||||
data, err := shared.FetchContext(r, cfg)
|
||||
if err != nil {
|
||||
router.SendError(w, r, router.ErrNoModelInContext)
|
||||
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -97,13 +97,13 @@ func CreateFormFilterMiddleware(cfg config.Config) chain.Middleware {
|
||||
}
|
||||
|
||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||
router.SendResponse(w, r, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||
shared.SendResponse(w, r, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
body, contentType, err := rewriteMultipartModel(r.MultipartForm, useModelName)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, err.Error())
|
||||
shared.SendResponse(w, r, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
+15
-4
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -13,7 +14,7 @@ import (
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// NewLoggers builds the proxy, upstream, and combined (mux) log monitors,
|
||||
@@ -101,13 +102,13 @@ func (s *Server) handleLogStream(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
logger, err := s.getLogger(logMonitorID)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusBadRequest, err.Error())
|
||||
shared.SendResponse(w, r, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
|
||||
shared.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -150,7 +151,8 @@ var requestLogPathSkips = []string{"/wol-health", "/api/performance", "/metrics"
|
||||
|
||||
// statusRecorder wraps an http.ResponseWriter to capture the response status
|
||||
// code and the number of body bytes written, so the access log can report
|
||||
// them. Flush is forwarded so streaming handlers (SSE) still work.
|
||||
// them. Flush is forwarded so streaming handlers (SSE) still work, and Hijack
|
||||
// is forwarded so httputil.ReverseProxy can upgrade websocket connections.
|
||||
type statusRecorder struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
@@ -174,6 +176,15 @@ func (sr *statusRecorder) Flush() {
|
||||
}
|
||||
}
|
||||
|
||||
// Hijack forwards to the underlying ResponseWriter so httputil.ReverseProxy can
|
||||
// take over the connection for websocket upgrades.
|
||||
func (sr *statusRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hj, ok := sr.ResponseWriter.(http.Hijacker); ok {
|
||||
return hj.Hijack()
|
||||
}
|
||||
return nil, nil, fmt.Errorf("underlying ResponseWriter does not support hijacking")
|
||||
}
|
||||
|
||||
// clientIP resolves the originating client address, preferring proxy headers
|
||||
// over the raw connection address.
|
||||
func clientIP(r *http.Request) string {
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
@@ -135,3 +140,103 @@ func TestServer_RequestLogMiddleware(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestServer_RequestLogMiddleware_WebSocketUpgrade verifies that the access-log
|
||||
// middleware (which wraps responses in statusRecorder) does not break websocket
|
||||
// upgrades proxied through httputil.ReverseProxy. ReverseProxy requires the
|
||||
// ResponseWriter to implement http.Hijacker to take over the connection; if
|
||||
// statusRecorder does not forward Hijack, the upgrade is refused with 502.
|
||||
func TestServer_RequestLogMiddleware_WebSocketUpgrade(t *testing.T) {
|
||||
// Upstream: complete the upgrade handshake then echo bytes back. This
|
||||
// stands in for an upstream that speaks websocket; ReverseProxy only cares
|
||||
// about the 101 response and then copies raw bytes both ways.
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
t.Errorf("upstream ResponseWriter is not an http.Hijacker")
|
||||
return
|
||||
}
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
t.Errorf("upstream hijack: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
brw.WriteString("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
|
||||
brw.Flush()
|
||||
// Echo whatever the client sends.
|
||||
buf := make([]byte, 64)
|
||||
n, err := brw.Read(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
brw.Write(buf[:n])
|
||||
brw.Flush()
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
upstreamURL, err := url.Parse(upstream.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse upstream URL: %v", err)
|
||||
}
|
||||
|
||||
// Front server: ReverseProxy wrapped in the access-log middleware, which is
|
||||
// the production statusRecorder-wrapped path.
|
||||
proxy := httputil.NewSingleHostReverseProxy(upstreamURL)
|
||||
mw := CreateRequestLogMiddleware(logmon.NewWriter(io.Discard))
|
||||
front := httptest.NewServer(mw(proxy))
|
||||
defer front.Close()
|
||||
|
||||
frontURL, err := url.Parse(front.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse front URL: %v", err)
|
||||
}
|
||||
|
||||
conn, err := net.DialTimeout("tcp", frontURL.Host, 5*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dial front: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
req := "GET / HTTP/1.1\r\n" +
|
||||
"Host: " + frontURL.Host + "\r\n" +
|
||||
"Connection: Upgrade\r\n" +
|
||||
"Upgrade: websocket\r\n" +
|
||||
"\r\n"
|
||||
if _, err := conn.Write([]byte(req)); err != nil {
|
||||
t.Fatalf("write upgrade request: %v", err)
|
||||
}
|
||||
|
||||
br := bufio.NewReader(conn)
|
||||
statusLine, err := br.ReadString('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("read status line: %v", err)
|
||||
}
|
||||
if !strings.Contains(statusLine, "101") {
|
||||
t.Fatalf("websocket upgrade failed: status line = %q, want 101 Switching Protocols", strings.TrimSpace(statusLine))
|
||||
}
|
||||
|
||||
// Drain the rest of the response headers.
|
||||
for {
|
||||
line, err := br.ReadString('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("read headers: %v", err)
|
||||
}
|
||||
if strings.TrimSpace(line) == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Verify bytes flow through the hijacked connection.
|
||||
if _, err := conn.Write([]byte("ping")); err != nil {
|
||||
t.Fatalf("write payload: %v", err)
|
||||
}
|
||||
echo := make([]byte, 4)
|
||||
if _, err := io.ReadFull(br, echo); err != nil {
|
||||
t.Fatalf("read echo: %v", err)
|
||||
}
|
||||
if string(echo) != "ping" {
|
||||
t.Errorf("echo = %q, want %q", echo, "ping")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -427,6 +429,12 @@ func (w *responseBodyCopier) Write(b []byte) (int, error) {
|
||||
if !w.wroteHeader {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
// On a protocol upgrade (e.g. websocket) the body is raw framed data, not a
|
||||
// metrics-parseable response, so write straight to the client without
|
||||
// buffering a copy we can't use.
|
||||
if w.status == http.StatusSwitchingProtocols {
|
||||
return w.ResponseWriter.Write(b)
|
||||
}
|
||||
return w.tee.Write(b)
|
||||
}
|
||||
|
||||
@@ -446,5 +454,14 @@ func (w *responseBodyCopier) Flush() {
|
||||
}
|
||||
}
|
||||
|
||||
// Hijack forwards to the underlying writer so httputil.ReverseProxy can take
|
||||
// over the connection for websocket upgrades.
|
||||
func (w *responseBodyCopier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hj, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||
return hj.Hijack()
|
||||
}
|
||||
return nil, nil, fmt.Errorf("underlying ResponseWriter does not support hijacking")
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) Status() int { return w.status }
|
||||
func (w *responseBodyCopier) StartTime() time.Time { return w.start }
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// CreateMetricsMiddleware returns middleware that records token metrics for
|
||||
@@ -23,9 +23,9 @@ func CreateMetricsMiddleware(mm *metricsMonitor, cfg config.Config) chain.Middle
|
||||
|
||||
// Resolve the model now so downstream dispatch hits the context
|
||||
// fast path; FetchContext restores the request body.
|
||||
data, err := router.FetchContext(r, cfg)
|
||||
data, err := shared.FetchContext(r, cfg)
|
||||
if err != nil {
|
||||
router.SendError(w, r, router.ErrNoModelInContext)
|
||||
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
+15
-20
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// Server owns the HTTP mux, cross-cutting middleware, and the local/peer model
|
||||
@@ -99,12 +100,13 @@ func New(cfg config.Config, muxlog *logmon.Monitor, proxylog *logmon.Monitor, up
|
||||
var local router.LocalRouter
|
||||
var err error
|
||||
|
||||
if cfg.Matrix != nil {
|
||||
switch cfg.Routing.Router.Use {
|
||||
case "matrix":
|
||||
local, err = router.NewMatrix(cfg, proxylog, upstreamlog)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating matrix router: %w", err)
|
||||
}
|
||||
} else {
|
||||
default: // "group"
|
||||
local, err = router.NewGroup(cfg, proxylog, upstreamlog)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating group router: %w", err)
|
||||
@@ -137,13 +139,13 @@ func New(cfg config.Config, muxlog *logmon.Monitor, proxylog *logmon.Monitor, up
|
||||
}
|
||||
|
||||
// localPeerHandler dispatches a model-routed request to the local or peer
|
||||
// router. The model is resolved once via router.FetchContext.
|
||||
// router. The model is resolved once via shared.FetchContext.
|
||||
func (s *Server) localPeerHandler(w http.ResponseWriter, r *http.Request) {
|
||||
stripVersionPrefix(r)
|
||||
|
||||
data, err := router.FetchContext(r, s.cfg)
|
||||
data, err := shared.FetchContext(r, s.cfg)
|
||||
if err != nil {
|
||||
router.SendError(w, r, router.ErrNoModelInContext)
|
||||
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -155,7 +157,7 @@ func (s *Server) localPeerHandler(w http.ResponseWriter, r *http.Request) {
|
||||
s.proxylog.Debugf("dispatch: using peer for model: %s", data.ModelID)
|
||||
s.peer.ServeHTTP(w, r)
|
||||
default:
|
||||
router.SendError(w, r, router.ErrNoRouterFound)
|
||||
shared.SendError(w, r, router.ErrNoRouterFound)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,21 +172,14 @@ func stripVersionPrefix(r *http.Request) {
|
||||
// routes builds the mux, registers every route, and wraps the mux with the
|
||||
// global CORS middleware.
|
||||
func (s *Server) routes() {
|
||||
authMW := CreateAuthMiddleware(s.cfg)
|
||||
filterMW := CreateFilterMiddleware(s.cfg)
|
||||
formFilterMW := CreateFormFilterMiddleware(s.cfg)
|
||||
|
||||
// Model-dispatched routes get auth + per-model concurrency limiting + body
|
||||
// filters + in-flight tracking + token metrics. concurrencyMW rejects with
|
||||
// 429 before the body filters do any rewrite work. filterMW rewrites JSON
|
||||
// bodies and formFilterMW rewrites multipart bodies; each is a no-op for the
|
||||
// other's Content-Type. Both run before the metrics middleware so it buffers
|
||||
// the rewritten body.
|
||||
authMW := CreateAuthMiddleware(s.cfg)
|
||||
modelChain := chain.New(
|
||||
authMW,
|
||||
CreateRequestContextMiddleware(s.cfg),
|
||||
CreateConcurrencyMiddleware(s.cfg),
|
||||
filterMW,
|
||||
formFilterMW,
|
||||
CreateFilterMiddleware(s.cfg),
|
||||
CreateFormFilterMiddleware(s.cfg),
|
||||
CreateInflightMiddleware(s.inflight),
|
||||
CreateMetricsMiddleware(s.metrics, s.cfg),
|
||||
)
|
||||
@@ -215,11 +210,11 @@ func (s *Server) routes() {
|
||||
mux.HandleFunc("GET /{$}", handleRootRedirect)
|
||||
|
||||
// Embedded UI.
|
||||
mux.HandleFunc("GET /ui/", s.handleUI)
|
||||
mux.Handle("GET /ui/", chain.New(authMW).ThenFunc(s.handleUI))
|
||||
mux.HandleFunc("GET /favicon.ico", s.handleFavicon)
|
||||
|
||||
// Prometheus metrics (no auth, matches the legacy endpoint).
|
||||
mux.HandleFunc("GET /metrics", s.handleMetrics)
|
||||
// Prometheus metrics (wrapped by apiChain, matches the legacy endpoint).
|
||||
mux.Handle("GET /metrics", apiChain.ThenFunc(s.handleMetrics))
|
||||
|
||||
// Operations endpoints.
|
||||
mux.Handle("GET /unload", apiChain.ThenFunc(s.handleUnload))
|
||||
|
||||
@@ -84,10 +84,15 @@ func chatRequest(model string) *http.Request {
|
||||
|
||||
func TestServer_New_GroupConfig(t *testing.T) {
|
||||
discard := logmon.NewWriter(io.Discard)
|
||||
s, err := New(config.Config{HealthCheckTimeout: 15}, discard, discard, discard, nil, BuildInfo{})
|
||||
cfg := config.Config{HealthCheckTimeout: 15}
|
||||
cfg.Routing.Router.Use = "group"
|
||||
s, err := New(cfg, discard, discard, discard, nil, BuildInfo{})
|
||||
if err != nil {
|
||||
t.Fatalf("New (group): %v", err)
|
||||
}
|
||||
if _, ok := s.local.(*router.Group); !ok {
|
||||
t.Fatalf("localRouter=%T want *router.Group", s.local)
|
||||
}
|
||||
if err := s.Shutdown(time.Second); err != nil {
|
||||
t.Fatalf("Shutdown: %v", err)
|
||||
}
|
||||
@@ -95,11 +100,16 @@ func TestServer_New_GroupConfig(t *testing.T) {
|
||||
|
||||
func TestServer_New_MatrixConfig(t *testing.T) {
|
||||
discard := logmon.NewWriter(io.Discard)
|
||||
cfg := config.Config{HealthCheckTimeout: 15, Matrix: &config.MatrixConfig{}}
|
||||
cfg := config.Config{HealthCheckTimeout: 15}
|
||||
cfg.Routing.Router.Use = "matrix"
|
||||
cfg.Routing.Router.Settings.Matrix = &config.MatrixConfig{}
|
||||
s, err := New(cfg, discard, discard, discard, nil, BuildInfo{})
|
||||
if err != nil {
|
||||
t.Fatalf("New (matrix): %v", err)
|
||||
}
|
||||
if _, ok := s.local.(*router.Matrix); !ok {
|
||||
t.Fatalf("localRouter=%T want *router.Matrix", s.local)
|
||||
}
|
||||
if err := s.Shutdown(time.Second); err != nil {
|
||||
t.Fatalf("Shutdown: %v", err)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,202 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type contextkey struct {
|
||||
name string
|
||||
}
|
||||
|
||||
type ReqContextData struct {
|
||||
ApiKey string
|
||||
Model string
|
||||
ModelID string
|
||||
Streaming bool
|
||||
SendLoadingState bool
|
||||
}
|
||||
|
||||
var (
|
||||
ReqContextKey = &contextkey{"context"}
|
||||
ErrNoModelInContext = fmt.Errorf("no model in request context")
|
||||
ErrNoRouterFound = fmt.Errorf("no router found for model")
|
||||
ErrNoPeerModelFound = fmt.Errorf("peer model not found")
|
||||
ErrNoLocalModelFound = fmt.Errorf("local model not found")
|
||||
)
|
||||
|
||||
func SendError(w http.ResponseWriter, r *http.Request, err error) {
|
||||
switch {
|
||||
case errors.Is(err, ErrNoModelInContext):
|
||||
SendResponse(w, r, http.StatusNotFound, "no model id could be identified")
|
||||
case errors.Is(err, ErrNoPeerModelFound):
|
||||
SendResponse(w, r, http.StatusNotFound, "no peer found for requested model")
|
||||
case errors.Is(err, ErrNoLocalModelFound):
|
||||
SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
|
||||
case errors.Is(err, ErrNoRouterFound):
|
||||
SendResponse(w, r, http.StatusNotFound, "no router for requested model")
|
||||
default:
|
||||
SendResponse(w, r, http.StatusInternalServerError, fmt.Sprintf("unspecific error: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// SendResponse detects what content type the client prefers and returns an error response in that format.
|
||||
func SendResponse(w http.ResponseWriter, r *http.Request, status int, message string) {
|
||||
acceptHeader := r.Header.Get("Accept")
|
||||
if strings.Contains(acceptHeader, "text/plain") {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(fmt.Sprintf("llama-swap: %s", message)))
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(acceptHeader, "text/html") {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(fmt.Sprintf(`<html><body><h1>llama-swap</h1><p>%s</p></body></html>`, html.EscapeString(message))))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
resp, err := json.Marshal(map[string]string{"src": "llama-swap", "error": message})
|
||||
if err != nil {
|
||||
w.Write([]byte(`{"src":"llama-swap", "error": "failed to marshal response"}`))
|
||||
return
|
||||
}
|
||||
w.Write(resp)
|
||||
}
|
||||
|
||||
// FetchContext will attempt to get the model id from the context then
|
||||
// from the model body. If it extracts the model from the body it will
|
||||
// store the model in the context for downstream handlers. An error
|
||||
// will be returned when model can not be fetch from either location.
|
||||
func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
|
||||
data, ok := ReadContext(r.Context())
|
||||
if ok {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
if data, err := extractContext(r); err == nil && data.Model != "" {
|
||||
realName, _ := cfg.RealModelName(data.Model)
|
||||
if realName == "" {
|
||||
realName = data.Model
|
||||
}
|
||||
data.ModelID = realName
|
||||
if mc, ok := cfg.Models[realName]; ok {
|
||||
data.SendLoadingState = mc.SendLoadingState != nil && *mc.SendLoadingState
|
||||
}
|
||||
*r = *r.WithContext(SetContext(r.Context(), data))
|
||||
return data, nil
|
||||
}
|
||||
|
||||
return ReqContextData{}, ErrNoModelInContext
|
||||
}
|
||||
|
||||
func SetContext(ctx context.Context, data ReqContextData) context.Context {
|
||||
return context.WithValue(ctx, ReqContextKey, data)
|
||||
}
|
||||
|
||||
func ReadContext(ctx context.Context) (ReqContextData, bool) {
|
||||
data, ok := ctx.Value(ReqContextKey).(ReqContextData)
|
||||
return data, ok
|
||||
}
|
||||
|
||||
// extractContext pulls fields from an HTTP request into a ReqContextData,
|
||||
// returning whatever is available. For GET requests it reads query parameters.
|
||||
// For POST requests it inspects Content-Type and parses JSON,
|
||||
// multipart/form-data, or application/x-www-form-urlencoded bodies. The
|
||||
// request body is always restored before returning. An error is returned only
|
||||
// for I/O or parse failures, not for missing fields.
|
||||
func extractContext(r *http.Request) (ReqContextData, error) {
|
||||
|
||||
apiKey := ExtractAPIKey(r)
|
||||
|
||||
if r.Method == http.MethodGet {
|
||||
q := r.URL.Query()
|
||||
return ReqContextData{
|
||||
Model: q.Get("model"),
|
||||
Streaming: q.Get("stream") == "true",
|
||||
ApiKey: apiKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return ReqContextData{}, fmt.Errorf("error reading request body: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
}()
|
||||
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
return ReqContextData{
|
||||
Model: gjson.GetBytes(bodyBytes, "model").String(),
|
||||
Streaming: gjson.GetBytes(bodyBytes, "stream").Bool(),
|
||||
ApiKey: apiKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Form parsers read from r.Body, so feed them a fresh reader over the
|
||||
// buffered bytes. The deferred restore above will reset r.Body again
|
||||
// after parsing.
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
if strings.Contains(contentType, "multipart/form-data") {
|
||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||
return ReqContextData{}, fmt.Errorf("error parsing multipart form: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return ReqContextData{}, fmt.Errorf("error parsing form: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return ReqContextData{
|
||||
Model: r.FormValue("model"),
|
||||
Streaming: r.FormValue("stream") == "true",
|
||||
ApiKey: apiKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractAPIKey pulls a candidate API key from the request, preferring Basic,
|
||||
// then Bearer, then x-api-key.
|
||||
func ExtractAPIKey(r *http.Request) string {
|
||||
var bearerKey, basicKey string
|
||||
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||
scheme, credentials, ok := strings.Cut(auth, " ")
|
||||
if ok {
|
||||
switch strings.ToLower(scheme) {
|
||||
case "bearer":
|
||||
bearerKey = credentials
|
||||
case "basic":
|
||||
if decoded, err := base64.StdEncoding.DecodeString(credentials); err == nil {
|
||||
if parts := strings.SplitN(string(decoded), ":", 2); len(parts) == 2 {
|
||||
basicKey = parts[1] // password field is the API key
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case basicKey != "":
|
||||
return basicKey
|
||||
case bearerKey != "":
|
||||
return bearerKey
|
||||
default:
|
||||
return r.Header.Get("x-api-key")
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,13 @@
|
||||
package router
|
||||
package shared
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -20,13 +22,13 @@ func TestExtractContext_GET(t *testing.T) {
|
||||
}{
|
||||
{"model present", "model=llama3", "llama3", false},
|
||||
{"model with slashes", "model=author/model-7b", "author/model-7b", false},
|
||||
{"model missing", "", "", true},
|
||||
{"model missing", "", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
|
||||
got, err := ExtractContext(r)
|
||||
got, err := extractContext(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||
}
|
||||
@@ -46,16 +48,16 @@ func TestExtractContext_JSON(t *testing.T) {
|
||||
}{
|
||||
{"model present", `{"model":"llama3","stream":true}`, "llama3", false},
|
||||
{"model with slashes", `{"model":"author/model-7b"}`, "author/model-7b", false},
|
||||
{"model empty string", `{"model":""}`, "", true},
|
||||
{"model key missing", `{"stream":true}`, "", true},
|
||||
{"invalid json", `not-json`, "", true},
|
||||
{"model empty string", `{"model":""}`, "", false},
|
||||
{"model key missing", `{"stream":true}`, "", false},
|
||||
{"invalid json", `not-json`, "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
got, err := ExtractContext(r)
|
||||
got, err := extractContext(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||
}
|
||||
@@ -74,7 +76,7 @@ func TestExtractContext_URLEncodedForm(t *testing.T) {
|
||||
wantErr bool
|
||||
}{
|
||||
{"model present", "whisper-1", "whisper-1", false},
|
||||
{"model missing", "", "", true},
|
||||
{"model missing", "", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -85,7 +87,7 @@ func TestExtractContext_URLEncodedForm(t *testing.T) {
|
||||
}
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(form.Encode()))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
got, err := ExtractContext(r)
|
||||
got, err := extractContext(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||
}
|
||||
@@ -104,7 +106,7 @@ func TestExtractContext_MultipartForm(t *testing.T) {
|
||||
wantErr bool
|
||||
}{
|
||||
{"model present", "whisper-1", "whisper-1", false},
|
||||
{"model missing", "", "", true},
|
||||
{"model missing", "", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -119,7 +121,7 @@ func TestExtractContext_MultipartForm(t *testing.T) {
|
||||
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf)
|
||||
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
got, err := ExtractContext(r)
|
||||
got, err := extractContext(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||
}
|
||||
@@ -135,7 +137,7 @@ func TestExtractContext_JSONBodyRestored(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
if _, err := ExtractContext(r); err != nil {
|
||||
if _, err := extractContext(r); err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
|
||||
@@ -162,7 +164,7 @@ func TestExtractContext_MultipartBodyRestored(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", bytes.NewReader(original))
|
||||
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
|
||||
if _, err := ExtractContext(r); err != nil {
|
||||
if _, err := extractContext(r); err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
|
||||
@@ -180,7 +182,7 @@ func TestExtractContext_URLEncodedBodyRestored(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(body))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
if _, err := ExtractContext(r); err != nil {
|
||||
if _, err := extractContext(r); err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
|
||||
@@ -195,7 +197,7 @@ func TestExtractContext_URLEncodedBodyRestored(t *testing.T) {
|
||||
|
||||
func TestSetContext(t *testing.T) {
|
||||
ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"})
|
||||
data, ok := ctx.Value(ContextKey).(ReqContextData)
|
||||
data, ok := ctx.Value(ReqContextKey).(ReqContextData)
|
||||
if !ok {
|
||||
t.Fatalf("ContextKey not set or wrong type")
|
||||
}
|
||||
@@ -209,7 +211,7 @@ func TestSetContext(t *testing.T) {
|
||||
|
||||
func TestSetContext_WithAlias(t *testing.T) {
|
||||
ctx := SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"})
|
||||
data, _ := ctx.Value(ContextKey).(ReqContextData)
|
||||
data, _ := ctx.Value(ReqContextKey).(ReqContextData)
|
||||
if data.Model != "llama" {
|
||||
t.Errorf("want requested %q got %q", "llama", data.Model)
|
||||
}
|
||||
@@ -221,7 +223,7 @@ func TestSetContext_WithAlias(t *testing.T) {
|
||||
func TestSetContext_DoesNotMutateParent(t *testing.T) {
|
||||
parent := context.Background()
|
||||
_ = SetContext(parent, ReqContextData{Model: "llama3", ModelID: "llama3"})
|
||||
if v := parent.Value(ContextKey); v != nil {
|
||||
if v := parent.Value(ReqContextKey); v != nil {
|
||||
t.Errorf("parent context was mutated: %v", v)
|
||||
}
|
||||
}
|
||||
@@ -273,3 +275,152 @@ func TestReadContext(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_Streaming_GET(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
wantStreaming bool
|
||||
}{
|
||||
{"streaming true", "model=llama3&stream=true", true},
|
||||
{"streaming false", "model=llama3&stream=false", false},
|
||||
{"no stream param", "model=llama3", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
|
||||
got, err := extractContext(r)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
if got.Streaming != tt.wantStreaming {
|
||||
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_Streaming_JSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantStreaming bool
|
||||
}{
|
||||
{"streaming true", `{"model":"llama3","stream":true}`, true},
|
||||
{"streaming false", `{"model":"llama3","stream":false}`, false},
|
||||
{"no stream param", `{"model":"llama3"}`, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
got, err := extractContext(r)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
if got.Streaming != tt.wantStreaming {
|
||||
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_Streaming_URLEncodedForm(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader("model=whisper-1&stream=true"))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
got, err := extractContext(r)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
if !got.Streaming {
|
||||
t.Error("Streaming should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_ApiKey(t *testing.T) {
|
||||
basicHeader := func(user, pass string) string {
|
||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
|
||||
}
|
||||
cases := []struct {
|
||||
name string
|
||||
method string
|
||||
ct string
|
||||
body string
|
||||
auth string
|
||||
xapi string
|
||||
wantKey string
|
||||
}{
|
||||
{"GET bearer", http.MethodGet, "", "", "Bearer sk-get", "", "sk-get"},
|
||||
{"GET x-api-key", http.MethodGet, "", "", "", "xk-get", "xk-get"},
|
||||
{"GET basic", http.MethodGet, "", "", basicHeader("u", "pw-get"), "", "pw-get"},
|
||||
{"JSON bearer", http.MethodPost, "application/json", `{"model":"m"}`, "Bearer sk-json", "", "sk-json"},
|
||||
{"JSON x-api-key", http.MethodPost, "application/json", `{"model":"m"}`, "", "xk-json", "xk-json"},
|
||||
{"form bearer", http.MethodPost, "application/x-www-form-urlencoded", "model=m", "Bearer sk-form", "", "sk-form"},
|
||||
{"no key", http.MethodGet, "", "", "", "", ""},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
var body io.Reader
|
||||
if c.body != "" {
|
||||
body = strings.NewReader(c.body)
|
||||
}
|
||||
r, _ := http.NewRequest(c.method, "/", body)
|
||||
if c.ct != "" {
|
||||
r.Header.Set("Content-Type", c.ct)
|
||||
}
|
||||
if c.auth != "" {
|
||||
r.Header.Set("Authorization", c.auth)
|
||||
}
|
||||
if c.xapi != "" {
|
||||
r.Header.Set("x-api-key", c.xapi)
|
||||
}
|
||||
got, err := extractContext(r)
|
||||
if err != nil {
|
||||
t.Fatalf("extractContext: %v", err)
|
||||
}
|
||||
if got.ApiKey != c.wantKey {
|
||||
t.Errorf("ApiKey = %q, want %q", got.ApiKey, c.wantKey)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ExtractAPIKey(t *testing.T) {
|
||||
basicHeader := func(user, pass string) string {
|
||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
|
||||
}
|
||||
cases := []struct {
|
||||
name string
|
||||
auth string
|
||||
xapi string
|
||||
want string
|
||||
}{
|
||||
{"none", "", "", ""},
|
||||
{"bearer", "Bearer tok123", "", "tok123"},
|
||||
{"basic", basicHeader("user", "pw-key"), "", "pw-key"},
|
||||
{"x-api-key", "", "xkey", "xkey"},
|
||||
{"basic beats bearer", basicHeader("u", "bk"), "", "bk"},
|
||||
{"bearer beats x-api-key", "Bearer btok", "xkey", "btok"},
|
||||
{"malformed basic falls back to x-api-key", "Basic !!!notbase64", "xkey", "xkey"},
|
||||
{"lowercase bearer", "bearer tok123", "", "tok123"},
|
||||
{"lowercase basic", "basic " + base64.StdEncoding.EncodeToString([]byte("user:pw-key")), "", "pw-key"},
|
||||
{"mixed case BEARER", "BEARER tok456", "", "tok456"},
|
||||
{"mixed case bAsIc", "bAsIc " + base64.StdEncoding.EncodeToString([]byte("u:bk")), "", "bk"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if c.auth != "" {
|
||||
r.Header.Set("Authorization", c.auth)
|
||||
}
|
||||
if c.xapi != "" {
|
||||
r.Header.Set("x-api-key", c.xapi)
|
||||
}
|
||||
if got := ExtractAPIKey(r); got != c.want {
|
||||
t.Errorf("extractAPIKey() = %q, want %q", got, c.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
package shared
|
||||
|
||||
import "net"
|
||||
|
||||
// IsLoopbackAddr reports whether listenAddr binds exclusively to loopback.
|
||||
// Addresses with an empty or wildcard host (e.g. ":8080", "0.0.0.0:8080",
|
||||
// "[::]:8080") bind on all interfaces and return false.
|
||||
func IsLoopbackAddr(listenAddr string) bool {
|
||||
host, _, err := net.SplitHostPort(listenAddr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if host == "" {
|
||||
return false
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil {
|
||||
return ip.IsLoopback()
|
||||
}
|
||||
// hostname case (e.g. "localhost")
|
||||
addrs, err := net.LookupHost(host)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, a := range addrs {
|
||||
if !net.ParseIP(a).IsLoopback() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return len(addrs) > 0
|
||||
}
|
||||
+36
-2
@@ -6,6 +6,7 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -19,6 +20,7 @@ import (
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/server"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
"github.com/mostlygeek/llama-swap/internal/watcher"
|
||||
@@ -122,6 +124,13 @@ func main() {
|
||||
applyLogSettings(cfg)
|
||||
proxyLog.Debugf("PID: %d", os.Getpid())
|
||||
|
||||
// On Windows, bind the process tree to a Job Object so every upstream
|
||||
// process is reaped when llama-swap exits — even on a forced kill. No-op
|
||||
// elsewhere. Non-fatal: a failure just falls back to per-process teardown.
|
||||
if err := process.SetupTreeCleanup(); err != nil {
|
||||
proxyLog.Warnf("failed to set up process tree cleanup: %v", err)
|
||||
}
|
||||
|
||||
// perfMon outlives config reloads; its config is updated in place.
|
||||
var perfMon *perf.Monitor
|
||||
if !cfg.Performance.Disabled {
|
||||
@@ -254,6 +263,11 @@ func main() {
|
||||
}
|
||||
}()
|
||||
|
||||
if !shared.IsLoopbackAddr(listenAddr) {
|
||||
_, port, _ := net.SplitHostPort(listenAddr)
|
||||
proxyLog.Infof("llama-swap is reachable by all hosts on the network, use -listen localhost:%s to restrict to loopback only", port)
|
||||
}
|
||||
|
||||
exitChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
@@ -267,6 +281,16 @@ func main() {
|
||||
proxyLog.Infof("received signal %v, shutting down", sig)
|
||||
watcherCancel()
|
||||
|
||||
// Backstop against a stalled shutdown: force the process to
|
||||
// exit once the whole graceful sequence has had its full budget.
|
||||
// On Windows the Job Object reaps upstream processes on exit, so
|
||||
// a forced exit still cleans up rather than orphaning children.
|
||||
go func() {
|
||||
time.Sleep(shutdownTimeout + 5*time.Second)
|
||||
proxyLog.Warnf("graceful shutdown exceeded %v, forcing exit", shutdownTimeout)
|
||||
os.Exit(1)
|
||||
}()
|
||||
|
||||
activeMu.RLock()
|
||||
srv := activeSrv
|
||||
activeMu.RUnlock()
|
||||
@@ -275,13 +299,23 @@ func main() {
|
||||
// drain without blocking on them for the full timeout.
|
||||
srv.CloseStreams()
|
||||
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||
// Both phases share a single deadline so total shutdown is
|
||||
// bounded by shutdownTimeout rather than 2x it.
|
||||
deadline := time.Now().Add(shutdownTimeout)
|
||||
shutdownCtx, cancel := context.WithDeadline(context.Background(), deadline)
|
||||
defer cancel()
|
||||
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
proxyLog.Warnf("http server shutdown error: %v", err)
|
||||
}
|
||||
|
||||
if err := srv.Shutdown(shutdownTimeout); err != nil {
|
||||
// Clamp the remaining budget to a small positive value: a
|
||||
// non-positive timeout makes the router fall back to its own
|
||||
// healthCheckTimeout, which would defeat the shared deadline.
|
||||
remaining := time.Until(deadline)
|
||||
if remaining <= 0 {
|
||||
remaining = time.Millisecond
|
||||
}
|
||||
if err := srv.Shutdown(remaining); err != nil {
|
||||
proxyLog.Warnf("router shutdown error: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
ui_dist/*
|
||||
@@ -1,27 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Custom discard writer that implements http.ResponseWriter but just discards everything
|
||||
type DiscardWriter struct {
|
||||
header http.Header
|
||||
status int
|
||||
}
|
||||
|
||||
func (w *DiscardWriter) Header() http.Header {
|
||||
if w.header == nil {
|
||||
w.header = make(http.Header)
|
||||
}
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *DiscardWriter) Write(data []byte) (int, error) {
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *DiscardWriter) WriteHeader(code int) {
|
||||
w.status = code
|
||||
}
|
||||
|
||||
// Satisfy the http.Flusher interface for streaming responses
|
||||
func (w *DiscardWriter) Flush() {}
|
||||
@@ -1,60 +0,0 @@
|
||||
package proxy
|
||||
|
||||
// package level registry of the different event types
|
||||
|
||||
const ProcessStateChangeEventID = 0x01
|
||||
const ChatCompletionStatsEventID = 0x02
|
||||
const ConfigFileChangedEventID = 0x03
|
||||
const ActivityLogEventID = 0x05
|
||||
const ModelPreloadedEventID = 0x06
|
||||
const InFlightRequestsEventID = 0x07
|
||||
|
||||
type ProcessStateChangeEvent struct {
|
||||
ProcessName string
|
||||
NewState ProcessState
|
||||
OldState ProcessState
|
||||
}
|
||||
|
||||
func (e ProcessStateChangeEvent) Type() uint32 {
|
||||
return ProcessStateChangeEventID
|
||||
}
|
||||
|
||||
type ChatCompletionStats struct {
|
||||
TokensGenerated int
|
||||
}
|
||||
|
||||
func (e ChatCompletionStats) Type() uint32 {
|
||||
return ChatCompletionStatsEventID
|
||||
}
|
||||
|
||||
type ReloadingState int
|
||||
|
||||
const (
|
||||
ReloadingStateStart ReloadingState = iota
|
||||
ReloadingStateEnd
|
||||
)
|
||||
|
||||
type ConfigFileChangedEvent struct {
|
||||
ReloadingState ReloadingState
|
||||
}
|
||||
|
||||
func (e ConfigFileChangedEvent) Type() uint32 {
|
||||
return ConfigFileChangedEventID
|
||||
}
|
||||
|
||||
type ModelPreloadedEvent struct {
|
||||
ModelName string
|
||||
Success bool
|
||||
}
|
||||
|
||||
func (e ModelPreloadedEvent) Type() uint32 {
|
||||
return ModelPreloadedEventID
|
||||
}
|
||||
|
||||
type InFlightRequestsEvent struct {
|
||||
Total int
|
||||
}
|
||||
|
||||
func (e InFlightRequestsEvent) Type() uint32 {
|
||||
return InFlightRequestsEventID
|
||||
}
|
||||
@@ -1,304 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
nextTestPort int = 12000
|
||||
portMutex sync.Mutex
|
||||
testLogger = logmon.NewWriter(os.Stdout)
|
||||
simpleResponderPath = getSimpleResponderPath()
|
||||
)
|
||||
|
||||
// Check if the binary exists
|
||||
func TestMain(m *testing.M) {
|
||||
binaryPath := getSimpleResponderPath()
|
||||
if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
|
||||
fmt.Printf("simple-responder not found at %s, did you `make simple-responder`?\n", binaryPath)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
switch os.Getenv("LOG_LEVEL") {
|
||||
case "debug":
|
||||
testLogger.SetLogLevel(logmon.LevelDebug)
|
||||
case "warn":
|
||||
testLogger.SetLogLevel(logmon.LevelWarn)
|
||||
case "info":
|
||||
testLogger.SetLogLevel(logmon.LevelInfo)
|
||||
default:
|
||||
testLogger.SetLogLevel(logmon.LevelWarn)
|
||||
}
|
||||
|
||||
m.Run()
|
||||
}
|
||||
|
||||
// Helper function to get the binary path
|
||||
func getSimpleResponderPath() string {
|
||||
goos := runtime.GOOS
|
||||
goarch := runtime.GOARCH
|
||||
|
||||
if goos == "windows" {
|
||||
return filepath.Join("..", "build", "simple-responder.exe")
|
||||
} else {
|
||||
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||
}
|
||||
}
|
||||
|
||||
func getTestPort() int {
|
||||
portMutex.Lock()
|
||||
defer portMutex.Unlock()
|
||||
|
||||
port := nextTestPort
|
||||
nextTestPort++
|
||||
|
||||
return port
|
||||
}
|
||||
|
||||
// testConfigFromYAML substitutes {{RESPONDER}} with the simple-responder path and
|
||||
// loads through the real config pipeline (env vars, macros, port assignment, etc.)
|
||||
func testConfigFromYAML(t *testing.T, yamlTmpl string) config.Config {
|
||||
t.Helper()
|
||||
yamlStr := strings.ReplaceAll(yamlTmpl, "{{RESPONDER}}", filepath.ToSlash(simpleResponderPath))
|
||||
cfg, err := config.LoadConfigFromReader(strings.NewReader(yamlStr))
|
||||
require.NoError(t, err)
|
||||
return cfg
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
|
||||
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) config.ModelConfig {
|
||||
// Convert path to forward slashes for cross-platform compatibility
|
||||
// Windows handles forward slashes in paths correctly
|
||||
cmdPath := filepath.ToSlash(simpleResponderPath)
|
||||
|
||||
// Create a YAML string with just the values we want to set
|
||||
yamlStr := fmt.Sprintf(`
|
||||
cmd: '%s --port %d --silent --respond %s'
|
||||
proxy: "http://127.0.0.1:%d"
|
||||
`, cmdPath, port, expectedMessage, port)
|
||||
|
||||
var cfg config.ModelConfig
|
||||
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
|
||||
panic(fmt.Sprintf("failed to unmarshal test config: %v in [%s]", err, yamlStr))
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// injectTestHandlers sets a testHandler on every Process in every ProcessGroup
|
||||
// of the given ProxyManager, bypassing subprocess launches. modelResponses maps
|
||||
// model IDs to their respond strings; if a model ID is not in the map, the model
|
||||
// ID itself is used.
|
||||
func injectTestHandlers(pm *ProxyManager, modelResponses map[string]string) {
|
||||
for _, pg := range pm.processGroups {
|
||||
for modelID, process := range pg.processes {
|
||||
respond := modelID
|
||||
if r, ok := modelResponses[modelID]; ok {
|
||||
respond = r
|
||||
}
|
||||
process.testHandler = newTestHandler(respond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newTestHandler returns an http.Handler that mimics simple-responder's API.
|
||||
// It supports the endpoints that routing tests depend on, without launching
|
||||
// any subprocess or binding any port.
|
||||
func respondJSON(w http.ResponseWriter, respond string, bodyBytes []byte) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"responseMessage": respond,
|
||||
"h_content_length": strconv.Itoa(len(bodyBytes)),
|
||||
"request_body": string(bodyBytes),
|
||||
"usage": map[string]any{
|
||||
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||
},
|
||||
"timings": map[string]any{
|
||||
"prompt_n": 25, "prompt_ms": 13, "predicted_n": 10,
|
||||
"predicted_ms": 17, "predicted_per_second": 10,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func newTestHandler(respond string) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
isStreaming := r.URL.Query().Get("stream") == "true"
|
||||
|
||||
if wait := r.URL.Query().Get("wait"); wait != "" {
|
||||
if d, err := time.ParseDuration(wait); err == nil {
|
||||
time.Sleep(d)
|
||||
}
|
||||
}
|
||||
|
||||
if isStreaming {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
flusher := w.(http.Flusher)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
data, _ := json.Marshal(map[string]any{
|
||||
"created": time.Now().Unix(),
|
||||
"choices": []map[string]any{
|
||||
{"index": 0, "delta": map[string]any{"content": "asdf"}, "finish_reason": nil},
|
||||
},
|
||||
})
|
||||
fmt.Fprintf(w, "event: message\ndata: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
finalData, _ := json.Marshal(map[string]any{
|
||||
"usage": map[string]any{
|
||||
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||
},
|
||||
"timings": map[string]any{
|
||||
"prompt_n": 25, "prompt_ms": 13, "predicted_n": 10,
|
||||
"predicted_ms": 17, "predicted_per_second": 10,
|
||||
},
|
||||
})
|
||||
fmt.Fprintf(w, "event: message\ndata: %s\n\n", finalData)
|
||||
flusher.Flush()
|
||||
|
||||
fmt.Fprintf(w, "event: message\ndata: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
} else {
|
||||
respondJSON(w, respond, bodyBytes)
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
if modelName != respond {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, respond)})
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]string{"message": "ok"})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
respondJSON(w, respond, bodyBytes)
|
||||
})
|
||||
|
||||
for _, path := range []string{
|
||||
"/chat/completions", "/completions",
|
||||
"/responses", "/messages", "/messages/count_tokens",
|
||||
"/embeddings", "/rerank", "/reranking",
|
||||
} {
|
||||
mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
respondJSON(w, respond, bodyBytes)
|
||||
})
|
||||
}
|
||||
|
||||
mux.HandleFunc("/completion", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"responseMessage": respond,
|
||||
"usage": map[string]any{
|
||||
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/audio/transcriptions", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseMultipartForm(10 << 20); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
|
||||
return
|
||||
}
|
||||
model := r.FormValue("model")
|
||||
if model == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Missing model parameter"})
|
||||
return
|
||||
}
|
||||
file, _, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Error getting file: %s", err)})
|
||||
return
|
||||
}
|
||||
fileBytes, _ := io.ReadAll(file)
|
||||
file.Close()
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"text": fmt.Sprintf("The length of the file is %d bytes", len(fileBytes)),
|
||||
"model": model,
|
||||
"h_content_type": r.Header.Get("Content-Type"),
|
||||
"h_content_length": r.Header.Get("Content-Length"),
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/audio/voices", func(w http.ResponseWriter, r *http.Request) {
|
||||
model := r.URL.Query().Get("model")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"voices": []string{"voice1"}, "model": model,
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
fmt.Fprint(w, respond)
|
||||
})
|
||||
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
fmt.Fprintf(w, "%s %s", r.Method, r.URL.Path)
|
||||
})
|
||||
|
||||
mux.HandleFunc("/sdapi/v1/txt2img", func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"model": modelName, "images": []string{},
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/sdapi/v1/img2img", func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"model": modelName, "images": []string{},
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/sdapi/v1/loras", func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"loras": []string{},
|
||||
})
|
||||
})
|
||||
|
||||
return mux
|
||||
}
|
||||
-330
@@ -1,330 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
// MatrixSolver contains pure swap-decision logic with no Process dependencies.
|
||||
// It is safe for concurrent reads after construction.
|
||||
type MatrixSolver struct {
|
||||
expandedSets []config.ExpandedSet // all valid model combinations
|
||||
evictCosts map[string]int // real model name -> eviction cost (default 1)
|
||||
modelToSets map[string][]int // model name -> indices into expandedSets
|
||||
}
|
||||
|
||||
// NewMatrixSolver builds a solver from expanded sets and eviction costs.
|
||||
func NewMatrixSolver(expandedSets []config.ExpandedSet, evictCosts map[string]int) *MatrixSolver {
|
||||
modelToSets := make(map[string][]int)
|
||||
for i, es := range expandedSets {
|
||||
for _, model := range es.Models {
|
||||
modelToSets[model] = append(modelToSets[model], i)
|
||||
}
|
||||
}
|
||||
|
||||
return &MatrixSolver{
|
||||
expandedSets: expandedSets,
|
||||
evictCosts: evictCosts,
|
||||
modelToSets: modelToSets,
|
||||
}
|
||||
}
|
||||
|
||||
// SolveResult describes what the solver decided.
|
||||
type SolveResult struct {
|
||||
Evict []string // running models that must be stopped
|
||||
TargetSet []string // the chosen set of models (for informational purposes)
|
||||
SetName string // name of the chosen set
|
||||
DSL string // original DSL expression for the chosen set
|
||||
TotalCost int // total eviction cost
|
||||
}
|
||||
|
||||
// Solve determines which models to evict when a model is requested.
|
||||
//
|
||||
// Algorithm:
|
||||
// 1. If requestedModel is already running, no eviction needed.
|
||||
// 2. Find all sets containing requestedModel.
|
||||
// 3. If no sets found, the model runs alone; evict all running models.
|
||||
// 4. For each candidate set, compute cost = sum of evict_costs for running
|
||||
// models NOT in that set.
|
||||
// 5. Pick lowest cost. Ties broken by definition order (index in expandedSets).
|
||||
// 6. Return models to evict and the chosen set.
|
||||
func (s *MatrixSolver) Solve(requestedModel string, runningModels []string) (SolveResult, error) {
|
||||
// If already running, nothing to do (but fill in set info for logging)
|
||||
if slices.Contains(runningModels, requestedModel) {
|
||||
setName, dsl := s.findMatchingSet(requestedModel, runningModels)
|
||||
return SolveResult{
|
||||
TargetSet: runningModels,
|
||||
SetName: setName,
|
||||
DSL: dsl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
candidateIndices := s.modelToSets[requestedModel]
|
||||
|
||||
// Model not in any set: runs alone, evict everything
|
||||
if len(candidateIndices) == 0 {
|
||||
evict := make([]string, len(runningModels))
|
||||
copy(evict, runningModels)
|
||||
return SolveResult{
|
||||
Evict: evict,
|
||||
TargetSet: []string{requestedModel},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Find the cheapest candidate set
|
||||
bestCost := -1
|
||||
bestIdx := -1
|
||||
|
||||
for _, idx := range candidateIndices {
|
||||
setModels := s.expandedSets[idx].Models
|
||||
cost := 0
|
||||
for _, running := range runningModels {
|
||||
if !slices.Contains(setModels, running) {
|
||||
cost += s.evictCost(running)
|
||||
}
|
||||
}
|
||||
|
||||
if bestCost < 0 || cost < bestCost || (cost == bestCost && idx < bestIdx) {
|
||||
bestCost = cost
|
||||
bestIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
// Determine which running models to evict
|
||||
chosen := s.expandedSets[bestIdx]
|
||||
var evict []string
|
||||
for _, running := range runningModels {
|
||||
if !slices.Contains(chosen.Models, running) {
|
||||
evict = append(evict, running)
|
||||
}
|
||||
}
|
||||
|
||||
return SolveResult{
|
||||
Evict: evict,
|
||||
TargetSet: chosen.Models,
|
||||
SetName: chosen.SetName,
|
||||
DSL: chosen.DSL,
|
||||
TotalCost: bestCost,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// findMatchingSet finds the expanded set that contains all running models.
|
||||
// Returns the set name and DSL, or empty strings if no match.
|
||||
func (s *MatrixSolver) findMatchingSet(requestedModel string, runningModels []string) (string, string) {
|
||||
for _, idx := range s.modelToSets[requestedModel] {
|
||||
set := s.expandedSets[idx]
|
||||
allInSet := true
|
||||
for _, m := range runningModels {
|
||||
if !slices.Contains(set.Models, m) {
|
||||
allInSet = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allInSet {
|
||||
return set.SetName, set.DSL
|
||||
}
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func (s *MatrixSolver) evictCost(model string) int {
|
||||
if cost, ok := s.evictCosts[model]; ok {
|
||||
return cost
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// Matrix manages processes using solver-based swap logic.
|
||||
type Matrix struct {
|
||||
sync.Mutex
|
||||
solver *MatrixSolver
|
||||
processes map[string]*Process // all processes keyed by real model name
|
||||
config config.Config
|
||||
proxyLogger *logmon.Monitor
|
||||
upstreamLogger *logmon.Monitor
|
||||
|
||||
// inflight tracks ProxyRequest calls that have released m.Lock but may
|
||||
// not yet have incremented Process.inFlightRequests. A concurrent
|
||||
// request that needs to evict models waits for inflight to drain under
|
||||
// m.Lock before stopping anything. Without this, a request that
|
||||
// released m.Lock but has not yet reached Process.inFlightRequests.Add(1)
|
||||
// races with Stop()'s Wait() and can be killed mid-request.
|
||||
inflight sync.WaitGroup
|
||||
|
||||
// testDelayFastPath is a test-only hook invoked in the no-eviction path
|
||||
// after m.Lock is released but before the request is dispatched to
|
||||
// Process.ProxyRequest. Tests use it to park a request at the exact
|
||||
// race window to deterministically reproduce the race.
|
||||
testDelayFastPath func()
|
||||
}
|
||||
|
||||
// NewMatrix creates a Matrix from config. It creates a Process for every
|
||||
// model defined in the config (any model can run alone even if not in a set).
|
||||
func NewMatrix(cfg config.Config, proxyLogger, upstreamLogger *logmon.Monitor) *Matrix {
|
||||
processes := make(map[string]*Process)
|
||||
for modelID, modelConfig := range cfg.Models {
|
||||
processLogger := logmon.NewWriter(upstreamLogger)
|
||||
process := NewProcess(modelID, cfg.HealthCheckTimeout, modelConfig, processLogger, proxyLogger)
|
||||
processes[modelID] = process
|
||||
}
|
||||
|
||||
evictCosts := cfg.Matrix.ResolvedEvictCosts()
|
||||
|
||||
return &Matrix{
|
||||
solver: NewMatrixSolver(cfg.ExpandedSets, evictCosts),
|
||||
processes: processes,
|
||||
config: cfg,
|
||||
proxyLogger: proxyLogger,
|
||||
upstreamLogger: upstreamLogger,
|
||||
}
|
||||
}
|
||||
|
||||
// ProxyRequest handles the swap logic and proxies the request to the model.
|
||||
func (m *Matrix) ProxyRequest(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
process, ok := m.processes[modelID]
|
||||
if !ok {
|
||||
return fmt.Errorf("model %s not found in matrix", modelID)
|
||||
}
|
||||
|
||||
m.Lock()
|
||||
running := m.runningModels()
|
||||
result, err := m.solver.Solve(modelID, running)
|
||||
if err != nil {
|
||||
m.Unlock()
|
||||
return fmt.Errorf("matrix solver error: %w", err)
|
||||
}
|
||||
|
||||
// Log solver decision
|
||||
if len(result.Evict) > 0 {
|
||||
m.proxyLogger.Infof("Matrix: model=%s set=%s dsl=%q evict=%v target=%v cost=%d",
|
||||
modelID, result.SetName, result.DSL, result.Evict, result.TargetSet, result.TotalCost)
|
||||
} else if len(running) == 0 {
|
||||
m.proxyLogger.Infof("Matrix: model=%s starting (no models running)", modelID)
|
||||
} else {
|
||||
m.proxyLogger.Debugf("Matrix: model=%s already running in set=%s dsl=%q", modelID, result.SetName, result.DSL)
|
||||
}
|
||||
|
||||
// Evict models that need to be stopped
|
||||
if len(result.Evict) > 0 {
|
||||
// Wait for any in-flight ProxyRequest calls to register on their
|
||||
// Process before stopping anything. Without this, a request that
|
||||
// released m.Lock but has not yet incremented
|
||||
// Process.inFlightRequests races with Stop() and can be killed
|
||||
// mid-request.
|
||||
m.inflight.Wait()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, evictModel := range result.Evict {
|
||||
if p, exists := m.processes[evictModel]; exists {
|
||||
wg.Add(1)
|
||||
go func(p *Process) {
|
||||
defer wg.Done()
|
||||
p.Stop()
|
||||
}(p)
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Register this request in inflight before releasing m.Lock so a
|
||||
// concurrent eviction will wait for it to complete.
|
||||
m.inflight.Add(1)
|
||||
defer m.inflight.Done()
|
||||
isFastPath := len(result.Evict) == 0
|
||||
m.Unlock()
|
||||
|
||||
if isFastPath && m.testDelayFastPath != nil {
|
||||
m.testDelayFastPath()
|
||||
}
|
||||
|
||||
// Proxy the request (Process handles on-demand start)
|
||||
process.ProxyRequest(w, r)
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopProcesses stops all running processes.
|
||||
func (m *Matrix) StopProcesses(strategy StopStrategy) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range m.processes {
|
||||
wg.Add(1)
|
||||
go func(p *Process) {
|
||||
defer wg.Done()
|
||||
switch strategy {
|
||||
case StopImmediately:
|
||||
p.StopImmediately()
|
||||
default:
|
||||
p.Stop()
|
||||
}
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// StopProcess stops a single process by model ID.
|
||||
func (m *Matrix) StopProcess(modelID string, strategy StopStrategy) error {
|
||||
process, ok := m.processes[modelID]
|
||||
if !ok {
|
||||
return fmt.Errorf("process not found for %s", modelID)
|
||||
}
|
||||
|
||||
switch strategy {
|
||||
case StopImmediately:
|
||||
process.StopImmediately()
|
||||
default:
|
||||
process.Stop()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown shuts down all processes.
|
||||
func (m *Matrix) Shutdown() {
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range m.processes {
|
||||
wg.Add(1)
|
||||
go func(p *Process) {
|
||||
defer wg.Done()
|
||||
p.Shutdown()
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// RunningModels returns model names currently in an active (non-stopped) state.
|
||||
func (m *Matrix) RunningModels() []string {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return m.runningModels()
|
||||
}
|
||||
|
||||
// runningModels returns running model names (caller must hold lock).
|
||||
func (m *Matrix) runningModels() []string {
|
||||
var running []string
|
||||
for id, process := range m.processes {
|
||||
if process.CurrentState() != StateStopped && process.CurrentState() != StateShutdown {
|
||||
running = append(running, id)
|
||||
}
|
||||
}
|
||||
sort.Strings(running)
|
||||
return running
|
||||
}
|
||||
|
||||
// GetProcess returns the Process for a model.
|
||||
func (m *Matrix) GetProcess(modelID string) (*Process, bool) {
|
||||
p, ok := m.processes[modelID]
|
||||
return p, ok
|
||||
}
|
||||
|
||||
// HasModel returns true if the model is managed by this matrix.
|
||||
func (m *Matrix) HasModel(modelID string) bool {
|
||||
_, ok := m.processes[modelID]
|
||||
return ok
|
||||
}
|
||||
@@ -1,349 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Helper to build expanded sets for solver tests
|
||||
func makeExpandedSets(sets ...struct {
|
||||
name string
|
||||
models []string
|
||||
}) []config.ExpandedSet {
|
||||
var result []config.ExpandedSet
|
||||
for _, s := range sets {
|
||||
result = append(result, config.ExpandedSet{
|
||||
SetName: s.name,
|
||||
Models: s.models,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func es(name string, models ...string) struct {
|
||||
name string
|
||||
models []string
|
||||
} {
|
||||
return struct {
|
||||
name string
|
||||
models []string
|
||||
}{name, models}
|
||||
}
|
||||
|
||||
func TestMatrixSolver_AlreadyRunning(t *testing.T) {
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(es("s1", "a", "b")),
|
||||
nil,
|
||||
)
|
||||
|
||||
result, err := solver.Solve("a", []string{"a"})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"a"}, result.TargetSet)
|
||||
assert.Equal(t, "s1", result.SetName)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_NotInAnySet_RunsAlone(t *testing.T) {
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(es("s1", "a", "b")),
|
||||
nil,
|
||||
)
|
||||
|
||||
// Model "c" not in any set
|
||||
result, err := solver.Solve("c", []string{"a", "b"})
|
||||
require.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"a", "b"}, result.Evict)
|
||||
assert.Equal(t, []string{"c"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_NotInAnySet_NothingRunning(t *testing.T) {
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(es("s1", "a", "b")),
|
||||
nil,
|
||||
)
|
||||
|
||||
result, err := solver.Solve("c", []string{})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"c"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_SingleSet_EvictsNonMembers(t *testing.T) {
|
||||
// Set: [a, b]. Request a when b and c are running.
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(es("s1", "a", "b")),
|
||||
nil,
|
||||
)
|
||||
|
||||
result, err := solver.Solve("a", []string{"b", "c"})
|
||||
require.NoError(t, err)
|
||||
// c is not in the set, so it gets evicted. b is in the set, so it stays.
|
||||
assert.Equal(t, []string{"c"}, result.Evict)
|
||||
assert.Equal(t, []string{"a", "b"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_PicksLowestCost(t *testing.T) {
|
||||
// Two sets containing model "a":
|
||||
// s1: [a, v] — if v is running, cost=0; if L is running, cost=30
|
||||
// s2: [a, L] — if L is running, cost=0; if v is running, cost=50
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(
|
||||
es("s1", "a", "v"),
|
||||
es("s2", "a", "L"),
|
||||
),
|
||||
map[string]int{"v": 50, "L": 30},
|
||||
)
|
||||
|
||||
// v is running. Switching to a:
|
||||
// s1 cost: v is in s1, so 0
|
||||
// s2 cost: v is NOT in s2, so 50
|
||||
// => pick s1
|
||||
result, err := solver.Solve("a", []string{"v"})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"a", "v"}, result.TargetSet)
|
||||
|
||||
// L is running. Switching to a:
|
||||
// s1 cost: L is NOT in s1, so 30
|
||||
// s2 cost: L is in s2, so 0
|
||||
// => pick s2
|
||||
result, err = solver.Solve("a", []string{"L"})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"a", "L"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_TieBreakingByDefinitionOrder(t *testing.T) {
|
||||
// Two sets with identical cost. Definition order should win.
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(
|
||||
es("s1", "a", "x"),
|
||||
es("s2", "a", "y"),
|
||||
),
|
||||
nil,
|
||||
)
|
||||
|
||||
// Nothing running, both sets cost 0. s1 is first.
|
||||
result, err := solver.Solve("a", []string{})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"a", "x"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_EvictCostPreservesExpensive(t *testing.T) {
|
||||
// Model "v" costs 50 to evict, "m" costs 1 (default).
|
||||
// Sets: [g,v], [g,m]
|
||||
// Running: v, m. Request g.
|
||||
// s1=[g,v]: evict m (cost 1), keep v
|
||||
// s2=[g,m]: evict v (cost 50), keep m
|
||||
// => pick s1
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(
|
||||
es("s1", "g", "v"),
|
||||
es("s2", "g", "m"),
|
||||
),
|
||||
map[string]int{"v": 50},
|
||||
)
|
||||
|
||||
result, err := solver.Solve("g", []string{"v", "m"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"m"}, result.Evict)
|
||||
assert.Equal(t, []string{"g", "v"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_NothingRunning(t *testing.T) {
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(
|
||||
es("s1", "g", "v"),
|
||||
es("s2", "q", "v"),
|
||||
),
|
||||
nil,
|
||||
)
|
||||
|
||||
result, err := solver.Solve("g", []string{})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"g", "v"}, result.TargetSet)
|
||||
}
|
||||
|
||||
// TestMatrix_ProxyRequestSwapRaceAgainstFastPath verifies that an eviction
|
||||
// cannot stop a process while an in-flight ProxyRequest for that process is
|
||||
// still in the [m.Unlock, Process.inFlightRequests.Add(1)] window. Without
|
||||
// matrix-level inflight tracking, the eviction's Stop() races with the
|
||||
// pending request and kills it mid-start.
|
||||
func TestMatrix_ProxyRequestSwapRaceAgainstFastPath(t *testing.T) {
|
||||
cfg := config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
ExpandedSets: []config.ExpandedSet{
|
||||
{SetName: "s1", Models: []string{"model1"}},
|
||||
{SetName: "s2", Models: []string{"model2"}},
|
||||
},
|
||||
Matrix: &config.MatrixConfig{},
|
||||
}
|
||||
|
||||
m := NewMatrix(cfg, testLogger, testLogger)
|
||||
defer m.StopProcesses(StopImmediately)
|
||||
|
||||
// Bypass real subprocesses so the test is fast and deterministic.
|
||||
m.processes["model1"].testHandler = newTestHandler("model1")
|
||||
m.processes["model2"].testHandler = newTestHandler("model2")
|
||||
|
||||
// Prime: run a request through model1 so it reaches StateReady and
|
||||
// subsequent requests take the no-eviction path.
|
||||
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
primeW := httptest.NewRecorder()
|
||||
require.NoError(t, m.ProxyRequest("model1", primeW, primeReq))
|
||||
require.Equal(t, http.StatusOK, primeW.Code)
|
||||
require.Equal(t, StateReady, m.processes["model1"].CurrentState())
|
||||
require.Equal(t, StateStopped, m.processes["model2"].CurrentState())
|
||||
|
||||
// Install fast-path hook that signals arrival and waits for release.
|
||||
// This parks R2 at the race window — after m.Lock is released but
|
||||
// before Process.inFlightRequests.Add(1).
|
||||
r2Reached := make(chan struct{})
|
||||
r2Release := make(chan struct{})
|
||||
m.testDelayFastPath = func() {
|
||||
close(r2Reached)
|
||||
<-r2Release
|
||||
}
|
||||
|
||||
// R2: no-eviction request for model1. Will pause at the hook.
|
||||
r2Done := make(chan struct{})
|
||||
w2 := httptest.NewRecorder()
|
||||
go func() {
|
||||
defer close(r2Done)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
assert.NoError(t, m.ProxyRequest("model1", w2, req))
|
||||
}()
|
||||
|
||||
// Deterministically wait for R2 to reach the race window.
|
||||
<-r2Reached
|
||||
|
||||
// R3: request for model2 which requires evicting model1. Must wait for
|
||||
// R2 to finish before touching model1.
|
||||
r3Done := make(chan struct{})
|
||||
w3 := httptest.NewRecorder()
|
||||
go func() {
|
||||
defer close(r3Done)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
assert.NoError(t, m.ProxyRequest("model2", w3, req))
|
||||
}()
|
||||
|
||||
// Spin until R3 has acquired m.Lock and entered the eviction path. In
|
||||
// the fixed code, R3 then blocks on m.inflight.Wait() while still
|
||||
// holding the lock, so TryLock keeps failing.
|
||||
for m.TryLock() {
|
||||
m.Unlock()
|
||||
runtime.Gosched()
|
||||
}
|
||||
|
||||
// Bounded poll: give R3 a chance to demonstrate the bug by mutating
|
||||
// state. In the fixed code R3 is blocked and nothing changes; in the
|
||||
// buggy code R3 will Stop() model1 and start model2 within microseconds.
|
||||
deadline := time.Now().Add(100 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if m.processes["model1"].CurrentState() != StateReady ||
|
||||
m.processes["model2"].CurrentState() != StateStopped {
|
||||
break
|
||||
}
|
||||
done := false
|
||||
select {
|
||||
case <-r3Done:
|
||||
done = true
|
||||
default:
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
runtime.Gosched()
|
||||
}
|
||||
|
||||
// Invariant: R3 must be blocked while R2 is still in flight.
|
||||
select {
|
||||
case <-r3Done:
|
||||
t.Fatal("eviction completed while in-flight request was still pending — race not prevented")
|
||||
default:
|
||||
}
|
||||
assert.Equal(t, StateReady, m.processes["model1"].CurrentState(),
|
||||
"model1 must stay Ready while an in-flight request is pending")
|
||||
assert.Equal(t, StateStopped, m.processes["model2"].CurrentState(),
|
||||
"model2 must not be started until R2 finishes and model1 is evicted")
|
||||
|
||||
// Release R2 and let both requests finish.
|
||||
close(r2Release)
|
||||
<-r2Done
|
||||
<-r3Done
|
||||
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Contains(t, w2.Body.String(), "model1")
|
||||
assert.Equal(t, http.StatusOK, w3.Code)
|
||||
assert.Contains(t, w3.Body.String(), "model2")
|
||||
}
|
||||
|
||||
func TestMatrixSolver_FullScenario(t *testing.T) {
|
||||
// Simulates the example config:
|
||||
// standard: [g,v], [q,v], [m,v]
|
||||
// with_rerank: [g,v,e], [q,v,e]
|
||||
// creative: [g,sd], [q,sd]
|
||||
// full: [L]
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(
|
||||
es("standard", "g", "v"),
|
||||
es("standard", "q", "v"),
|
||||
es("standard", "m", "v"),
|
||||
es("with_rerank", "e", "g", "v"),
|
||||
es("with_rerank", "e", "q", "v"),
|
||||
es("creative", "g", "sd"),
|
||||
es("creative", "q", "sd"),
|
||||
es("full", "L"),
|
||||
),
|
||||
map[string]int{"v": 50, "L": 30, "whisper": 10},
|
||||
)
|
||||
|
||||
// Running: g, v. Request q.
|
||||
// standard[q,v]: evict g (cost 1), keep v. Total: 1.
|
||||
// with_rerank[q,v,e]: evict g (cost 1), keep v. Total: 1.
|
||||
// => tie, pick first by definition order = standard[q,v]
|
||||
result, err := solver.Solve("q", []string{"g", "v"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"g"}, result.Evict)
|
||||
assert.Equal(t, []string{"q", "v"}, result.TargetSet)
|
||||
|
||||
// Running: g, v. Request L.
|
||||
// full[L]: evict g (cost 1) + v (cost 50). Total: 51.
|
||||
// Only one set contains L, so pick it.
|
||||
result, err = solver.Solve("L", []string{"g", "v"})
|
||||
require.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"g", "v"}, result.Evict)
|
||||
assert.Equal(t, []string{"L"}, result.TargetSet)
|
||||
|
||||
// Running: g, v. Request sd.
|
||||
// creative[g,sd]: evict v (cost 50). Total: 50.
|
||||
// creative[q,sd]: evict g (cost 1) + v (cost 50). Total: 51.
|
||||
// => pick creative[g,sd]
|
||||
result, err = solver.Solve("sd", []string{"g", "v"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"v"}, result.Evict)
|
||||
assert.Equal(t, []string{"g", "sd"}, result.TargetSet)
|
||||
|
||||
// Running: q, v, e. Request g.
|
||||
// standard[g,v]: evict q (1) + e (1). Total: 2.
|
||||
// with_rerank[g,v,e]: evict q (1). Total: 1.
|
||||
// creative[g,sd]: evict q (1) + v (50) + e (1). Total: 52.
|
||||
// => pick with_rerank[g,v,e]
|
||||
result, err = solver.Solve("g", []string{"e", "q", "v"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"q"}, result.Evict)
|
||||
assert.Equal(t, []string{"e", "g", "v"}, result.TargetSet)
|
||||
}
|
||||
@@ -1,689 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/mostlygeek/llama-swap/internal/cache"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/ring"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// zstdEncOptions are the shared zstd encoder options for maximum compression.
|
||||
var zstdEncOptions = []zstd.EOption{
|
||||
zstd.WithEncoderLevel(zstd.SpeedBetterCompression),
|
||||
}
|
||||
|
||||
// zstdDecOptions are the shared zstd decoder options.
|
||||
var zstdDecOptions = []zstd.DOption{}
|
||||
|
||||
// zstdEncPool pools zstd.Encoder instances to reduce allocations.
|
||||
var zstdEncPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
enc, _ := zstd.NewWriter(nil, zstdEncOptions...)
|
||||
return enc
|
||||
},
|
||||
}
|
||||
|
||||
// zstdDecPool pools zstd.Decoder instances to reduce allocations.
|
||||
var zstdDecPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
dec, _ := zstd.NewReader(nil, zstdDecOptions...)
|
||||
return dec
|
||||
},
|
||||
}
|
||||
|
||||
// compressCapture marshals a ReqRespCapture to CBOR and compresses it with zstd.
|
||||
// Returns compressed bytes and the original CBOR byte count for logging.
|
||||
func compressCapture(c *ReqRespCapture) ([]byte, int, error) {
|
||||
cborBytes, err := cbor.Marshal(c)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("marshal capture: %w", err)
|
||||
}
|
||||
zenc := zstdEncPool.Get().(*zstd.Encoder)
|
||||
defer zstdEncPool.Put(zenc)
|
||||
return zenc.EncodeAll(cborBytes, nil), len(cborBytes), nil
|
||||
}
|
||||
|
||||
// decompressCapture decompresses zstd-compressed CBOR and unmarshals it into a ReqRespCapture.
|
||||
func decompressCapture(data []byte) (*ReqRespCapture, error) {
|
||||
dec := zstdDecPool.Get().(*zstd.Decoder)
|
||||
defer zstdDecPool.Put(dec)
|
||||
cborBytes, err := dec.DecodeAll(data, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decompress capture: %w", err)
|
||||
}
|
||||
var capture ReqRespCapture
|
||||
if err := cbor.Unmarshal(cborBytes, &capture); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal capture: %w", err)
|
||||
}
|
||||
return &capture, nil
|
||||
}
|
||||
|
||||
// TokenMetrics holds token usage and performance metrics
|
||||
type TokenMetrics struct {
|
||||
CachedTokens int `json:"cache_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||
TokensPerSecond float64 `json:"tokens_per_second"`
|
||||
}
|
||||
|
||||
// ActivityLogEntry represents parsed token statistics from llama-server logs
|
||||
type ActivityLogEntry struct {
|
||||
ID int `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Model string `json:"model"`
|
||||
ReqPath string `json:"req_path"`
|
||||
RespContentType string `json:"resp_content_type"`
|
||||
RespStatusCode int `json:"resp_status_code"`
|
||||
Tokens TokenMetrics `json:"tokens"`
|
||||
DurationMs int `json:"duration_ms"`
|
||||
HasCapture bool `json:"has_capture"`
|
||||
}
|
||||
|
||||
type ReqRespCapture struct {
|
||||
ID int `json:"id"`
|
||||
ReqPath string `json:"req_path"`
|
||||
ReqHeaders map[string]string `json:"req_headers"`
|
||||
ReqBody []byte `json:"req_body"`
|
||||
RespHeaders map[string]string `json:"resp_headers"`
|
||||
RespBody []byte `json:"resp_body"`
|
||||
}
|
||||
|
||||
// ActivityLogEvent represents a token metrics event
|
||||
type ActivityLogEvent struct {
|
||||
Metrics ActivityLogEntry
|
||||
}
|
||||
|
||||
func (e ActivityLogEvent) Type() uint32 {
|
||||
return ActivityLogEventID // defined in events.go
|
||||
}
|
||||
|
||||
// metricsMonitor parses llama-server output for token statistics
|
||||
type metricsMonitor struct {
|
||||
mu sync.RWMutex
|
||||
metrics ring.Buffer[ActivityLogEntry]
|
||||
nextID int
|
||||
logger *logmon.Monitor
|
||||
|
||||
// capture fields
|
||||
enableCaptures bool
|
||||
captureCache *cache.Cache // zstd-compressed CBOR of ReqRespCapture
|
||||
}
|
||||
|
||||
// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the
|
||||
// capture buffer size in megabytes; 0 disables captures.
|
||||
func newMetricsMonitor(logger *logmon.Monitor, maxMetrics int, captureBufferMB int) *metricsMonitor {
|
||||
mm := &metricsMonitor{
|
||||
logger: logger,
|
||||
metrics: ring.NewBuffer[ActivityLogEntry](maxMetrics),
|
||||
enableCaptures: captureBufferMB > 0,
|
||||
}
|
||||
if captureBufferMB > 0 {
|
||||
mm.captureCache = cache.New(captureBufferMB * 1024 * 1024)
|
||||
}
|
||||
return mm
|
||||
}
|
||||
|
||||
// queueMetrics adds a new metric to the collection without emitting an event.
|
||||
// Returns the assigned metric ID. Call emitMetric after capture setup.
|
||||
func (mp *metricsMonitor) queueMetrics(metric ActivityLogEntry) int {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
metric.ID = mp.nextID
|
||||
mp.nextID++
|
||||
mp.metrics.Push(metric)
|
||||
return metric.ID
|
||||
}
|
||||
|
||||
// emitMetric publishes an ActivityLogEvent for the given metric.
|
||||
func (mp *metricsMonitor) emitMetric(metric ActivityLogEntry) {
|
||||
event.Emit(ActivityLogEvent{Metrics: metric})
|
||||
}
|
||||
|
||||
// addCapture compresses and stores a capture in the cache.
|
||||
// Returns true if the capture was stored, false otherwise.
|
||||
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) bool {
|
||||
if !mp.enableCaptures {
|
||||
return false
|
||||
}
|
||||
|
||||
compressed, uncompressedBytes, err := compressCapture(&capture)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("failed to compress capture: %v, skipping", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := mp.captureCache.Add(capture.ID, compressed); err != nil {
|
||||
mp.logger.Warnf("capture %d too large (%d bytes), skipping: %v", capture.ID, len(compressed), err)
|
||||
return false
|
||||
}
|
||||
|
||||
compressionRatio := (1 - float64(len(compressed))/float64(uncompressedBytes)) * 100
|
||||
mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio)
|
||||
return true
|
||||
}
|
||||
|
||||
// getCompressedBytes returns the raw compressed bytes for a capture by ID.
|
||||
func (mp *metricsMonitor) getCompressedBytes(id int) ([]byte, bool) {
|
||||
if mp.captureCache == nil {
|
||||
return nil, false
|
||||
}
|
||||
data, err := mp.captureCache.Get(id)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return data, true
|
||||
}
|
||||
|
||||
// getCaptureByID decompresses and unmarshals a capture by ID.
|
||||
// Returns nil if the capture is not found or decompression fails.
|
||||
func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture {
|
||||
if mp.captureCache == nil {
|
||||
return nil
|
||||
}
|
||||
data, exists := mp.getCompressedBytes(id)
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
capture, err := decompressCapture(data)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("failed to decompress capture %d: %v", id, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return capture
|
||||
}
|
||||
|
||||
// getMetrics returns a copy of the current metrics with HasCapture resolved from cache.
|
||||
func (mp *metricsMonitor) getMetrics() []ActivityLogEntry {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
result := mp.metrics.Slice()
|
||||
if result == nil {
|
||||
return []ActivityLogEntry{}
|
||||
}
|
||||
if mp.captureCache != nil {
|
||||
for i := range result {
|
||||
result[i].HasCapture = mp.captureCache.Has(result[i].ID)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// getMetricsJSON returns metrics as JSON with HasCapture resolved from cache.
|
||||
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
result := mp.metrics.Slice()
|
||||
if result == nil {
|
||||
return json.Marshal([]ActivityLogEntry{})
|
||||
}
|
||||
if mp.captureCache != nil {
|
||||
for i := range result {
|
||||
result[i].HasCapture = mp.captureCache.Has(result[i].ID)
|
||||
}
|
||||
}
|
||||
return json.Marshal(result)
|
||||
}
|
||||
|
||||
// Capture field flags for controlling what is saved in ReqRespCapture.
|
||||
type captureFields uint
|
||||
|
||||
const (
|
||||
captureNone captureFields = 1 << iota
|
||||
captureReqHeaders
|
||||
captureReqBody
|
||||
captureRespHeaders
|
||||
captureRespBody
|
||||
)
|
||||
|
||||
const (
|
||||
captureReqAll = captureReqHeaders | captureReqBody
|
||||
captureRespAll = captureRespHeaders | captureRespBody
|
||||
captureAll = captureReqAll | captureRespAll
|
||||
)
|
||||
|
||||
// wrapHandler wraps the proxy handler to extract token metrics.
|
||||
// captureFields controls what is saved in the ReqRespCapture using bitwise flags.
|
||||
// if wrapHandler returns an error it is safe to assume that no
|
||||
// data was sent to the client
|
||||
func (mp *metricsMonitor) wrapHandler(
|
||||
modelID string,
|
||||
writer gin.ResponseWriter,
|
||||
request *http.Request,
|
||||
captureFields captureFields,
|
||||
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
|
||||
) error {
|
||||
// Capture request body and headers if captures enabled
|
||||
var reqBody []byte
|
||||
var reqHeaders map[string]string
|
||||
if mp.enableCaptures && (captureFields&captureReqBody) != 0 {
|
||||
if request.Body != nil {
|
||||
var err error
|
||||
reqBody, err = io.ReadAll(request.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read request body for capture: %w", err)
|
||||
}
|
||||
request.Body.Close()
|
||||
request.Body = io.NopCloser(bytes.NewBuffer(reqBody))
|
||||
}
|
||||
}
|
||||
if mp.enableCaptures && (captureFields&captureReqHeaders) != 0 {
|
||||
reqHeaders = make(map[string]string)
|
||||
for key, values := range request.Header {
|
||||
if len(values) > 0 {
|
||||
reqHeaders[key] = values[0]
|
||||
}
|
||||
}
|
||||
redactHeaders(reqHeaders)
|
||||
}
|
||||
|
||||
recorder := newBodyCopier(writer)
|
||||
|
||||
// Filter Accept-Encoding to only include encodings we can decompress for metrics
|
||||
if ae := request.Header.Get("Accept-Encoding"); ae != "" {
|
||||
request.Header.Set("Accept-Encoding", filterAcceptEncoding(ae))
|
||||
}
|
||||
|
||||
if err := next(modelID, recorder, request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// after this point we have to assume that data was sent to the client
|
||||
// and we can only log errors but not send them to clients
|
||||
|
||||
// Initialize default metrics - recorded for every request
|
||||
tm := ActivityLogEntry{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
ReqPath: request.URL.Path,
|
||||
RespContentType: recorder.Header().Get("Content-Type"),
|
||||
RespStatusCode: recorder.Status(),
|
||||
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
|
||||
}
|
||||
|
||||
if recorder.Status() != http.StatusOK {
|
||||
mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), request.URL.Path)
|
||||
tm.ID = mp.queueMetrics(tm)
|
||||
mp.emitMetric(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
body := recorder.body.Bytes()
|
||||
if len(body) == 0 {
|
||||
mp.logger.Warn("metrics: empty body, recording minimal metrics")
|
||||
tm.ID = mp.queueMetrics(tm)
|
||||
mp.emitMetric(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decompress if needed
|
||||
if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" {
|
||||
var err error
|
||||
body, err = decompressBody(body, encoding)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
tm.ID = mp.queueMetrics(tm)
|
||||
mp.emitMetric(tm)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
} else {
|
||||
tm.Tokens = parsed.Tokens
|
||||
tm.DurationMs = parsed.DurationMs
|
||||
}
|
||||
} else {
|
||||
if gjson.ValidBytes(body) {
|
||||
parsed := gjson.ParseBytes(body)
|
||||
usage := parsed.Get("usage")
|
||||
timings := parsed.Get("timings")
|
||||
|
||||
// extract timings for infill - response is an array, timings are in the last element
|
||||
// see #463
|
||||
if strings.HasPrefix(request.URL.Path, "/infill") {
|
||||
if arr := parsed.Array(); len(arr) > 0 {
|
||||
timings = arr[len(arr)-1].Get("timings")
|
||||
}
|
||||
}
|
||||
|
||||
if usage.Exists() || timings.Exists() {
|
||||
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
|
||||
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
} else {
|
||||
tm.Tokens = parsedMetrics.Tokens
|
||||
tm.DurationMs = parsedMetrics.DurationMs
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", request.URL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
// Build capture if enabled and determine if it will be stored
|
||||
var capture *ReqRespCapture
|
||||
if mp.enableCaptures {
|
||||
var respHeaders map[string]string
|
||||
var respBody []byte
|
||||
if (captureFields & captureRespHeaders) != 0 {
|
||||
respHeaders = make(map[string]string)
|
||||
for key, values := range recorder.Header() {
|
||||
if len(values) > 0 {
|
||||
respHeaders[key] = values[0]
|
||||
}
|
||||
}
|
||||
redactHeaders(respHeaders)
|
||||
delete(respHeaders, "Content-Encoding")
|
||||
}
|
||||
if (captureFields & captureRespBody) != 0 {
|
||||
respBody = body
|
||||
}
|
||||
capture = &ReqRespCapture{
|
||||
ReqPath: request.URL.Path,
|
||||
ReqHeaders: reqHeaders,
|
||||
ReqBody: reqBody,
|
||||
RespHeaders: respHeaders,
|
||||
RespBody: respBody,
|
||||
}
|
||||
}
|
||||
|
||||
metricID := mp.queueMetrics(tm)
|
||||
tm.ID = metricID
|
||||
|
||||
// Store capture if enabled
|
||||
if capture != nil {
|
||||
capture.ID = metricID
|
||||
if mp.addCapture(*capture) {
|
||||
tm.HasCapture = true
|
||||
}
|
||||
}
|
||||
|
||||
mp.emitMetric(tm)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// usagePaths lists the JSON paths where a per-event usage object can live.
|
||||
// v1/chat/completions puts it at top-level "usage"; v1/responses nests under
|
||||
// "response.usage"; v1/messages emits it at "message.usage" on message_start
|
||||
// and at "usage" on message_delta.
|
||||
var usagePaths = []string{"usage", "response.usage", "message.usage"}
|
||||
|
||||
// extractUsageTokens reads input/output/cached token counts from a usage
|
||||
// gjson.Result, handling the field-name differences across endpoints.
|
||||
// cached returns -1 when the field is absent. ok is true when at least one
|
||||
// field was present.
|
||||
func extractUsageTokens(usage gjson.Result) (input, output, cached int64, ok bool) {
|
||||
cached = -1
|
||||
if !usage.Exists() {
|
||||
return
|
||||
}
|
||||
|
||||
if v := usage.Get("prompt_tokens"); v.Exists() {
|
||||
// v1/chat/completions
|
||||
input = v.Int()
|
||||
ok = true
|
||||
} else if v := usage.Get("input_tokens"); v.Exists() {
|
||||
// v1/messages, v1/responses
|
||||
input = v.Int()
|
||||
ok = true
|
||||
}
|
||||
|
||||
if v := usage.Get("completion_tokens"); v.Exists() {
|
||||
// v1/chat/completions
|
||||
output = v.Int()
|
||||
ok = true
|
||||
} else if v := usage.Get("output_tokens"); v.Exists() {
|
||||
// v1/messages, v1/responses
|
||||
output = v.Int()
|
||||
ok = true
|
||||
}
|
||||
|
||||
if v := usage.Get("cache_read_input_tokens"); v.Exists() {
|
||||
// v1/messages (Anthropic)
|
||||
cached = v.Int()
|
||||
ok = true
|
||||
} else if v := usage.Get("input_tokens_details.cached_tokens"); v.Exists() {
|
||||
// v1/responses (OpenAI Responses API)
|
||||
cached = v.Int()
|
||||
ok = true
|
||||
} else if v := usage.Get("prompt_tokens_details.cached_tokens"); v.Exists() {
|
||||
// v1/chat/completions (OpenAI cache hits)
|
||||
cached = v.Int()
|
||||
ok = true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func processStreamingResponse(modelID string, start time.Time, body []byte) (ActivityLogEntry, error) {
|
||||
// Walk SSE "data:" lines forward, merging usage info from every event.
|
||||
// Different endpoints split usage across events:
|
||||
// - v1/chat/completions: usage on the final chunk before [DONE]
|
||||
// - v1/responses: usage on response.completed (response.usage)
|
||||
// - v1/messages: input + cache on message_start (message.usage),
|
||||
// output_tokens on message_delta (usage)
|
||||
// We take the latest informative value per field so all three are covered.
|
||||
|
||||
var (
|
||||
inputTokens, outputTokens int64
|
||||
cachedTokens int64 = -1
|
||||
hasAny bool
|
||||
timings gjson.Result
|
||||
)
|
||||
|
||||
prefix := []byte("data:")
|
||||
for offset := 0; offset < len(body); {
|
||||
nl := bytes.IndexByte(body[offset:], '\n')
|
||||
var line []byte
|
||||
if nl == -1 {
|
||||
line = body[offset:]
|
||||
offset = len(body)
|
||||
} else {
|
||||
line = body[offset : offset+nl]
|
||||
offset += nl + 1
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || !bytes.HasPrefix(line, prefix) {
|
||||
continue
|
||||
}
|
||||
data := bytes.TrimSpace(line[len(prefix):])
|
||||
if len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
if !gjson.ValidBytes(data) {
|
||||
continue
|
||||
}
|
||||
parsed := gjson.ParseBytes(data)
|
||||
|
||||
for _, path := range usagePaths {
|
||||
u := parsed.Get(path)
|
||||
if !u.Exists() {
|
||||
continue
|
||||
}
|
||||
i, o, c, ok := extractUsageTokens(u)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
hasAny = true
|
||||
// Take the latest non-zero value so message_start's input_tokens
|
||||
// is preserved when message_delta's usage omits it, and vice versa
|
||||
// for output_tokens.
|
||||
if i > 0 {
|
||||
inputTokens = i
|
||||
}
|
||||
if o > 0 {
|
||||
outputTokens = o
|
||||
}
|
||||
if c >= 0 {
|
||||
cachedTokens = c
|
||||
}
|
||||
}
|
||||
if t := parsed.Get("timings"); t.Exists() {
|
||||
timings = t
|
||||
hasAny = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasAny {
|
||||
return ActivityLogEntry{}, fmt.Errorf("no valid JSON data found in stream")
|
||||
}
|
||||
|
||||
return buildMetrics(modelID, start, inputTokens, outputTokens, cachedTokens, timings), nil
|
||||
}
|
||||
|
||||
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (ActivityLogEntry, error) {
|
||||
input, output, cached, _ := extractUsageTokens(usage)
|
||||
return buildMetrics(modelID, start, input, output, cached, timings), nil
|
||||
}
|
||||
|
||||
// buildMetrics composes an ActivityLogEntry from accumulated token counts and
|
||||
// optional llama-server timings (which override input/output and provide rates).
|
||||
func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, cachedTokens int64, timings gjson.Result) ActivityLogEntry {
|
||||
wallDurationMs := int(time.Since(start).Milliseconds())
|
||||
durationMs := wallDurationMs
|
||||
tokensPerSecond := -1.0
|
||||
promptPerSecond := -1.0
|
||||
|
||||
if timings.Exists() {
|
||||
inputTokens = timings.Get("prompt_n").Int()
|
||||
outputTokens = timings.Get("predicted_n").Int()
|
||||
promptPerSecond = timings.Get("prompt_per_second").Float()
|
||||
tokensPerSecond = timings.Get("predicted_per_second").Float()
|
||||
timingsDurationMs := int(timings.Get("prompt_ms").Float() + timings.Get("predicted_ms").Float())
|
||||
if timingsDurationMs > durationMs {
|
||||
durationMs = timingsDurationMs
|
||||
}
|
||||
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
|
||||
cachedTokens = cachedValue.Int()
|
||||
}
|
||||
}
|
||||
|
||||
return ActivityLogEntry{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
Tokens: TokenMetrics{
|
||||
CachedTokens: int(cachedTokens),
|
||||
InputTokens: int(inputTokens),
|
||||
OutputTokens: int(outputTokens),
|
||||
PromptPerSecond: promptPerSecond,
|
||||
TokensPerSecond: tokensPerSecond,
|
||||
},
|
||||
DurationMs: durationMs,
|
||||
}
|
||||
}
|
||||
|
||||
// decompressBody decompresses the body based on Content-Encoding header
|
||||
func decompressBody(body []byte, encoding string) ([]byte, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(encoding)) {
|
||||
case "gzip":
|
||||
reader, err := gzip.NewReader(bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
case "deflate":
|
||||
reader := flate.NewReader(bytes.NewReader(body))
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
default:
|
||||
return body, nil // Return as-is for unknown/no encoding
|
||||
}
|
||||
}
|
||||
|
||||
// responseBodyCopier records the response body and writes to the original response writer
|
||||
// while also capturing it in a buffer for later processing
|
||||
type responseBodyCopier struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
tee io.Writer
|
||||
start time.Time
|
||||
}
|
||||
|
||||
func newBodyCopier(w gin.ResponseWriter) *responseBodyCopier {
|
||||
bodyBuffer := &bytes.Buffer{}
|
||||
return &responseBodyCopier{
|
||||
ResponseWriter: w,
|
||||
body: bodyBuffer,
|
||||
tee: io.MultiWriter(w, bodyBuffer),
|
||||
start: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) Write(b []byte) (int, error) {
|
||||
return w.tee.Write(b)
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) WriteHeader(statusCode int) {
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) Header() http.Header {
|
||||
return w.ResponseWriter.Header()
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) StartTime() time.Time {
|
||||
return w.start
|
||||
}
|
||||
|
||||
// sensitiveHeaders lists headers that should be redacted in captures
|
||||
var sensitiveHeaders = map[string]bool{
|
||||
"authorization": true,
|
||||
"proxy-authorization": true,
|
||||
"cookie": true,
|
||||
"set-cookie": true,
|
||||
"x-api-key": true,
|
||||
}
|
||||
|
||||
// redactHeaders replaces sensitive header values in-place with "[REDACTED]"
|
||||
func redactHeaders(headers map[string]string) {
|
||||
for key := range headers {
|
||||
if sensitiveHeaders[strings.ToLower(key)] {
|
||||
headers[key] = "[REDACTED]"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// filterAcceptEncoding filters the Accept-Encoding header to only include
|
||||
// encodings we can decompress (gzip, deflate). This respects the client's
|
||||
// preferences while ensuring we can parse response bodies for metrics.
|
||||
func filterAcceptEncoding(acceptEncoding string) string {
|
||||
if acceptEncoding == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
supported := map[string]bool{"gzip": true, "deflate": true}
|
||||
var filtered []string
|
||||
|
||||
for part := range strings.SplitSeq(acceptEncoding, ",") {
|
||||
// Parse encoding and optional quality value (e.g., "gzip;q=1.0")
|
||||
encoding, _, _ := strings.Cut(strings.TrimSpace(part), ";")
|
||||
if supported[strings.ToLower(encoding)] {
|
||||
filtered = append(filtered, strings.TrimSpace(part))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(filtered, ", ")
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,144 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
type peerProxyMember struct {
|
||||
peerID string
|
||||
reverseProxy *httputil.ReverseProxy
|
||||
apiKey string
|
||||
}
|
||||
|
||||
type PeerProxy struct {
|
||||
peers config.PeerDictionaryConfig
|
||||
proxyMap map[string]*peerProxyMember
|
||||
}
|
||||
|
||||
func NewPeerProxy(peers config.PeerDictionaryConfig, proxyLogger *logmon.Monitor) (*PeerProxy, error) {
|
||||
proxyMap := make(map[string]*peerProxyMember)
|
||||
|
||||
// Sort peer IDs for consistent iteration order
|
||||
peerIDs := make([]string, 0, len(peers))
|
||||
for peerID := range peers {
|
||||
peerIDs = append(peerIDs, peerID)
|
||||
}
|
||||
sort.Strings(peerIDs)
|
||||
|
||||
for _, peerID := range peerIDs {
|
||||
peer := peers[peerID]
|
||||
|
||||
// Create a transport with per-peer timeout configuration
|
||||
peerTransport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: time.Duration(peer.Timeouts.Connect) * time.Second,
|
||||
KeepAlive: time.Duration(peer.Timeouts.KeepAlive) * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: time.Duration(peer.Timeouts.TLSHandshake) * time.Second,
|
||||
ResponseHeaderTimeout: time.Duration(peer.Timeouts.ResponseHeader) * time.Second,
|
||||
ExpectContinueTimeout: time.Duration(peer.Timeouts.ExpectContinue) * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second,
|
||||
}
|
||||
|
||||
// Create reverse proxy for this peer
|
||||
reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL)
|
||||
reverseProxy.Transport = peerTransport
|
||||
|
||||
// Wrap Director to set Host header for remote hosts (not localhost)
|
||||
originalDirector := reverseProxy.Director
|
||||
reverseProxy.Director = func(req *http.Request) {
|
||||
originalDirector(req)
|
||||
// Ensure Host header matches target URL for remote proxying
|
||||
req.Host = req.URL.Host
|
||||
}
|
||||
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||
resp.Header.Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
proxyLogger.Warnf("peer %s: proxy error: %v", peerID, err)
|
||||
errMsg := fmt.Sprintf("peer proxy error: %v", err)
|
||||
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
|
||||
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
|
||||
}
|
||||
http.Error(w, errMsg, http.StatusBadGateway)
|
||||
}
|
||||
|
||||
pp := &peerProxyMember{
|
||||
peerID: peerID,
|
||||
reverseProxy: reverseProxy,
|
||||
apiKey: peer.ApiKey,
|
||||
}
|
||||
|
||||
// Map each model to this peer's proxy
|
||||
for _, modelID := range peer.Models {
|
||||
if _, found := proxyMap[modelID]; found {
|
||||
proxyLogger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
|
||||
continue
|
||||
}
|
||||
proxyMap[modelID] = pp
|
||||
}
|
||||
}
|
||||
|
||||
return &PeerProxy{
|
||||
peers: peers,
|
||||
proxyMap: proxyMap,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PeerProxy) HasPeerModel(modelID string) bool {
|
||||
_, found := p.proxyMap[modelID]
|
||||
return found
|
||||
}
|
||||
|
||||
// GetPeerFilters returns the filters for a peer model, or empty filters if not found
|
||||
func (p *PeerProxy) GetPeerFilters(modelID string) config.Filters {
|
||||
pp, found := p.proxyMap[modelID]
|
||||
if !found {
|
||||
return config.Filters{}
|
||||
}
|
||||
// Get the peer config using the peerID
|
||||
peer, found := p.peers[pp.peerID]
|
||||
if !found {
|
||||
return config.Filters{}
|
||||
}
|
||||
return peer.Filters
|
||||
}
|
||||
|
||||
func (p *PeerProxy) ListPeers() config.PeerDictionaryConfig {
|
||||
return p.peers
|
||||
}
|
||||
|
||||
func (p *PeerProxy) ProxyRequest(model_id string, writer http.ResponseWriter, request *http.Request) error {
|
||||
pp, found := p.proxyMap[model_id]
|
||||
if !found {
|
||||
return fmt.Errorf("no peer proxy found for model %s", model_id)
|
||||
}
|
||||
|
||||
// Inject API key if configured for this peer
|
||||
if pp.apiKey != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+pp.apiKey)
|
||||
request.Header.Set("x-api-key", pp.apiKey)
|
||||
}
|
||||
|
||||
pp.reverseProxy.ServeHTTP(writer, request)
|
||||
return nil
|
||||
}
|
||||
@@ -1,311 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewPeerProxy_EmptyPeers(t *testing.T) {
|
||||
peers := config.PeerDictionaryConfig{}
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, pm)
|
||||
assert.Empty(t, pm.proxyMap)
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_SinglePeer(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "test-key",
|
||||
Models: []string{"model-a", "model-b"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pm.proxyMap, 2)
|
||||
assert.True(t, pm.HasPeerModel("model-a"))
|
||||
assert.True(t, pm.HasPeerModel("model-b"))
|
||||
assert.False(t, pm.HasPeerModel("model-c"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_MultiplePeers(t *testing.T) {
|
||||
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL1,
|
||||
Models: []string{"model-a", "model-b"},
|
||||
},
|
||||
"peer2": config.PeerConfig{
|
||||
Proxy: "http://peer2.example.com:8080",
|
||||
ProxyURL: proxyURL2,
|
||||
Models: []string{"model-c", "model-d"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pm.proxyMap, 4)
|
||||
assert.True(t, pm.HasPeerModel("model-a"))
|
||||
assert.True(t, pm.HasPeerModel("model-b"))
|
||||
assert.True(t, pm.HasPeerModel("model-c"))
|
||||
assert.True(t, pm.HasPeerModel("model-d"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_DuplicateModelWarning(t *testing.T) {
|
||||
// When the same model is in multiple peers, only the first (lexicographically by peer ID)
|
||||
// should be mapped, and a warning should be logged
|
||||
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"alpha-peer": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL1,
|
||||
Models: []string{"duplicate-model"},
|
||||
},
|
||||
"beta-peer": config.PeerConfig{
|
||||
Proxy: "http://peer2.example.com:8080",
|
||||
ProxyURL: proxyURL2,
|
||||
Models: []string{"duplicate-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
// Should only have one entry for the duplicate model
|
||||
assert.Len(t, pm.proxyMap, 1)
|
||||
assert.True(t, pm.HasPeerModel("duplicate-model"))
|
||||
}
|
||||
|
||||
func TestHasPeerModel(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"existing-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, pm.HasPeerModel("existing-model"))
|
||||
assert.False(t, pm.HasPeerModel("non-existing-model"))
|
||||
}
|
||||
|
||||
func TestProxyRequest_ModelNotFound(t *testing.T) {
|
||||
peers := config.PeerDictionaryConfig{}
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("non-existing-model", w, req)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no peer proxy found for model non-existing-model")
|
||||
}
|
||||
|
||||
func TestProxyRequest_Success(t *testing.T) {
|
||||
// Create a test server to act as the peer
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("response from peer"))
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "response from peer", w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyRequest_ApiKeyInjection(t *testing.T) {
|
||||
// Create a test server that checks for the Authorization header
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "secret-api-key",
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Bearer secret-api-key", receivedAuthHeader)
|
||||
}
|
||||
|
||||
func TestProxyRequest_NoApiKey(t *testing.T) {
|
||||
// Create a test server that checks for the Authorization header
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "", // No API key
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, receivedAuthHeader)
|
||||
}
|
||||
|
||||
func TestProxyRequest_HostHeaderSet(t *testing.T) {
|
||||
// Create a test server that checks the Host header
|
||||
var receivedHost string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedHost = r.Host
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
// The Host header should be set to the target URL's host
|
||||
assert.True(t, strings.HasPrefix(receivedHost, "127.0.0.1:"))
|
||||
}
|
||||
|
||||
func TestProxyRequest_SSEHeaderModification(t *testing.T) {
|
||||
// Create a test server that returns SSE content type
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
// The X-Accel-Buffering header should be set to "no" for SSE
|
||||
assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_CustomTimeouts(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://localhost:8080")
|
||||
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"test-peer": config.PeerConfig{
|
||||
Proxy: "http://localhost:8080",
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"model1"},
|
||||
Timeouts: config.TimeoutsConfig{
|
||||
Connect: 45,
|
||||
ResponseHeader: 300,
|
||||
TLSHandshake: 15,
|
||||
ExpectContinue: 2,
|
||||
IdleConn: 120,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
peerProxy, err := NewPeerProxy(peers, testLogger)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, peerProxy)
|
||||
assert.True(t, peerProxy.HasPeerModel("model1"))
|
||||
|
||||
// Verify the timeout values are actually applied to the transport
|
||||
member, found := peerProxy.proxyMap["model1"]
|
||||
require.True(t, found, "model1 should exist in proxyMap")
|
||||
assert.NotNil(t, member.reverseProxy)
|
||||
assert.NotNil(t, member.reverseProxy.Transport)
|
||||
|
||||
transport, ok := member.reverseProxy.Transport.(*http.Transport)
|
||||
require.True(t, ok, "Transport should be *http.Transport")
|
||||
|
||||
// Verify all timeout values are correctly applied
|
||||
assert.Equal(t, 300*time.Second, transport.ResponseHeaderTimeout)
|
||||
assert.Equal(t, 15*time.Second, transport.TLSHandshakeTimeout)
|
||||
assert.Equal(t, 2*time.Second, transport.ExpectContinueTimeout)
|
||||
assert.Equal(t, 120*time.Second, transport.IdleConnTimeout)
|
||||
// ForceAttemptHTTP2 should be enabled
|
||||
assert.True(t, transport.ForceAttemptHTTP2)
|
||||
}
|
||||
@@ -1,956 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
type ProcessState string
|
||||
|
||||
const (
|
||||
StateStopped ProcessState = ProcessState("stopped")
|
||||
StateStarting ProcessState = ProcessState("starting")
|
||||
StateReady ProcessState = ProcessState("ready")
|
||||
StateStopping ProcessState = ProcessState("stopping")
|
||||
|
||||
// process is shutdown and will not be restarted
|
||||
StateShutdown ProcessState = ProcessState("shutdown")
|
||||
)
|
||||
|
||||
type StopStrategy int
|
||||
|
||||
const (
|
||||
StopImmediately StopStrategy = iota
|
||||
StopWaitForInflightRequest
|
||||
)
|
||||
|
||||
type Process struct {
|
||||
ID string
|
||||
config config.ModelConfig
|
||||
cmd *exec.Cmd
|
||||
reverseProxy *httputil.ReverseProxy
|
||||
|
||||
// PR #155 called to cancel the upstream process
|
||||
cmdMutex sync.RWMutex
|
||||
cancelUpstream context.CancelFunc
|
||||
|
||||
// closed when command exits
|
||||
cmdWaitChan chan struct{}
|
||||
|
||||
processLogger *logmon.Monitor
|
||||
proxyLogger *logmon.Monitor
|
||||
|
||||
healthCheckTimeout int
|
||||
healthCheckLoopInterval time.Duration
|
||||
|
||||
lastRequestHandledMutex sync.RWMutex
|
||||
lastRequestHandled time.Time
|
||||
|
||||
stateMutex sync.RWMutex
|
||||
state ProcessState
|
||||
|
||||
inFlightRequests sync.WaitGroup
|
||||
inFlightRequestsCount atomic.Int32
|
||||
|
||||
// used to block on multiple start() calls
|
||||
waitStarting sync.WaitGroup
|
||||
|
||||
// for managing concurrency limits
|
||||
concurrencyLimitSemaphore chan struct{}
|
||||
|
||||
// used for testing to override the default value
|
||||
gracefulStopTimeout time.Duration
|
||||
|
||||
// used for testing to bypass subprocess and reverse proxy
|
||||
testHandler http.Handler
|
||||
|
||||
// track the number of failed starts
|
||||
failedStartCount int
|
||||
}
|
||||
|
||||
func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, processLogger *logmon.Monitor, proxyLogger *logmon.Monitor) *Process {
|
||||
concurrentLimit := 10
|
||||
if config.ConcurrencyLimit > 0 {
|
||||
concurrentLimit = config.ConcurrencyLimit
|
||||
}
|
||||
|
||||
// Setup the reverse proxy.
|
||||
proxyURL, err := url.Parse(config.Proxy)
|
||||
if err != nil {
|
||||
proxyLogger.Errorf("<%s> invalid proxy URL %q: %v", ID, config.Proxy, err)
|
||||
}
|
||||
|
||||
var reverseProxy *httputil.ReverseProxy
|
||||
if proxyURL != nil {
|
||||
reverseProxy = httputil.NewSingleHostReverseProxy(proxyURL)
|
||||
|
||||
// Create custom transport with configured timeouts
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: time.Duration(config.Timeouts.Connect) * time.Second,
|
||||
KeepAlive: time.Duration(config.Timeouts.KeepAlive) * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: time.Duration(config.Timeouts.TLSHandshake) * time.Second,
|
||||
ResponseHeaderTimeout: time.Duration(config.Timeouts.ResponseHeader) * time.Second,
|
||||
ExpectContinueTimeout: time.Duration(config.Timeouts.ExpectContinue) * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: time.Duration(config.Timeouts.IdleConn) * time.Second,
|
||||
}
|
||||
reverseProxy.Transport = transport
|
||||
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
// prevent nginx from buffering streaming responses (e.g., SSE)
|
||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||
resp.Header.Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return &Process{
|
||||
ID: ID,
|
||||
config: config,
|
||||
cmd: nil,
|
||||
reverseProxy: reverseProxy,
|
||||
cancelUpstream: nil,
|
||||
processLogger: processLogger,
|
||||
proxyLogger: proxyLogger,
|
||||
healthCheckTimeout: healthCheckTimeout,
|
||||
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
|
||||
state: StateStopped,
|
||||
|
||||
// concurrency limit
|
||||
concurrencyLimitSemaphore: make(chan struct{}, concurrentLimit),
|
||||
|
||||
// To be removed when migration over exec.CommandContext is complete
|
||||
// stop timeout
|
||||
gracefulStopTimeout: 10 * time.Second,
|
||||
cmdWaitChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// LogMonitor returns the log monitor associated with the process.
|
||||
func (p *Process) LogMonitor() *logmon.Monitor {
|
||||
return p.processLogger
|
||||
}
|
||||
|
||||
// setLastRequestHandled sets the last request handled time in a thread-safe manner.
|
||||
func (p *Process) setLastRequestHandled(t time.Time) {
|
||||
p.lastRequestHandledMutex.Lock()
|
||||
defer p.lastRequestHandledMutex.Unlock()
|
||||
p.lastRequestHandled = t
|
||||
}
|
||||
|
||||
// getLastRequestHandled gets the last request handled time in a thread-safe manner.
|
||||
func (p *Process) getLastRequestHandled() time.Time {
|
||||
p.lastRequestHandledMutex.RLock()
|
||||
defer p.lastRequestHandledMutex.RUnlock()
|
||||
return p.lastRequestHandled
|
||||
}
|
||||
|
||||
// custom error types for swapping state
|
||||
var (
|
||||
ErrExpectedStateMismatch = errors.New("expected state mismatch")
|
||||
ErrInvalidStateTransition = errors.New("invalid state transition")
|
||||
)
|
||||
|
||||
// swapState performs a compare and swap of the state atomically. It returns the current state
|
||||
// and an error if the swap failed.
|
||||
func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, error) {
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
|
||||
if p.state != expectedState {
|
||||
p.proxyLogger.Warnf("<%s> swapState() Unexpected current state %s, expected %s", p.ID, p.state, expectedState)
|
||||
return p.state, ErrExpectedStateMismatch
|
||||
}
|
||||
|
||||
if !isValidTransition(p.state, newState) {
|
||||
p.proxyLogger.Warnf("<%s> swapState() Invalid state transition from %s to %s", p.ID, p.state, newState)
|
||||
return p.state, ErrInvalidStateTransition
|
||||
}
|
||||
|
||||
p.state = newState
|
||||
|
||||
// Atomically increment waitStarting when entering StateStarting
|
||||
// This ensures any thread that sees StateStarting will also see the WaitGroup counter incremented
|
||||
if newState == StateStarting {
|
||||
p.waitStarting.Add(1)
|
||||
}
|
||||
|
||||
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
|
||||
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
|
||||
return p.state, nil
|
||||
}
|
||||
|
||||
// Helper function to encapsulate transition rules
|
||||
func isValidTransition(from, to ProcessState) bool {
|
||||
switch from {
|
||||
case StateStopped:
|
||||
return to == StateStarting
|
||||
case StateStarting:
|
||||
return to == StateReady || to == StateStopping || to == StateStopped
|
||||
case StateReady:
|
||||
return to == StateStopping
|
||||
case StateStopping:
|
||||
return to == StateStopped || to == StateShutdown
|
||||
case StateShutdown:
|
||||
return false // No transitions allowed from these states
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Process) CurrentState() ProcessState {
|
||||
p.stateMutex.RLock()
|
||||
defer p.stateMutex.RUnlock()
|
||||
return p.state
|
||||
}
|
||||
|
||||
// forceState forces the process state to the new state with mutex protection.
|
||||
// This should only be used in exceptional cases where the normal state transition
|
||||
// validation via swapState() cannot be used.
|
||||
func (p *Process) forceState(newState ProcessState) {
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
p.state = newState
|
||||
}
|
||||
|
||||
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
|
||||
// it is a private method because starting is automatic but stopping can be called
|
||||
// at any time.
|
||||
func (p *Process) start() error {
|
||||
|
||||
// test-only fast path: skip subprocess, health check, and TTL goroutine
|
||||
if p.testHandler != nil {
|
||||
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
|
||||
if err == ErrExpectedStateMismatch {
|
||||
if curState == StateStarting {
|
||||
p.waitStarting.Wait()
|
||||
curState = p.CurrentState()
|
||||
if curState == StateReady {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("process was already starting but wound up in state %v", curState)
|
||||
}
|
||||
return fmt.Errorf("process was in state %v when start() was called", curState)
|
||||
}
|
||||
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
|
||||
}
|
||||
defer p.waitStarting.Done()
|
||||
|
||||
// Mimic the real stop path: cancelUpstream transitions
|
||||
// StateStopping -> StateStopped and closes cmdWaitChan,
|
||||
// matching what waitForCmd does for real subprocesses.
|
||||
ch := make(chan struct{})
|
||||
p.cmdMutex.Lock()
|
||||
p.cancelUpstream = func() {
|
||||
if curState := p.CurrentState(); curState == StateStopping {
|
||||
if _, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||
p.forceState(StateStopped)
|
||||
}
|
||||
} else {
|
||||
p.forceState(StateStopped)
|
||||
}
|
||||
close(ch)
|
||||
}
|
||||
p.cmdWaitChan = ch
|
||||
p.cmdMutex.Unlock()
|
||||
|
||||
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
|
||||
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
|
||||
}
|
||||
p.failedStartCount = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
if p.config.Proxy == "" {
|
||||
return fmt.Errorf("can not start(), upstream proxy missing")
|
||||
}
|
||||
|
||||
args, err := p.config.SanitizedCommand()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get sanitized command: %v", err)
|
||||
}
|
||||
|
||||
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
|
||||
if err == ErrExpectedStateMismatch {
|
||||
// already starting, just wait for it to complete and expect
|
||||
// it to be be in the Ready start after. If not, return an error
|
||||
if curState == StateStarting {
|
||||
p.waitStarting.Wait()
|
||||
if state := p.CurrentState(); state == StateReady {
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("process was already starting but wound up in state %v", state)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("process was in state %v when start() was called", curState)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
|
||||
}
|
||||
}
|
||||
|
||||
// waitStarting.Add(1) is now called atomically in swapState() when transitioning to StateStarting
|
||||
defer p.waitStarting.Done()
|
||||
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())
|
||||
|
||||
p.cmd = exec.CommandContext(cmdContext, args[0], args[1:]...)
|
||||
p.cmd.Stdout = p.processLogger
|
||||
p.cmd.Stderr = p.processLogger
|
||||
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
|
||||
p.cmd.Cancel = p.cmdStopUpstreamProcess
|
||||
p.cmd.WaitDelay = p.gracefulStopTimeout
|
||||
setProcAttributes(p.cmd)
|
||||
|
||||
p.cmdMutex.Lock()
|
||||
p.cancelUpstream = ctxCancelUpstream
|
||||
p.cmdWaitChan = make(chan struct{})
|
||||
p.cmdMutex.Unlock()
|
||||
|
||||
p.failedStartCount++ // this will be reset to zero when the process has successfully started
|
||||
|
||||
p.proxyLogger.Debugf("<%s> Executing start command: %s, env: %s", p.ID, strings.Join(args, " "), strings.Join(p.config.Env, ", "))
|
||||
err = p.cmd.Start()
|
||||
|
||||
// Set process state to failed
|
||||
if err != nil {
|
||||
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
|
||||
p.forceState(StateStopped) // force it into a stopped state
|
||||
return fmt.Errorf(
|
||||
"failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
||||
strings.Join(args, " "), err, curState, swapErr,
|
||||
)
|
||||
}
|
||||
return fmt.Errorf("start() failed for command '%s': %v", strings.Join(args, " "), err)
|
||||
}
|
||||
|
||||
// Capture the exit error for later signalling
|
||||
go p.waitForCmd()
|
||||
|
||||
// One of three things can happen at this stage:
|
||||
// 1. The command exits unexpectedly
|
||||
// 2. The health check fails
|
||||
// 3. The health check passes
|
||||
//
|
||||
// only in the third case will the process be considered Ready to accept
|
||||
<-time.After(250 * time.Millisecond) // give process a bit of time to start
|
||||
|
||||
checkStartTime := time.Now()
|
||||
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
|
||||
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
||||
|
||||
// a "none" means don't check for health ... I could have picked a better word :facepalm:
|
||||
if checkEndpoint != "none" {
|
||||
proxyTo := p.config.Proxy
|
||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint)
|
||||
}
|
||||
|
||||
// Ready Check loop
|
||||
for {
|
||||
currentState := p.CurrentState()
|
||||
if currentState != StateStarting {
|
||||
if currentState == StateStopped {
|
||||
return fmt.Errorf("upstream command exited prematurely but successfully")
|
||||
}
|
||||
return errors.New("health check interrupted due to shutdown")
|
||||
}
|
||||
|
||||
if time.Since(checkStartTime) > maxDuration {
|
||||
p.stopCommand()
|
||||
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
|
||||
}
|
||||
|
||||
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
||||
p.proxyLogger.Infof("<%s> Health check passed on %s", p.ID, healthURL)
|
||||
break
|
||||
} else {
|
||||
if strings.Contains(err.Error(), "connection refused") {
|
||||
ttl := time.Until(checkStartTime.Add(maxDuration))
|
||||
p.proxyLogger.Debugf("<%s> Connection refused on %s, giving up in %.0fs (normal during startup)", p.ID, healthURL, ttl.Seconds())
|
||||
} else {
|
||||
p.proxyLogger.Debugf("<%s> Health check error on %s, %v (normal during startup)", p.ID, healthURL, err)
|
||||
}
|
||||
}
|
||||
<-time.After(p.healthCheckLoopInterval)
|
||||
}
|
||||
}
|
||||
|
||||
if p.config.UnloadAfter > 0 {
|
||||
// start a goroutine to check every second if
|
||||
// the process should be stopped
|
||||
go func() {
|
||||
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
|
||||
|
||||
for range time.Tick(time.Second) {
|
||||
if p.CurrentState() != StateReady {
|
||||
return
|
||||
}
|
||||
|
||||
// skip the TTL check if there are inflight requests
|
||||
if p.inFlightRequestsCount.Load() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if time.Since(p.getLastRequestHandled()) > maxDuration {
|
||||
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
|
||||
p.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
|
||||
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
|
||||
} else {
|
||||
p.failedStartCount = 0
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Stop will wait for inflight requests to complete before stopping the process.
|
||||
func (p *Process) Stop() {
|
||||
|
||||
// guard to prevent multiple goroutines from stopping
|
||||
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||
p.proxyLogger.Debugf("<%s> Stop() suppressing invalid transition from %s to StateStopping", p.ID, p.CurrentState())
|
||||
return
|
||||
}
|
||||
|
||||
// wait for any inflight requests before proceeding
|
||||
p.proxyLogger.Debugf("<%s> Stop(): Waiting for inflight requests to complete", p.ID)
|
||||
p.inFlightRequests.Wait()
|
||||
p.StopImmediately()
|
||||
}
|
||||
|
||||
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
|
||||
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
|
||||
func (p *Process) StopImmediately() {
|
||||
|
||||
// guard to prevent multiple goroutines from stopping the process
|
||||
enterState := p.CurrentState()
|
||||
if !isValidTransition(enterState, StateStopping) {
|
||||
p.proxyLogger.Debugf("<%s> StopImmediate() suppressing invalid transition from %s to StateStopping", p.ID, p.CurrentState())
|
||||
return
|
||||
}
|
||||
|
||||
p.proxyLogger.Debugf("<%s> Stopping process, enter state: %s", p.ID, enterState)
|
||||
if curState, err := p.swapState(enterState, StateStopping); err != nil {
|
||||
p.proxyLogger.Infof("<%s> Stop() %s -> StateStopping err: %v, current state: %v", p.ID, enterState, err, curState)
|
||||
return
|
||||
}
|
||||
|
||||
p.stopCommand()
|
||||
}
|
||||
|
||||
// Shutdown is called when llama-swap is shutting down. It will give a little bit
|
||||
// of time for any inflight requests to complete before shutting down. If the Process
|
||||
// is in the state of starting, it will cancel it and shut it down. Once a process is in
|
||||
// the StateShutdown state, it can not be started again.
|
||||
func (p *Process) Shutdown() {
|
||||
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||
return
|
||||
}
|
||||
|
||||
p.stopCommand()
|
||||
// just force it to this state since there is no recovery from shutdown
|
||||
p.forceState(StateShutdown)
|
||||
}
|
||||
|
||||
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
||||
// If it does not exit within 5 seconds, it will send a SIGKILL.
|
||||
func (p *Process) stopCommand() {
|
||||
stopStartTime := time.Now()
|
||||
defer func() {
|
||||
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
||||
|
||||
// free the buffer in processLogger so the memory can be recovered
|
||||
p.processLogger.Clear()
|
||||
}()
|
||||
|
||||
p.cmdMutex.RLock()
|
||||
cancelUpstream := p.cancelUpstream
|
||||
cmdWaitChan := p.cmdWaitChan
|
||||
p.cmdMutex.RUnlock()
|
||||
|
||||
if cancelUpstream == nil {
|
||||
p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID)
|
||||
return
|
||||
}
|
||||
|
||||
cancelUpstream()
|
||||
<-cmdWaitChan
|
||||
}
|
||||
|
||||
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||
|
||||
client := &http.Client{
|
||||
// wait a short time for a tcp connection to be established
|
||||
Transport: &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 500 * time.Millisecond,
|
||||
}).DialContext,
|
||||
},
|
||||
|
||||
// give a long time to respond to the health check endpoint
|
||||
// after the connection is established. See issue: 276
|
||||
Timeout: 5000 * time.Millisecond,
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", healthURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// got a response but it was not an OK
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if p.reverseProxy == nil {
|
||||
http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
requestBeginTime := time.Now()
|
||||
var startDuration time.Duration
|
||||
|
||||
// prevent new requests from being made while stopping or irrecoverable
|
||||
currentState := p.CurrentState()
|
||||
if currentState == StateShutdown || currentState == StateStopping {
|
||||
http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case p.concurrencyLimitSemaphore <- struct{}{}:
|
||||
defer func() { <-p.concurrencyLimitSemaphore }()
|
||||
default:
|
||||
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
p.inFlightRequests.Add(1)
|
||||
p.inFlightRequestsCount.Add(1)
|
||||
defer func() {
|
||||
p.setLastRequestHandled(time.Now())
|
||||
p.inFlightRequestsCount.Add(-1)
|
||||
p.inFlightRequests.Done()
|
||||
}()
|
||||
|
||||
// for #366
|
||||
// - extract streaming param from request context, should have been set by proxymanager
|
||||
var srw *statusResponseWriter
|
||||
swapCtx, cancelLoadCtx := context.WithCancel(r.Context())
|
||||
// start the process on demand
|
||||
if p.CurrentState() != StateReady {
|
||||
// start a goroutine to stream loading status messages into the response writer
|
||||
// add a sync so the streaming client only runs when the goroutine has exited
|
||||
|
||||
isStreaming, _ := r.Context().Value(proxyCtxKey("streaming")).(bool)
|
||||
|
||||
// PR #417 (no support for anthropic v1/messages yet)
|
||||
isChatCompletions := strings.HasPrefix(r.URL.Path, "/v1/chat/completions")
|
||||
if p.config.SendLoadingState != nil && *p.config.SendLoadingState && isStreaming && isChatCompletions {
|
||||
srw = newStatusResponseWriter(p, w)
|
||||
go srw.statusUpdates(swapCtx)
|
||||
} else {
|
||||
p.proxyLogger.Debugf("<%s> SendLoadingState is nil or false, not streaming loading state", p.ID)
|
||||
}
|
||||
|
||||
beginStartTime := time.Now()
|
||||
if err := p.start(); err != nil {
|
||||
errstr := fmt.Sprintf("unable to start process: %s", err)
|
||||
cancelLoadCtx()
|
||||
if srw != nil {
|
||||
srw.sendData(fmt.Sprintf("Unable to swap model err: %s\n", errstr))
|
||||
// Wait for statusUpdates goroutine to finish writing its deferred "Done!" messages
|
||||
// before closing the connection. Without this, the connection would close before
|
||||
// the goroutine can write its cleanup messages, causing incomplete SSE output.
|
||||
srw.waitForCompletion(100 * time.Millisecond)
|
||||
} else {
|
||||
http.Error(w, errstr, http.StatusBadGateway)
|
||||
}
|
||||
return
|
||||
}
|
||||
startDuration = time.Since(beginStartTime)
|
||||
}
|
||||
|
||||
// should trigger srw to stop sending loading events ...
|
||||
cancelLoadCtx()
|
||||
|
||||
// recover from http.ErrAbortHandler panics that can occur when the client
|
||||
// disconnects before the response is sent
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if r == http.ErrAbortHandler {
|
||||
p.proxyLogger.Infof("<%s> recovered from client disconnection during streaming", p.ID)
|
||||
} else {
|
||||
p.proxyLogger.Infof("<%s> recovered from panic: %v", p.ID, r)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if srw != nil {
|
||||
// Wait for the goroutine to finish writing its final messages
|
||||
const completionTimeout = 1 * time.Second
|
||||
if !srw.waitForCompletion(completionTimeout) {
|
||||
p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
if p.testHandler != nil {
|
||||
p.testHandler.ServeHTTP(w, r)
|
||||
} else if srw != nil {
|
||||
p.reverseProxy.ServeHTTP(srw, r)
|
||||
} else {
|
||||
p.reverseProxy.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
totalTime := time.Since(requestBeginTime)
|
||||
p.proxyLogger.Debugf("<%s> request %s - start: %v, total: %v",
|
||||
p.ID, r.RequestURI, startDuration, totalTime)
|
||||
}
|
||||
|
||||
// waitForCmd waits for the command to exit and handles exit conditions depending on current state
|
||||
func (p *Process) waitForCmd() {
|
||||
exitErr := p.cmd.Wait()
|
||||
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
|
||||
|
||||
if exitErr != nil {
|
||||
if errno, ok := exitErr.(syscall.Errno); ok {
|
||||
p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno)
|
||||
} else if exitError, ok := exitErr.(*exec.ExitError); ok {
|
||||
if strings.Contains(exitError.String(), "signal: terminated") {
|
||||
p.proxyLogger.Debugf("<%s> Process stopped OK", p.ID)
|
||||
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
||||
p.proxyLogger.Debugf("<%s> Process interrupted OK", p.ID)
|
||||
} else {
|
||||
p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
|
||||
}
|
||||
} else {
|
||||
if exitErr.Error() != "context canceled" /* this is normal */ {
|
||||
p.proxyLogger.Errorf("<%s> Process exited >> %v", p.ID, exitErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
currentState := p.CurrentState()
|
||||
switch currentState {
|
||||
case StateStopping:
|
||||
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err)
|
||||
p.forceState(StateStopped)
|
||||
}
|
||||
default:
|
||||
p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState)
|
||||
p.forceState(StateStopped) // force it to be in this state
|
||||
}
|
||||
|
||||
p.cmdMutex.Lock()
|
||||
close(p.cmdWaitChan)
|
||||
p.cmdMutex.Unlock()
|
||||
}
|
||||
|
||||
// cmdStopUpstreamProcess attemps to stop the upstream process gracefully
|
||||
func (p *Process) cmdStopUpstreamProcess() error {
|
||||
p.processLogger.Debugf("<%s> cmdStopUpstreamProcess() initiating graceful stop of upstream process", p.ID)
|
||||
|
||||
// this should never happen ...
|
||||
if p.cmd == nil || p.cmd.Process == nil {
|
||||
p.proxyLogger.Debugf("<%s> cmd or cmd.Process is nil (normal during config reload)", p.ID)
|
||||
return fmt.Errorf("<%s> process is nil or cmd is nil, skipping graceful stop", p.ID)
|
||||
}
|
||||
|
||||
if p.config.CmdStop != "" {
|
||||
// replace ${PID} with the pid of the process
|
||||
stopArgs, err := config.SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
|
||||
if err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
p.proxyLogger.Debugf("<%s> Executing stop command: %s", p.ID, strings.Join(stopArgs, " "))
|
||||
|
||||
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
|
||||
stopCmd.Stdout = p.processLogger
|
||||
stopCmd.Stderr = p.processLogger
|
||||
setProcAttributes(stopCmd)
|
||||
stopCmd.Env = p.cmd.Env
|
||||
|
||||
if err := stopCmd.Run(); err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Failed to exec stop command: %v", p.ID, err)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := p.cmd.Process.Signal(syscall.SIGTERM); err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Failed to send SIGTERM to process: %v", p.ID, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Logger returns the logger for this process.
|
||||
func (p *Process) Logger() *logmon.Monitor {
|
||||
return p.processLogger
|
||||
}
|
||||
|
||||
var loadingRemarks = []string{
|
||||
"Still faster than your last standup meeting...",
|
||||
"Reticulating splines...",
|
||||
"Waking up the hamsters...",
|
||||
"Teaching the model manners...",
|
||||
"Convincing the GPU to participate...",
|
||||
"Loading weights (they're heavy)...",
|
||||
"Herding electrons...",
|
||||
"Compiling excuses for the delay...",
|
||||
"Downloading more RAM...",
|
||||
"Asking the model nicely to boot up...",
|
||||
"Bribing CUDA with cookies...",
|
||||
"Still loading (blame VRAM)...",
|
||||
"The model is fashionably late...",
|
||||
"Warming up those tensors...",
|
||||
"Making the neural net do push-ups...",
|
||||
"Your patience is appreciated (really)...",
|
||||
"Almost there (probably)...",
|
||||
"Loading like it's 1999...",
|
||||
"The model forgot where it put its keys...",
|
||||
"Quantum tunneling through layers...",
|
||||
"Negotiating with the PCIe bus...",
|
||||
"Defrosting frozen parameters...",
|
||||
"Teaching attention heads to focus...",
|
||||
"Running the matrix (slowly)...",
|
||||
"Untangling transformer blocks...",
|
||||
"Calibrating the flux capacitor...",
|
||||
"Spinning up the probability wheels...",
|
||||
"Waiting for the GPU to wake from its nap...",
|
||||
"Converting caffeine to compute...",
|
||||
"Allocating virtual patience...",
|
||||
"Performing arcane CUDA rituals...",
|
||||
"The model is stuck in traffic...",
|
||||
"Inflating embeddings...",
|
||||
"Summoning computational demons...",
|
||||
"Pleading with the OOM killer...",
|
||||
"Calculating the meaning of life (still at 42)...",
|
||||
"Training the training wheels...",
|
||||
"Optimizing the optimizer...",
|
||||
"Bootstrapping the bootstrapper...",
|
||||
"Loading loading screen...",
|
||||
"Processing processing logs...",
|
||||
"Buffering buffer overflow jokes...",
|
||||
"The model hit snooze...",
|
||||
"Debugging the debugger...",
|
||||
"Compiling the compiler...",
|
||||
"Parsing the parser (meta)...",
|
||||
"Tokenizing tokens...",
|
||||
"Encoding the encoder...",
|
||||
"Hashing hash browns...",
|
||||
"Forking spoons (not forks)...",
|
||||
"The model is contemplating existence...",
|
||||
"Transcending dimensional barriers...",
|
||||
"Invoking elder tensor gods...",
|
||||
"Unfurling probability clouds...",
|
||||
"Synchronizing parallel universes...",
|
||||
"The GPU is having second thoughts...",
|
||||
"Recalibrating reality matrices...",
|
||||
"Time is an illusion, loading doubly so...",
|
||||
"Convincing bits to flip themselves...",
|
||||
"The model is reading its own documentation...",
|
||||
}
|
||||
|
||||
type statusResponseWriter struct {
|
||||
hasWritten bool
|
||||
writer http.ResponseWriter
|
||||
process *Process
|
||||
wg sync.WaitGroup // Track goroutine completion
|
||||
start time.Time
|
||||
}
|
||||
|
||||
func newStatusResponseWriter(p *Process, w http.ResponseWriter) *statusResponseWriter {
|
||||
s := &statusResponseWriter{
|
||||
writer: w,
|
||||
process: p,
|
||||
start: time.Now(),
|
||||
}
|
||||
|
||||
s.Header().Set("Content-Type", "text/event-stream") // SSE
|
||||
s.Header().Set("Cache-Control", "no-cache") // no-cache
|
||||
s.Header().Set("Connection", "keep-alive") // keep-alive
|
||||
s.WriteHeader(http.StatusOK) // send status code 200
|
||||
s.sendLine("━━━━━")
|
||||
s.sendLine(fmt.Sprintf("llama-swap loading model: %s", p.ID))
|
||||
return s
|
||||
}
|
||||
|
||||
// statusUpdates sends status updates to the client while the model is loading
|
||||
func (s *statusResponseWriter) statusUpdates(ctx context.Context) {
|
||||
s.wg.Add(1)
|
||||
defer s.wg.Done()
|
||||
|
||||
// Recover from panics caused by client disconnection
|
||||
// Note: recover() only works within the same goroutine, so we need it here
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
s.process.proxyLogger.Debugf("<%s> statusUpdates recovered from panic (likely client disconnect): %v", s.process.ID, r)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
duration := time.Since(s.start)
|
||||
s.sendLine(fmt.Sprintf("\nDone! (%.2fs)", duration.Seconds()))
|
||||
s.sendLine("━━━━━")
|
||||
s.sendLine(" ")
|
||||
}()
|
||||
|
||||
// Create a shuffled copy of loadingRemarks
|
||||
remarks := make([]string, len(loadingRemarks))
|
||||
copy(remarks, loadingRemarks)
|
||||
rand.Shuffle(len(remarks), func(i, j int) {
|
||||
remarks[i], remarks[j] = remarks[j], remarks[i]
|
||||
})
|
||||
ri := 0
|
||||
|
||||
// Pick a random duration to send a remark
|
||||
nextRemarkIn := time.Duration(2+rand.Intn(4)) * time.Second
|
||||
lastRemarkTime := time.Now()
|
||||
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop() // Ensure ticker is stopped to prevent resource leak
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if s.process.CurrentState() == StateReady {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if it's time for a snarky remark
|
||||
if time.Since(lastRemarkTime) >= nextRemarkIn {
|
||||
remark := remarks[ri%len(remarks)]
|
||||
ri++
|
||||
s.sendLine(fmt.Sprintf("\n%s", remark))
|
||||
lastRemarkTime = time.Now()
|
||||
// Pick a new random duration for the next remark
|
||||
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
|
||||
} else {
|
||||
s.sendData(".")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// waitForCompletion waits for the statusUpdates goroutine to finish
|
||||
func (s *statusResponseWriter) waitForCompletion(timeout time.Duration) bool {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return true
|
||||
case <-time.After(timeout):
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) sendLine(line string) {
|
||||
s.sendData(line + "\n")
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) sendData(data string) {
|
||||
// Create the proper SSE JSON structure
|
||||
type Delta struct {
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
}
|
||||
type Choice struct {
|
||||
Delta Delta `json:"delta"`
|
||||
}
|
||||
type SSEMessage struct {
|
||||
Choices []Choice `json:"choices"`
|
||||
}
|
||||
|
||||
msg := SSEMessage{
|
||||
Choices: []Choice{
|
||||
{
|
||||
Delta: Delta{
|
||||
ReasoningContent: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
s.process.proxyLogger.Errorf("<%s> Failed to marshal SSE message: %v", s.process.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Write SSE formatted data, panic if not able to write
|
||||
_, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("<%s> Failed to write SSE data: %v", s.process.ID, err))
|
||||
}
|
||||
s.Flush()
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) Header() http.Header {
|
||||
return s.writer.Header()
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) Write(data []byte) (int, error) {
|
||||
return s.writer.Write(data)
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) WriteHeader(statusCode int) {
|
||||
if s.hasWritten {
|
||||
return
|
||||
}
|
||||
s.hasWritten = true
|
||||
s.writer.WriteHeader(statusCode)
|
||||
s.Flush()
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) Flush() {
|
||||
if flusher, ok := s.writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
@@ -1,609 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
debugLogger = logmon.NewWriter(os.Stdout)
|
||||
)
|
||||
|
||||
func init() {
|
||||
// flip to help with debugging tests
|
||||
if false {
|
||||
debugLogger.SetLogLevel(logmon.LevelDebug)
|
||||
} else {
|
||||
debugLogger.SetLogLevel(logmon.LevelError)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
||||
|
||||
expectedMessage := "testing91931"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// Create a process
|
||||
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// process is automatically started
|
||||
assert.Equal(t, StateStopped, process.CurrentState())
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), expectedMessage)
|
||||
|
||||
// Stop the process
|
||||
process.Stop()
|
||||
|
||||
req = httptest.NewRequest("GET", "/", nil)
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
// Proxy the request
|
||||
process.ProxyRequest(w, req)
|
||||
|
||||
// should have automatically started the process again
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcess_WaitOnMultipleStarts tests that multiple concurrent requests
|
||||
// are all handled successfully, even though they all may ask for the process to .start()
|
||||
func TestProcess_WaitOnMultipleStarts(t *testing.T) {
|
||||
|
||||
expectedMessage := "testing91931"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func(reqID int) {
|
||||
defer wg.Done()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Worker %d got wrong HTTP code", reqID)
|
||||
assert.Contains(t, w.Body.String(), expectedMessage, "Worker %d got wrong message", reqID)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
}
|
||||
|
||||
// test that the automatic start returns the expected error type
|
||||
func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||
// Create a process configuration
|
||||
config := config.ModelConfig{
|
||||
Cmd: "nonexistent-command",
|
||||
Proxy: "http://127.0.0.1:9913",
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("broken", 1, config, debugLogger, debugLogger)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "unable to start process")
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "start() failed for command 'nonexistent-command':")
|
||||
}
|
||||
|
||||
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long auto unload TTL test")
|
||||
}
|
||||
|
||||
expectedMessage := "I_sense_imminent_danger"
|
||||
conf := getTestSimpleResponderConfig(expectedMessage)
|
||||
assert.Equal(t, config.MODEL_CONFIG_DEFAULT_TTL, conf.UnloadAfter)
|
||||
conf.UnloadAfter = 3 // seconds
|
||||
assert.Equal(t, 3, conf.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl_test", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// this should take 4 seconds
|
||||
req1 := httptest.NewRequest("GET", "/slow-respond?echo=1234&delay=1000ms", nil)
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Proxy the request (auto start) with a slow response that takes longer than config.UnloadAfter
|
||||
process.ProxyRequest(w, req1)
|
||||
|
||||
t.Log("sending slow first request (4 seconds)")
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "1234")
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
|
||||
// ensure the TTL timeout does not race slow requests (see issue #25)
|
||||
t.Log("sending second request (1 second)")
|
||||
time.Sleep(time.Second)
|
||||
w = httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req2)
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), expectedMessage)
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
|
||||
// wait 5 seconds
|
||||
t.Log("sleep 5 seconds and check if unloaded")
|
||||
time.Sleep(5 * time.Second)
|
||||
assert.Equal(t, StateStopped, process.CurrentState())
|
||||
}
|
||||
|
||||
func TestProcess_LowTTLValue(t *testing.T) {
|
||||
if true { // change this code to run this ...
|
||||
t.Skip("skipping test, edit process_test.go to run it ")
|
||||
}
|
||||
|
||||
conf := getTestSimpleResponderConfig("fast_ttl")
|
||||
assert.Equal(t, config.MODEL_CONFIG_DEFAULT_TTL, conf.UnloadAfter)
|
||||
conf.UnloadAfter = 1 // second
|
||||
assert.Equal(t, 1, conf.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
t.Logf("Waiting before sending request %d", i)
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
expected := fmt.Sprintf("echo=test_%d", i)
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=50ms", expected), nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), expected)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// issue #19
|
||||
// This test makes sure using Process.Stop() does not affect pending HTTP
|
||||
// requests. All HTTP requests in this test should complete successfully.
|
||||
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
expectedMessage := "12345"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
process := NewProcess("t", 10, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
results := map[string]string{
|
||||
"12345": "",
|
||||
"abcde": "",
|
||||
"fghij": "",
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
|
||||
for key := range results {
|
||||
wg.Add(1)
|
||||
go func(key string) {
|
||||
defer wg.Done()
|
||||
// send a request where simple-responder is will wait 300ms before responding
|
||||
// this will simulate an in-progress request.
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=300ms", key), nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
process.ProxyRequest(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
results[key] = w.Body.String()
|
||||
mu.Unlock()
|
||||
|
||||
}(key)
|
||||
}
|
||||
|
||||
// Stop the process while requests are still being processed
|
||||
go func() {
|
||||
<-time.After(150 * time.Millisecond)
|
||||
process.Stop()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
for key, result := range results {
|
||||
assert.Equal(t, key, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcess_SwapState(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
currentState ProcessState
|
||||
expectedState ProcessState
|
||||
newState ProcessState
|
||||
expectedError error
|
||||
expectedResult ProcessState
|
||||
}{
|
||||
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
|
||||
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
|
||||
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
|
||||
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, nil, StateStopped},
|
||||
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
|
||||
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
|
||||
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
|
||||
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
|
||||
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
|
||||
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
|
||||
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
|
||||
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
|
||||
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger)
|
||||
p.state = test.currentState
|
||||
|
||||
resultState, err := p.swapState(test.expectedState, test.newState)
|
||||
if err != nil && test.expectedError == nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
} else if err == nil && test.expectedError != nil {
|
||||
t.Errorf("Expected error: %v, but got none", test.expectedError)
|
||||
} else if err != nil && test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("Expected error: %v, got: %v", test.expectedError, err)
|
||||
}
|
||||
}
|
||||
|
||||
if resultState != test.expectedResult {
|
||||
t.Errorf("Expected state: %v, got: %v", test.expectedResult, resultState)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long shutdown test")
|
||||
}
|
||||
|
||||
expectedMessage := "testing91931"
|
||||
|
||||
// make a config where the healthcheck will always fail because port is wrong
|
||||
config := getTestSimpleResponderConfigPort(expectedMessage, 9999)
|
||||
config.Proxy = "http://localhost:9998/test"
|
||||
|
||||
healthCheckTTLSeconds := 30
|
||||
process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger)
|
||||
|
||||
// make it a lot faster
|
||||
process.healthCheckLoopInterval = time.Second
|
||||
|
||||
// start a goroutine to simulate a shutdown
|
||||
var wg sync.WaitGroup
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-time.After(time.Millisecond * 500)
|
||||
process.Shutdown()
|
||||
}()
|
||||
wg.Add(1)
|
||||
|
||||
// start the process, this is a blocking call
|
||||
err := process.start()
|
||||
|
||||
wg.Wait()
|
||||
assert.ErrorContains(t, err, "health check interrupted due to shutdown")
|
||||
assert.Equal(t, StateShutdown, process.CurrentState())
|
||||
}
|
||||
|
||||
func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping Exit Interrupts Health Check test")
|
||||
}
|
||||
|
||||
// should run and exit but interrupt the long checkHealthTimeout
|
||||
checkHealthTimeout := 5
|
||||
config := config.ModelConfig{
|
||||
Cmd: "sleep 1",
|
||||
Proxy: "http://127.0.0.1:9913",
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger)
|
||||
process.healthCheckLoopInterval = time.Second // make it faster
|
||||
err := process.start()
|
||||
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
func TestProcess_ConcurrencyLimit(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long concurrency limit test")
|
||||
}
|
||||
|
||||
expectedMessage := "concurrency_limit_test"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// only allow 1 concurrent request at a time
|
||||
config.ConcurrencyLimit = 1
|
||||
|
||||
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
|
||||
assert.Equal(t, 1, cap(process.concurrencyLimitSemaphore))
|
||||
defer process.Stop()
|
||||
|
||||
// launch a goroutine first to take up the semaphore
|
||||
go func() {
|
||||
req1 := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=75ms", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req1)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}()
|
||||
|
||||
// let the goroutine start
|
||||
<-time.After(time.Millisecond * 25)
|
||||
|
||||
denied := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, denied)
|
||||
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||
}
|
||||
|
||||
func TestProcess_StopImmediately(t *testing.T) {
|
||||
expectedMessage := "test_stop_immediate"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, process.CurrentState(), StateReady)
|
||||
go func() {
|
||||
// slow, but will get killed by StopImmediate
|
||||
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=1s", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
}()
|
||||
<-time.After(time.Millisecond)
|
||||
process.StopImmediately()
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
|
||||
// the upstream command
|
||||
func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping SIGTERM test on Windows ")
|
||||
}
|
||||
|
||||
expectedMessage := "test_sigkill"
|
||||
binaryPath := getSimpleResponderPath()
|
||||
port := getTestPort()
|
||||
|
||||
conf := config.ModelConfig{
|
||||
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
|
||||
// to force the process to exit
|
||||
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("stop_immediate", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// reduce to make testing go faster
|
||||
process.gracefulStopTimeout = time.Second
|
||||
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, process.CurrentState(), StateReady)
|
||||
|
||||
waitChan := make(chan struct{})
|
||||
go func() {
|
||||
// slow, but will get killed by StopImmediate
|
||||
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=2s", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
|
||||
// StatusOK because that was already sent before the kill
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// unexpected EOF because the kill happened, the "1" is sent before the kill
|
||||
// then the unexpected EOF is sent after the kill
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host")
|
||||
} else {
|
||||
// Upstream may be killed mid-response.
|
||||
// Assert an incomplete or partial response.
|
||||
assert.NotEqual(t, "12345", w.Body.String())
|
||||
}
|
||||
|
||||
close(waitChan)
|
||||
}()
|
||||
|
||||
<-time.After(time.Millisecond)
|
||||
process.StopImmediately()
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
|
||||
// the request should have been interrupted by SIGKILL
|
||||
<-waitChan
|
||||
}
|
||||
|
||||
func TestProcess_StopCmd(t *testing.T) {
|
||||
conf := getTestSimpleResponderConfig("test_stop_cmd")
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
conf.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
} else {
|
||||
conf.CmdStop = "kill -TERM ${PID}"
|
||||
}
|
||||
|
||||
process := NewProcess("testStopCmd", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, process.CurrentState(), StateReady)
|
||||
process.StopImmediately()
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
|
||||
expectedMessage := "test_env_not_emptied"
|
||||
conf := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// ensure that the the default config does not blank out the inherited environment
|
||||
configWEnv := conf
|
||||
|
||||
// ensure the additiona variables are appended to the process' environment
|
||||
configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2")
|
||||
|
||||
process1 := NewProcess("env_test", 2, conf, debugLogger, debugLogger)
|
||||
process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger)
|
||||
|
||||
process1.start()
|
||||
defer process1.Stop()
|
||||
process2.start()
|
||||
defer process2.Stop()
|
||||
|
||||
assert.NotZero(t, len(process1.cmd.Environ()))
|
||||
assert.NotZero(t, len(process2.cmd.Environ()))
|
||||
assert.Equal(t, len(process1.cmd.Environ())+2, len(process2.cmd.Environ()), "process2 should have 2 more environment variables than process1")
|
||||
|
||||
}
|
||||
|
||||
// TestProcess_ReverseProxyPanicIsHandled tests that panics from
|
||||
// httputil.ReverseProxy in Process.ProxyRequest(w, r) do not bubble up and are
|
||||
// handled appropriately.
|
||||
//
|
||||
// httputil.ReverseProxy will panic with http.ErrAbortHandler when it has sent headers
|
||||
// can't copy the body. This can be caused by a client disconnecting before the full
|
||||
// response is sent from some reason.
|
||||
//
|
||||
// bug: https://github.com/mostlygeek/llama-swap/issues/362
|
||||
// see: https://github.com/golang/go/issues/23643 (where panic was added to httputil.ReverseProxy)
|
||||
func TestProcess_ReverseProxyPanicIsHandled(t *testing.T) {
|
||||
// Add defer/recover to catch any panics that aren't handled by ProxyRequest
|
||||
// If this recover() is hit, it means ProxyRequest didn't handle the panic properly
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("ProxyRequest should handle panics from reverseProxy.ServeHTTP, but panic was not caught: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
expectedMessage := "panic_test"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
process := NewProcess("panic-test", 5, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// Start the process
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
|
||||
// Create a custom ResponseWriter that simulates a client disconnect
|
||||
// by panicking when Write is called after headers are sent
|
||||
panicWriter := &panicOnWriteResponseWriter{
|
||||
ResponseRecorder: httptest.NewRecorder(),
|
||||
shouldPanic: true,
|
||||
}
|
||||
|
||||
// Make a request that will trigger the panic
|
||||
req := httptest.NewRequest("GET", "/slow-respond?echo=test&delay=100ms", nil)
|
||||
|
||||
// This should panic inside reverseProxy.ServeHTTP when the panicWriter.Write() is called.
|
||||
// ProxyRequest should catch and handle this panic gracefully.
|
||||
process.ProxyRequest(panicWriter, req)
|
||||
|
||||
// If we get here, the panic was properly recovered in ProxyRequest
|
||||
// The process should still be in a ready state
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
}
|
||||
|
||||
// panicOnWriteResponseWriter is a ResponseWriter that panics on Write
|
||||
// to simulate a client disconnect after headers are sent
|
||||
// used by: TestProcess_ReverseProxyPanicIsHandled
|
||||
type panicOnWriteResponseWriter struct {
|
||||
*httptest.ResponseRecorder
|
||||
shouldPanic bool
|
||||
headerWritten bool
|
||||
}
|
||||
|
||||
func (w *panicOnWriteResponseWriter) WriteHeader(statusCode int) {
|
||||
w.headerWritten = true
|
||||
w.ResponseRecorder.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *panicOnWriteResponseWriter) Write(b []byte) (int, error) {
|
||||
if w.shouldPanic && w.headerWritten {
|
||||
// Simulate the panic that httputil.ReverseProxy throws
|
||||
panic(http.ErrAbortHandler)
|
||||
}
|
||||
return w.ResponseRecorder.Write(b)
|
||||
}
|
||||
|
||||
func TestProcess_CustomTimeouts(t *testing.T) {
|
||||
modelConfig := config.ModelConfig{
|
||||
Cmd: "echo test",
|
||||
Proxy: "http://localhost:8080",
|
||||
CheckEndpoint: "/health",
|
||||
Timeouts: config.TimeoutsConfig{
|
||||
Connect: 45,
|
||||
ResponseHeader: 120,
|
||||
TLSHandshake: 15,
|
||||
ExpectContinue: 2,
|
||||
IdleConn: 120,
|
||||
},
|
||||
}
|
||||
|
||||
debugLogger := logmon.NewWriter(io.Discard)
|
||||
process := NewProcess("test-model", 30, modelConfig, debugLogger, debugLogger)
|
||||
|
||||
// Verify the process was created successfully
|
||||
assert.NotNil(t, process)
|
||||
assert.Equal(t, "test-model", process.ID)
|
||||
assert.NotNil(t, process.reverseProxy)
|
||||
assert.NotNil(t, process.reverseProxy.Transport)
|
||||
|
||||
// Verify it's using http.Transport (not some other type)
|
||||
transport, ok := process.reverseProxy.Transport.(*http.Transport)
|
||||
assert.True(t, ok, "Transport should be *http.Transport")
|
||||
assert.NotNil(t, transport)
|
||||
|
||||
// Verify the timeouts are correctly applied
|
||||
assert.Equal(t, 120*time.Second, transport.ResponseHeaderTimeout)
|
||||
assert.Equal(t, 15*time.Second, transport.TLSHandshakeTimeout)
|
||||
assert.Equal(t, 2*time.Second, transport.ExpectContinueTimeout)
|
||||
assert.Equal(t, 120*time.Second, transport.IdleConnTimeout)
|
||||
assert.True(t, transport.ForceAttemptHTTP2)
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
// No-op on Unix systems
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
HideWindow: true,
|
||||
CreationFlags: 0x08000000, // CREATE_NO_WINDOW
|
||||
}
|
||||
}
|
||||
@@ -1,194 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
type ProcessGroup struct {
|
||||
sync.Mutex
|
||||
|
||||
config config.Config
|
||||
id string
|
||||
swap bool
|
||||
exclusive bool
|
||||
persistent bool
|
||||
|
||||
proxyLogger *logmon.Monitor
|
||||
upstreamLogger *logmon.Monitor
|
||||
|
||||
// map of current processes
|
||||
processes map[string]*Process
|
||||
lastUsedProcess string
|
||||
|
||||
// inflight tracks fast-path requests (requests for the already-selected
|
||||
// model in a swap group). Fast-path requests Add(1) while holding pg.Lock
|
||||
// and Done() on completion; a concurrent swap request calls inflight.Wait()
|
||||
// under pg.Lock before stopping the current process. Without this tracking,
|
||||
// a fast-path request that has released pg.Lock but has not yet called
|
||||
// Process.inFlightRequests.Add(1) races with Stop()'s Wait() and can be
|
||||
// killed mid-request.
|
||||
inflight sync.WaitGroup
|
||||
|
||||
// testDelayFastPath is a test-only hook that, when non-nil, is invoked in
|
||||
// the fast path after pg.Lock is released but before the request is
|
||||
// dispatched to Process.ProxyRequest. Tests use it to park a fast-path
|
||||
// request at the exact race window to deterministically reproduce the
|
||||
// fast-path vs swap race.
|
||||
testDelayFastPath func()
|
||||
}
|
||||
|
||||
func NewProcessGroup(id string, config config.Config, proxyLogger *logmon.Monitor, upstreamLogger *logmon.Monitor) *ProcessGroup {
|
||||
groupConfig, ok := config.Groups[id]
|
||||
if !ok {
|
||||
panic("Unable to find configuration for group id: " + id)
|
||||
}
|
||||
|
||||
pg := &ProcessGroup{
|
||||
id: id,
|
||||
config: config,
|
||||
swap: groupConfig.Swap,
|
||||
exclusive: groupConfig.Exclusive,
|
||||
persistent: groupConfig.Persistent,
|
||||
proxyLogger: proxyLogger,
|
||||
upstreamLogger: upstreamLogger,
|
||||
processes: make(map[string]*Process),
|
||||
}
|
||||
|
||||
// Create a Process for each member in the group
|
||||
for _, modelID := range groupConfig.Members {
|
||||
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
|
||||
processLogger := logmon.NewWriter(upstreamLogger)
|
||||
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, processLogger, pg.proxyLogger)
|
||||
pg.processes[modelID] = process
|
||||
}
|
||||
|
||||
return pg
|
||||
}
|
||||
|
||||
// ProxyRequest proxies a request to the specified model
|
||||
func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter, request *http.Request) error {
|
||||
if !pg.HasMember(modelID) {
|
||||
return fmt.Errorf("model %s not part of group %s", modelID, pg.id)
|
||||
}
|
||||
|
||||
if pg.swap {
|
||||
pg.Lock()
|
||||
if pg.lastUsedProcess != modelID {
|
||||
|
||||
// Wait for in-flight fast-path requests to drain before stopping
|
||||
// the previous process. Without this, a fast-path request that has
|
||||
// released pg.Lock but has not yet incremented
|
||||
// Process.inFlightRequests races with Stop() and can be killed
|
||||
// mid-request.
|
||||
pg.inflight.Wait()
|
||||
|
||||
// is there something already running?
|
||||
if pg.lastUsedProcess != "" {
|
||||
pg.processes[pg.lastUsedProcess].Stop()
|
||||
}
|
||||
|
||||
// wait for the request to the new model to be fully handled
|
||||
// and prevent race conditions see issue #277
|
||||
pg.processes[modelID].ProxyRequest(writer, request)
|
||||
pg.lastUsedProcess = modelID
|
||||
|
||||
// short circuit and exit
|
||||
pg.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fast path: register this request in inflight before releasing
|
||||
// pg.Lock so a concurrent swap will wait for it to complete.
|
||||
pg.inflight.Add(1)
|
||||
defer pg.inflight.Done()
|
||||
pg.Unlock()
|
||||
|
||||
if pg.testDelayFastPath != nil {
|
||||
pg.testDelayFastPath()
|
||||
}
|
||||
}
|
||||
|
||||
pg.processes[modelID].ProxyRequest(writer, request)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) HasMember(modelName string) bool {
|
||||
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) GetMember(modelName string) (*Process, bool) {
|
||||
if pg.HasMember(modelName) {
|
||||
return pg.processes[modelName], true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error {
|
||||
pg.Lock()
|
||||
|
||||
process, exists := pg.processes[modelID]
|
||||
if !exists {
|
||||
pg.Unlock()
|
||||
return fmt.Errorf("process not found for %s", modelID)
|
||||
}
|
||||
|
||||
if pg.lastUsedProcess == modelID {
|
||||
pg.lastUsedProcess = ""
|
||||
}
|
||||
pg.Unlock()
|
||||
|
||||
switch strategy {
|
||||
case StopImmediately:
|
||||
process.StopImmediately()
|
||||
default:
|
||||
process.Stop()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
|
||||
pg.Lock()
|
||||
defer pg.Unlock()
|
||||
|
||||
if strategy != StopImmediately {
|
||||
pg.inflight.Wait()
|
||||
}
|
||||
|
||||
if len(pg.processes) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// stop Processes in parallel
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pg.processes {
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
defer wg.Done()
|
||||
switch strategy {
|
||||
case StopImmediately:
|
||||
process.StopImmediately()
|
||||
default:
|
||||
process.Stop()
|
||||
}
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) Shutdown() {
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pg.processes {
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
defer wg.Done()
|
||||
process.Shutdown()
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -1,345 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
"model4": getTestSimpleResponderConfig("model4"),
|
||||
"model5": getTestSimpleResponderConfig("model5"),
|
||||
},
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model2"},
|
||||
},
|
||||
"G2": {
|
||||
Swap: false,
|
||||
Exclusive: true,
|
||||
Members: []string{"model3", "model4"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
|
||||
pg := NewProcessGroup(config.DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
|
||||
assert.True(t, pg.HasMember("model5"))
|
||||
}
|
||||
|
||||
func TestProcessGroup_HasMember(t *testing.T) {
|
||||
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||
assert.True(t, pg.HasMember("model1"))
|
||||
assert.True(t, pg.HasMember("model2"))
|
||||
assert.False(t, pg.HasMember("model3"))
|
||||
}
|
||||
|
||||
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
|
||||
// and multiple requests are made in parallel, only one process is running at a time.
|
||||
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
// use the same listening so if a model is already running, it will fail
|
||||
// this is a way to test that swap isolation is working
|
||||
// properly when there are parallel requests made at the
|
||||
// same time.
|
||||
"model1": getTestSimpleResponderConfigPort("model1", 9832),
|
||||
"model2": getTestSimpleResponderConfigPort("model2", 9832),
|
||||
"model3": getTestSimpleResponderConfigPort("model3", 9832),
|
||||
"model4": getTestSimpleResponderConfigPort("model4", 9832),
|
||||
"model5": getTestSimpleResponderConfigPort("model5", 9832),
|
||||
},
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Members: []string{"model1", "model2", "model3", "model4", "model5"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
tests := []string{"model1", "model2", "model3", "model4", "model5"}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(len(tests))
|
||||
for _, modelName := range tests {
|
||||
go func(modelName string) {
|
||||
defer wg.Done()
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelName)
|
||||
}(modelName)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestProcessGroup_ProxyRequestSwapRaceAgainstFastPath verifies that a swap
|
||||
// request cannot stop the current process while a fast-path request (for the
|
||||
// already-selected model) is in flight. Without ProcessGroup-level inflight
|
||||
// tracking, a fast-path request that has released pg.Lock but has not yet
|
||||
// incremented Process.inFlightRequests races with Stop()'s Wait() and the
|
||||
// process is killed mid-request.
|
||||
func TestProcessGroup_ProxyRequestSwapRaceAgainstFastPath(t *testing.T) {
|
||||
cfg := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Members: []string{"model1", "model2"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
pg := NewProcessGroup("G1", cfg, testLogger, testLogger)
|
||||
defer pg.StopProcesses(StopImmediately)
|
||||
|
||||
// Bypass real subprocesses so the test is fast and deterministic.
|
||||
pg.processes["model1"].testHandler = newTestHandler("model1")
|
||||
pg.processes["model2"].testHandler = newTestHandler("model2")
|
||||
|
||||
// Prime: run a request through model1 via the swap path so that
|
||||
// lastUsedProcess == "model1" and subsequent model1 requests take the
|
||||
// fast path.
|
||||
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
primeW := httptest.NewRecorder()
|
||||
require.NoError(t, pg.ProxyRequest("model1", primeW, primeReq))
|
||||
require.Equal(t, http.StatusOK, primeW.Code)
|
||||
require.Equal(t, StateReady, pg.processes["model1"].CurrentState())
|
||||
require.Equal(t, StateStopped, pg.processes["model2"].CurrentState())
|
||||
|
||||
// Fast-path hook: signal arrival at the race window, then wait for
|
||||
// release. This parks R2 deterministically at the point where pg.Lock
|
||||
// has been released but Process.inFlightRequests has not yet been
|
||||
// incremented — the exact window the race exploits.
|
||||
r2Reached := make(chan struct{})
|
||||
r2Release := make(chan struct{})
|
||||
pg.testDelayFastPath = func() {
|
||||
close(r2Reached)
|
||||
<-r2Release
|
||||
}
|
||||
|
||||
// R2: fast-path request for model1. Will pause at the test hook.
|
||||
r2Done := make(chan struct{})
|
||||
w2 := httptest.NewRecorder()
|
||||
go func() {
|
||||
defer close(r2Done)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
assert.NoError(t, pg.ProxyRequest("model1", w2, req))
|
||||
}()
|
||||
|
||||
// Deterministically wait for R2 to reach the race window.
|
||||
<-r2Reached
|
||||
|
||||
// R3: swap request for model2. Must wait for R2 to finish before touching
|
||||
// model1, otherwise model1 gets killed mid-request.
|
||||
r3Done := make(chan struct{})
|
||||
w3 := httptest.NewRecorder()
|
||||
go func() {
|
||||
defer close(r3Done)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
assert.NoError(t, pg.ProxyRequest("model2", w3, req))
|
||||
}()
|
||||
|
||||
// Spin until R3 has acquired pg.Lock and entered the swap critical
|
||||
// section. In the fixed code, R3 then blocks on pg.inflight.Wait() while
|
||||
// still holding the lock, so TryLock keeps failing.
|
||||
for pg.TryLock() {
|
||||
pg.Unlock()
|
||||
runtime.Gosched()
|
||||
}
|
||||
|
||||
// Bounded poll: give R3 a chance to demonstrate the bug by mutating
|
||||
// state. In the fixed code, R3 is blocked on pg.inflight.Wait() and
|
||||
// nothing changes, so we wait the full window. In the buggy code, R3
|
||||
// will Stop() model1 and start serving via model2 within microseconds —
|
||||
// we exit early once the mutation is observable.
|
||||
deadline := time.Now().Add(100 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if pg.processes["model1"].CurrentState() != StateReady ||
|
||||
pg.processes["model2"].CurrentState() != StateStopped {
|
||||
break
|
||||
}
|
||||
done := false
|
||||
select {
|
||||
case <-r3Done:
|
||||
done = true
|
||||
default:
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
runtime.Gosched()
|
||||
}
|
||||
|
||||
// Invariant: R3 must be blocked while R2 is still in flight.
|
||||
select {
|
||||
case <-r3Done:
|
||||
t.Fatal("swap completed while fast-path request was still in flight — race not prevented")
|
||||
default:
|
||||
}
|
||||
assert.Equal(t, StateReady, pg.processes["model1"].CurrentState(),
|
||||
"model1 must stay Ready while a fast-path request is in flight")
|
||||
assert.Equal(t, StateStopped, pg.processes["model2"].CurrentState(),
|
||||
"model2 must not be started until R2 finishes and model1 is swapped out")
|
||||
|
||||
// Release R2 and let both requests finish.
|
||||
close(r2Release)
|
||||
<-r2Done
|
||||
<-r3Done
|
||||
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Contains(t, w2.Body.String(), "model1")
|
||||
assert.Equal(t, http.StatusOK, w3.Code)
|
||||
assert.Contains(t, w3.Body.String(), "model2")
|
||||
}
|
||||
|
||||
// TestProcessGroup_StopProcessesWaitsForInflight verifies that StopProcesses
|
||||
// (called externally, e.g. from ProxyManager.swapProcessGroup) cannot stop a
|
||||
// process while a fast-path ProxyRequest is in the [pg.Unlock,
|
||||
// Process.inFlightRequests.Add(1)] window. Without pg.inflight.Wait() in
|
||||
// StopProcesses, the external caller bypasses the inflight guard and kills the
|
||||
// process mid-request.
|
||||
func TestProcessGroup_StopProcessesWaitsForInflight(t *testing.T) {
|
||||
cfg := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Members: []string{"model1", "model2"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
pg := NewProcessGroup("G1", cfg, testLogger, testLogger)
|
||||
defer pg.StopProcesses(StopImmediately)
|
||||
|
||||
pg.processes["model1"].testHandler = newTestHandler("model1")
|
||||
pg.processes["model2"].testHandler = newTestHandler("model2")
|
||||
|
||||
// Prime: model1 is active so subsequent model1 requests take the fast path.
|
||||
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
primeW := httptest.NewRecorder()
|
||||
require.NoError(t, pg.ProxyRequest("model1", primeW, primeReq))
|
||||
require.Equal(t, http.StatusOK, primeW.Code)
|
||||
require.Equal(t, StateReady, pg.processes["model1"].CurrentState())
|
||||
|
||||
// Park a fast-path request at the race window.
|
||||
r2Reached := make(chan struct{})
|
||||
r2Release := make(chan struct{})
|
||||
pg.testDelayFastPath = func() {
|
||||
close(r2Reached)
|
||||
<-r2Release
|
||||
}
|
||||
|
||||
r2Done := make(chan struct{})
|
||||
w2 := httptest.NewRecorder()
|
||||
go func() {
|
||||
defer close(r2Done)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
assert.NoError(t, pg.ProxyRequest("model1", w2, req))
|
||||
}()
|
||||
|
||||
<-r2Reached
|
||||
|
||||
// Simulate an external caller (e.g. ProxyManager.swapProcessGroup) stopping
|
||||
// the group while a fast-path request is in flight.
|
||||
r3Done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(r3Done)
|
||||
pg.StopProcesses(StopWaitForInflightRequest)
|
||||
}()
|
||||
|
||||
// Spin until StopProcesses has acquired pg.Lock.
|
||||
for pg.TryLock() {
|
||||
pg.Unlock()
|
||||
runtime.Gosched()
|
||||
}
|
||||
|
||||
// Bounded poll: in the fixed code StopProcesses blocks on pg.inflight.Wait()
|
||||
// and model1 stays Ready. In the buggy code it proceeds immediately and
|
||||
// kills model1.
|
||||
deadline := time.Now().Add(100 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if pg.processes["model1"].CurrentState() != StateReady {
|
||||
break
|
||||
}
|
||||
select {
|
||||
case <-r3Done:
|
||||
goto done
|
||||
default:
|
||||
}
|
||||
runtime.Gosched()
|
||||
}
|
||||
done:
|
||||
|
||||
select {
|
||||
case <-r3Done:
|
||||
t.Fatal("StopProcesses completed while a fast-path request was still in flight — race not prevented")
|
||||
default:
|
||||
}
|
||||
assert.Equal(t, StateReady, pg.processes["model1"].CurrentState(),
|
||||
"model1 must stay Ready while a fast-path request is in flight")
|
||||
|
||||
close(r2Release)
|
||||
<-r2Done
|
||||
<-r3Done
|
||||
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Contains(t, w2.Body.String(), "model1")
|
||||
}
|
||||
|
||||
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
||||
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
|
||||
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
tests := []string{"model3", "model4"}
|
||||
|
||||
for _, modelName := range tests {
|
||||
t.Run(modelName, func(t *testing.T) {
|
||||
reqBody := `{"x", "y"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelName)
|
||||
})
|
||||
}
|
||||
|
||||
// make sure all the processes are running
|
||||
for _, process := range pg.processes {
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,358 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
Unlisted bool `json:"unlisted"`
|
||||
PeerID string `json:"peerID"`
|
||||
Aliases []string `json:"aliases,omitempty"`
|
||||
}
|
||||
|
||||
func addApiHandlers(pm *ProxyManager) {
|
||||
// Add API endpoints for React to consume
|
||||
// Protected with API key authentication
|
||||
apiGroup := pm.ginEngine.Group("/api", pm.apiKeyAuth())
|
||||
{
|
||||
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
||||
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
|
||||
apiGroup.GET("/events", pm.apiSendEvents)
|
||||
apiGroup.GET("/metrics", pm.apiGetMetrics)
|
||||
apiGroup.GET("/performance", pm.apiGetPerformance)
|
||||
apiGroup.GET("/version", pm.apiGetVersion)
|
||||
apiGroup.GET("/captures/:id", pm.apiGetCapture)
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiUnloadAllModels(c *gin.Context) {
|
||||
pm.StopProcesses(StopImmediately)
|
||||
c.JSON(http.StatusOK, gin.H{"msg": "ok"})
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) getModelStatus() []Model {
|
||||
// Extract keys and sort them
|
||||
models := []Model{}
|
||||
|
||||
modelIDs := make([]string, 0, len(pm.config.Models))
|
||||
for modelID := range pm.config.Models {
|
||||
modelIDs = append(modelIDs, modelID)
|
||||
}
|
||||
sort.Strings(modelIDs)
|
||||
|
||||
// Iterate over sorted keys
|
||||
for _, modelID := range modelIDs {
|
||||
// Get process state
|
||||
state := "unknown"
|
||||
var process *Process
|
||||
if pm.matrix != nil {
|
||||
process, _ = pm.matrix.GetProcess(modelID)
|
||||
} else {
|
||||
processGroup := pm.findGroupByModelName(modelID)
|
||||
if processGroup != nil {
|
||||
process = processGroup.processes[modelID]
|
||||
}
|
||||
}
|
||||
if process != nil {
|
||||
switch process.CurrentState() {
|
||||
case StateReady:
|
||||
state = "ready"
|
||||
case StateStarting:
|
||||
state = "starting"
|
||||
case StateStopping:
|
||||
state = "stopping"
|
||||
case StateShutdown:
|
||||
state = "shutdown"
|
||||
case StateStopped:
|
||||
state = "stopped"
|
||||
}
|
||||
}
|
||||
models = append(models, Model{
|
||||
Id: modelID,
|
||||
Name: pm.config.Models[modelID].Name,
|
||||
Description: pm.config.Models[modelID].Description,
|
||||
State: state,
|
||||
Unlisted: pm.config.Models[modelID].Unlisted,
|
||||
Aliases: pm.config.Models[modelID].Aliases,
|
||||
})
|
||||
}
|
||||
|
||||
// Iterate over the peer models
|
||||
if pm.peerProxy != nil {
|
||||
for peerID, peer := range pm.peerProxy.ListPeers() {
|
||||
for _, modelID := range peer.Models {
|
||||
models = append(models, Model{
|
||||
Id: modelID,
|
||||
PeerID: peerID,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
type messageType string
|
||||
|
||||
const (
|
||||
msgTypeModelStatus messageType = "modelStatus"
|
||||
msgTypeLogData messageType = "logData"
|
||||
msgTypeMetrics messageType = "metrics"
|
||||
msgTypeInFlight messageType = "inflight"
|
||||
)
|
||||
|
||||
type messageEnvelope struct {
|
||||
Type messageType `json:"type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// sends a stream of different message types that happen on the server
|
||||
func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
// prevent nginx from buffering SSE
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
sendBuffer := make(chan messageEnvelope, 25)
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
sendModels := func() {
|
||||
data, err := json.Marshal(pm.getModelStatus())
|
||||
if err == nil {
|
||||
msg := messageEnvelope{Type: msgTypeModelStatus, Data: string(data)}
|
||||
select {
|
||||
case sendBuffer <- msg:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
sendLogData := func(source string, data []byte) {
|
||||
data, err := json.Marshal(gin.H{
|
||||
"source": source,
|
||||
"data": string(data),
|
||||
})
|
||||
if err == nil {
|
||||
select {
|
||||
case sendBuffer <- messageEnvelope{Type: msgTypeLogData, Data: string(data)}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sendMetrics := func(metrics []ActivityLogEntry) {
|
||||
jsonData, err := json.Marshal(metrics)
|
||||
if err == nil {
|
||||
select {
|
||||
case sendBuffer <- messageEnvelope{Type: msgTypeMetrics, Data: string(jsonData)}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sendInFlight := func(total int) {
|
||||
jsonData, err := json.Marshal(gin.H{"total": total})
|
||||
if err == nil {
|
||||
select {
|
||||
case sendBuffer <- messageEnvelope{Type: msgTypeInFlight, Data: string(jsonData)}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send updated models list
|
||||
*/
|
||||
defer event.On(func(e ProcessStateChangeEvent) {
|
||||
sendModels()
|
||||
})()
|
||||
defer event.On(func(e ConfigFileChangedEvent) {
|
||||
sendModels()
|
||||
})()
|
||||
|
||||
/**
|
||||
* Send Log data
|
||||
*/
|
||||
defer pm.proxyLogger.OnLogData(func(data []byte) {
|
||||
sendLogData("proxy", data)
|
||||
})()
|
||||
defer pm.upstreamLogger.OnLogData(func(data []byte) {
|
||||
sendLogData("upstream", data)
|
||||
})()
|
||||
|
||||
/**
|
||||
* Send Metrics data
|
||||
*/
|
||||
defer event.On(func(e ActivityLogEvent) {
|
||||
sendMetrics([]ActivityLogEntry{e.Metrics})
|
||||
})()
|
||||
|
||||
/**
|
||||
* Send in-flight request stats related to token stats "Waiting: N" count.
|
||||
*/
|
||||
defer event.On(func(e InFlightRequestsEvent) {
|
||||
sendInFlight(e.Total)
|
||||
})()
|
||||
|
||||
// send initial batch of data
|
||||
sendLogData("proxy", pm.proxyLogger.GetHistory())
|
||||
sendLogData("upstream", pm.upstreamLogger.GetHistory())
|
||||
sendModels()
|
||||
sendMetrics(pm.metricsMonitor.getMetrics())
|
||||
sendInFlight(pm.inFlightCounter.Current())
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cancel()
|
||||
return
|
||||
case <-pm.shutdownCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case msg := <-sendBuffer:
|
||||
c.SSEvent("message", msg)
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
|
||||
jsonData, err := pm.metricsMonitor.getMetricsJSON()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"})
|
||||
return
|
||||
}
|
||||
c.Data(http.StatusOK, "application/json", jsonData)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) prometheusMetricsHandler(c *gin.Context) {
|
||||
if pm.perfMonitor == nil {
|
||||
c.String(http.StatusServiceUnavailable, "# performance monitor not available\n")
|
||||
return
|
||||
}
|
||||
pm.perfMonitor.MetricsHandler().ServeHTTP(c.Writer, c.Request)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiGetPerformance(c *gin.Context) {
|
||||
if pm.perfMonitor == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "performance monitor not available"})
|
||||
return
|
||||
}
|
||||
|
||||
sysStats, gpuStats := pm.perfMonitor.Current()
|
||||
|
||||
var after time.Time
|
||||
if afterStr := c.Query("after"); afterStr != "" {
|
||||
ts, err := time.Parse(time.RFC3339, afterStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid 'after' timestamp, use RFC3339 format"})
|
||||
return
|
||||
}
|
||||
after = ts
|
||||
}
|
||||
|
||||
if !after.IsZero() {
|
||||
filtered := make([]perf.SysStat, 0, len(sysStats))
|
||||
for _, s := range sysStats {
|
||||
if s.Timestamp.After(after) {
|
||||
filtered = append(filtered, s)
|
||||
}
|
||||
}
|
||||
sysStats = filtered
|
||||
|
||||
filteredGpu := make([]perf.GpuStat, 0, len(gpuStats))
|
||||
for _, g := range gpuStats {
|
||||
if g.Timestamp.After(after) {
|
||||
filteredGpu = append(filteredGpu, g)
|
||||
}
|
||||
}
|
||||
gpuStats = filteredGpu
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"sys_stats": sysStats,
|
||||
"gpu_stats": gpuStats,
|
||||
})
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) {
|
||||
requestedModel := strings.TrimPrefix(c.Param("model"), "/")
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, "Model not found")
|
||||
return
|
||||
}
|
||||
|
||||
var stopErr error
|
||||
if pm.matrix != nil {
|
||||
stopErr = pm.matrix.StopProcess(realModelName, StopImmediately)
|
||||
} else {
|
||||
processGroup := pm.findGroupByModelName(realModelName)
|
||||
if processGroup == nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
|
||||
return
|
||||
}
|
||||
stopErr = processGroup.StopProcess(realModelName, StopImmediately)
|
||||
}
|
||||
|
||||
if stopErr != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", stopErr.Error()))
|
||||
return
|
||||
}
|
||||
c.String(http.StatusOK, "OK")
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiGetVersion(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, map[string]string{
|
||||
"version": pm.version,
|
||||
"commit": pm.commit,
|
||||
"build_date": pm.buildDate,
|
||||
})
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiGetCapture(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid capture ID"})
|
||||
return
|
||||
}
|
||||
|
||||
capture := pm.metricsMonitor.getCaptureByID(id)
|
||||
if capture == nil || (capture.ReqPath == "" && capture.ReqHeaders == nil && capture.ReqBody == nil && capture.RespHeaders == nil && capture.RespBody == nil) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"})
|
||||
return
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(capture)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to marshal capture"})
|
||||
return
|
||||
}
|
||||
c.Data(http.StatusOK, "application/json", jsonBytes)
|
||||
}
|
||||
@@ -1,121 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
||||
accept := c.GetHeader("Accept")
|
||||
if strings.Contains(accept, "text/html") {
|
||||
c.Redirect(http.StatusFound, "/ui/")
|
||||
} else {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
history := pm.muxLogger.GetHistory()
|
||||
_, err := c.Writer.Write(history)
|
||||
if err != nil {
|
||||
c.AbortWithError(http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.Header("Transfer-Encoding", "chunked")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
// prevent nginx from buffering streamed logs
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
logMonitorId := strings.TrimPrefix(c.Param("logMonitorID"), "/")
|
||||
|
||||
// Handle case where query string might be included in the parameter
|
||||
// (can happen with catch-all routes on some versions/setups)
|
||||
if idx := strings.Index(logMonitorId, "?"); idx != -1 {
|
||||
logMonitorId = logMonitorId[:idx]
|
||||
}
|
||||
|
||||
logger, err := pm.getLogger(logMonitorId)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported"))
|
||||
return
|
||||
}
|
||||
|
||||
_, skipHistory := c.GetQuery("no-history")
|
||||
// Send history first if not skipped
|
||||
|
||||
if !skipHistory {
|
||||
history := logger.GetHistory()
|
||||
if len(history) != 0 {
|
||||
c.Writer.Write(history)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
sendChan := make(chan []byte, 10)
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer logger.OnLogData(func(data []byte) {
|
||||
select {
|
||||
case sendChan <- data:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
})()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cancel()
|
||||
return
|
||||
case <-pm.shutdownCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case data := <-sendChan:
|
||||
c.Writer.Write(data)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getLogger searches for the appropriate logger based on the logMonitorId
|
||||
func (pm *ProxyManager) getLogger(logMonitorId string) (*logmon.Monitor, error) {
|
||||
switch logMonitorId {
|
||||
case "":
|
||||
// maintain the default
|
||||
return pm.muxLogger, nil
|
||||
case "proxy":
|
||||
return pm.proxyLogger, nil
|
||||
case "upstream":
|
||||
return pm.upstreamLogger, nil
|
||||
default:
|
||||
// search for a models specific logger using findModelInPath
|
||||
// to handle model names with slashes (e.g., "author/model")
|
||||
if _, name, _, found := pm.findModelInPath("/" + logMonitorId); found {
|
||||
for _, group := range pm.processGroups {
|
||||
if process, found := group.GetMember(name); found {
|
||||
return process.Logger(), nil
|
||||
}
|
||||
}
|
||||
// also check the matrix when processGroups doesn't contain the model
|
||||
if pm.matrix != nil {
|
||||
if process, found := pm.matrix.GetProcess(name); found {
|
||||
return process.Logger(), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID")
|
||||
}
|
||||
}
|
||||
@@ -1,173 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLogMonitorIdQueryParameterStripping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "upstream without query param",
|
||||
input: "upstream",
|
||||
expected: "upstream",
|
||||
},
|
||||
{
|
||||
name: "upstream with query param",
|
||||
input: "upstream?no-history",
|
||||
expected: "upstream",
|
||||
},
|
||||
{
|
||||
name: "proxy with multiple query params",
|
||||
input: "proxy?no-history&foo=bar",
|
||||
expected: "proxy",
|
||||
},
|
||||
{
|
||||
name: "model with slash and query param",
|
||||
input: "author/model?no-history",
|
||||
expected: "author/model",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Simulate the query parameter stripping logic
|
||||
logMonitorId := tt.input
|
||||
if idx := strings.Index(logMonitorId, "?"); idx != -1 {
|
||||
logMonitorId = logMonitorId[:idx]
|
||||
}
|
||||
|
||||
if logMonitorId != tt.expected {
|
||||
t.Errorf("Query parameter stripping failed: got %q, want %q", logMonitorId, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyManager_GetLogger_ProcessGroups verifies getLogger resolves the
|
||||
// well-known "proxy"/"upstream" loggers and a model ID managed by processGroups.
|
||||
func TestProxyManager_GetLogger_ProcessGroups(t *testing.T) {
|
||||
cfg := testConfigFromYAML(t, `
|
||||
healthCheckTimeout: 15
|
||||
logLevel: error
|
||||
models:
|
||||
model1:
|
||||
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
||||
`)
|
||||
pm := New(cfg)
|
||||
defer pm.StopProcesses(StopImmediately)
|
||||
|
||||
tests := []struct {
|
||||
id string
|
||||
wantErr bool
|
||||
}{
|
||||
{"proxy", false},
|
||||
{"upstream", false},
|
||||
{"model1", false},
|
||||
{"does-not-exist", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.id, func(t *testing.T) {
|
||||
logger, err := pm.getLogger(tt.id)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid logger")
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, logger)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyManager_GetLogger_Matrix verifies that getLogger can resolve a model
|
||||
// ID when the proxy is configured with a swap matrix (pm.processGroups is empty
|
||||
// for matrix-managed models).
|
||||
func TestProxyManager_GetLogger_Matrix(t *testing.T) {
|
||||
cfg := config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
ExpandedSets: []config.ExpandedSet{
|
||||
{SetName: "s1", Models: []string{"model1", "model2"}},
|
||||
},
|
||||
Matrix: &config.MatrixConfig{},
|
||||
}
|
||||
|
||||
pm := New(cfg)
|
||||
defer pm.StopProcesses(StopImmediately)
|
||||
|
||||
tests := []struct {
|
||||
id string
|
||||
wantErr bool
|
||||
}{
|
||||
{"proxy", false},
|
||||
{"upstream", false},
|
||||
{"model1", false},
|
||||
{"model2", false},
|
||||
{"does-not-exist", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.id, func(t *testing.T) {
|
||||
logger, err := pm.getLogger(tt.id)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid logger")
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, logger)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyManager_StreamLogs_Matrix verifies that /logs/stream/<modelID>
|
||||
// returns 200 (not 400) for a model managed by the swap matrix.
|
||||
func TestProxyManager_StreamLogs_Matrix(t *testing.T) {
|
||||
cfg := config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"matrix-model": getTestSimpleResponderConfig("matrix-model"),
|
||||
},
|
||||
ExpandedSets: []config.ExpandedSet{
|
||||
{SetName: "s1", Models: []string{"matrix-model"}},
|
||||
},
|
||||
Matrix: &config.MatrixConfig{},
|
||||
}
|
||||
|
||||
pm := New(cfg)
|
||||
defer pm.StopProcesses(StopImmediately)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
req := httptest.NewRequest("GET", "/logs/stream/matrix-model", nil)
|
||||
req = req.WithContext(ctx)
|
||||
rec := CreateTestResponseRecorder()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
pm.ServeHTTP(rec, req)
|
||||
}()
|
||||
|
||||
<-ctx.Done()
|
||||
<-done
|
||||
|
||||
assert.Equal(t, 200, rec.Code)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,43 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func isTokenChar(r rune) bool {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
case r >= 'A' && r <= 'Z':
|
||||
case r >= '0' && r <= '9':
|
||||
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func SanitizeAccessControlRequestHeaderValues(headerValues string) string {
|
||||
parts := strings.Split(headerValues, ",")
|
||||
valid := make([]string, 0, len(parts))
|
||||
|
||||
for _, p := range parts {
|
||||
v := strings.TrimSpace(p)
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
validPart := true
|
||||
for _, c := range v {
|
||||
if !isTokenChar(c) {
|
||||
validPart = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if validPart {
|
||||
valid = append(valid, v)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(valid, ", ")
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSanitizeAccessControlRequestHeaderValues(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "whitespace only",
|
||||
input: " ",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "single valid value",
|
||||
input: "content-type",
|
||||
expected: "content-type",
|
||||
},
|
||||
{
|
||||
name: "multiple valid values",
|
||||
input: "content-type, authorization, x-requested-with",
|
||||
expected: "content-type, authorization, x-requested-with",
|
||||
},
|
||||
{
|
||||
name: "values with extra spaces",
|
||||
input: " content-type , authorization ",
|
||||
expected: "content-type, authorization",
|
||||
},
|
||||
{
|
||||
name: "values with tabs",
|
||||
input: "content-type,\tauthorization",
|
||||
expected: "content-type, authorization",
|
||||
},
|
||||
{
|
||||
name: "values with invalid characters",
|
||||
input: "content-type, auth\n, x-requested-with\r",
|
||||
expected: "content-type, auth, x-requested-with",
|
||||
},
|
||||
{
|
||||
name: "empty values in list",
|
||||
input: "content-type,,authorization",
|
||||
expected: "content-type, authorization",
|
||||
},
|
||||
{
|
||||
name: "leading and trailing commas",
|
||||
input: ",content-type,authorization,",
|
||||
expected: "content-type, authorization",
|
||||
},
|
||||
{
|
||||
name: "mixed valid and invalid values",
|
||||
input: "content-type, \x00invalid, x-requested-with",
|
||||
expected: "content-type, x-requested-with",
|
||||
},
|
||||
{
|
||||
name: "mixed case values",
|
||||
input: "Content-Type, my-Valid-Header, Another-hEader",
|
||||
expected: "Content-Type, my-Valid-Header, Another-hEader",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SanitizeAccessControlRequestHeaderValues(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("SanitizeAccessControlRequestHeaderValues(%q) = %q, want %q",
|
||||
tt.input, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,81 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// selectEncoding chooses the best encoding based on Accept-Encoding header
|
||||
// Returns the encoding ("br", "gzip", or "") and the corresponding file extension
|
||||
func selectEncoding(acceptEncoding string) (encoding, ext string) {
|
||||
if acceptEncoding == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||
enc := strings.TrimSpace(strings.SplitN(part, ";", 2)[0])
|
||||
if enc == "br" {
|
||||
return "br", ".br"
|
||||
}
|
||||
}
|
||||
|
||||
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||
enc := strings.TrimSpace(strings.SplitN(part, ";", 2)[0])
|
||||
if enc == "gzip" {
|
||||
return "gzip", ".gz"
|
||||
}
|
||||
}
|
||||
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// ServeCompressedFile serves a file with compression support.
|
||||
// It checks for pre-compressed versions and serves them with proper headers.
|
||||
func ServeCompressedFile(fs http.FileSystem, w http.ResponseWriter, r *http.Request, name string) {
|
||||
encoding, ext := selectEncoding(r.Header.Get("Accept-Encoding"))
|
||||
|
||||
// Try to serve compressed version if client supports it
|
||||
if encoding != "" {
|
||||
if cf, err := fs.Open(name + ext); err == nil {
|
||||
defer cf.Close()
|
||||
|
||||
// Verify it's a regular file (not a directory)
|
||||
if stat, err := cf.Stat(); err == nil && !stat.IsDir() {
|
||||
// Set the content encoding header
|
||||
w.Header().Set("Content-Encoding", encoding)
|
||||
w.Header().Add("Vary", "Accept-Encoding")
|
||||
|
||||
// Get original file info for content type detection
|
||||
origFile, err := fs.Open(name)
|
||||
if err == nil {
|
||||
origFile.Close()
|
||||
}
|
||||
|
||||
// Serve the compressed file
|
||||
http.ServeContent(w, r, name, stat.ModTime(), cf)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to serving the uncompressed file
|
||||
file, err := fs.Open(name)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if stat.IsDir() {
|
||||
http.Error(w, "is a directory", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
http.ServeContent(w, r, name, stat.ModTime(), file)
|
||||
}
|
||||
@@ -1,283 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestServeCompressedFile_Brotli(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is test content that should be compressed with brotli")
|
||||
brContent := []byte("fake-brotli-compressed-data")
|
||||
|
||||
// Create a test filesystem
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
"test.js.br": {Data: brContent, ModTime: time.Now()},
|
||||
"test.js.gz": {Data: []byte("fake-gzip-data"), ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
req.Header.Set("Accept-Encoding", "br, gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check that brotli is used (preferred over gzip)
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "br" {
|
||||
t.Errorf("Expected Content-Encoding 'br', got '%s'", encoding)
|
||||
}
|
||||
|
||||
if vary := resp.Header.Get("Vary"); vary != "Accept-Encoding" {
|
||||
t.Errorf("Expected Vary 'Accept-Encoding', got '%s'", vary)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, brContent) {
|
||||
t.Errorf("Expected brotli content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_Gzip(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is test content that should be compressed with gzip")
|
||||
gzContent := []byte("fake-gzip-compressed-data")
|
||||
|
||||
// Create a test filesystem without brotli
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
"test.js.gz": {Data: gzContent, ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "gzip" {
|
||||
t.Errorf("Expected Content-Encoding 'gzip', got '%s'", encoding)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, gzContent) {
|
||||
t.Errorf("Expected gzip content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_UncompressedFallback(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is uncompressed test content")
|
||||
|
||||
// Create a test filesystem without compressed versions
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
req.Header.Set("Accept-Encoding", "br, gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Should not have Content-Encoding header since we're serving uncompressed
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
|
||||
t.Errorf("Expected no Content-Encoding, got '%s'", encoding)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, content) {
|
||||
t.Errorf("Expected original content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_NoAcceptEncoding(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is test content")
|
||||
|
||||
// Create a test filesystem with compressed versions
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
"test.js.br": {Data: []byte("brotli"), ModTime: time.Now()},
|
||||
"test.js.gz": {Data: []byte("gzip"), ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
// No Accept-Encoding header
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Should serve uncompressed content
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
|
||||
t.Errorf("Expected no Content-Encoding, got '%s'", encoding)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, content) {
|
||||
t.Errorf("Expected original content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_NotFound(t *testing.T) {
|
||||
mapFS := fstest.MapFS{}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/nonexistent.js", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "nonexistent.js")
|
||||
|
||||
resp := w.Result()
|
||||
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("Expected status 404, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectEncoding(t *testing.T) {
|
||||
tests := []struct {
|
||||
acceptEncoding string
|
||||
wantEncoding string
|
||||
wantExt string
|
||||
}{
|
||||
{"br, gzip", "br", ".br"},
|
||||
{"gzip, deflate", "gzip", ".gz"},
|
||||
{"gzip", "gzip", ".gz"},
|
||||
{"br", "br", ".br"},
|
||||
{"", "", ""},
|
||||
{"deflate", "", ""},
|
||||
{"br;q=1.0, gzip;q=0.5", "br", ".br"},
|
||||
{"gzip;q=1.0, br;q=0.5", "br", ".br"},
|
||||
{"browser", "", ""},
|
||||
{"compress, deflate", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
gotEncoding, gotExt := selectEncoding(tt.acceptEncoding)
|
||||
if gotEncoding != tt.wantEncoding || gotExt != tt.wantExt {
|
||||
t.Errorf("selectEncoding(%q) = (%q, %q), want (%q, %q)",
|
||||
tt.acceptEncoding, gotEncoding, gotExt, tt.wantEncoding, tt.wantExt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test with actual pre-compressed files from ui_dist
|
||||
func TestServeCompressedFile_RealFiles(t *testing.T) {
|
||||
// Check if ui_dist exists
|
||||
if _, err := os.Stat("./ui_dist"); os.IsNotExist(err) {
|
||||
t.Skip("ui_dist not found, skipping real file test")
|
||||
}
|
||||
|
||||
// Find a .js or .css file that has compressed versions
|
||||
entries, err := os.ReadDir("./ui_dist/assets")
|
||||
if err != nil {
|
||||
t.Skipf("Could not read ui_dist/assets: %v", err)
|
||||
}
|
||||
|
||||
var testFile string
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if strings.HasSuffix(name, ".js") && !strings.HasSuffix(name, ".js.gz") && !strings.HasSuffix(name, ".js.br") {
|
||||
// Check if compressed versions exist
|
||||
base := strings.TrimSuffix(name, ".js")
|
||||
if _, err := os.Stat(filepath.Join("./ui_dist/assets", base+".js.gz")); err == nil {
|
||||
testFile = "assets/" + name
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if testFile == "" {
|
||||
t.Skip("No suitable test file found with compressed versions")
|
||||
}
|
||||
|
||||
fs := http.FS(os.DirFS("./ui_dist"))
|
||||
|
||||
// Test brotli
|
||||
t.Run("brotli", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/"+testFile, nil)
|
||||
req.Header.Set("Accept-Encoding", "br")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, testFile)
|
||||
|
||||
resp := w.Result()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "br" {
|
||||
t.Errorf("Expected Content-Encoding 'br', got '%s'", encoding)
|
||||
}
|
||||
})
|
||||
|
||||
// Test gzip
|
||||
t.Run("gzip", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/"+testFile, nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, testFile)
|
||||
|
||||
resp := w.Result()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "gzip" {
|
||||
t.Errorf("Expected Content-Encoding 'gzip', got '%s'", encoding)
|
||||
}
|
||||
|
||||
// Verify it's valid gzip
|
||||
reader, err := gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid gzip content: %v", err)
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Just read to verify it's valid
|
||||
_, err = io.Copy(io.Discard, reader)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to decompress gzip: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
//go:embed ui_dist
|
||||
var reactStaticFS embed.FS
|
||||
|
||||
// GetReactFS returns the embedded React filesystem
|
||||
func GetReactFS() (http.FileSystem, error) {
|
||||
subFS, err := fs.Sub(reactStaticFS, "ui_dist")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return http.FS(subFS), nil
|
||||
}
|
||||
|
||||
// GetReactIndexHTML returns the main index.html for the React app
|
||||
func GetReactIndexHTML() ([]byte, error) {
|
||||
return reactStaticFS.ReadFile("ui_dist/index.html")
|
||||
}
|
||||
@@ -6,6 +6,8 @@
|
||||
|
||||
let isUnloading = $state(false);
|
||||
let menuOpen = $state(false);
|
||||
let pendingLoads = $state<Record<string, boolean>>({});
|
||||
const loadControllers = new Map<string, AbortController>();
|
||||
|
||||
const showUnlistedStore = persistentStore<boolean>("showUnlisted", true);
|
||||
const showIdorNameStore = persistentStore<"id" | "name">("showIdorName", "id");
|
||||
@@ -42,6 +44,25 @@
|
||||
}
|
||||
}
|
||||
|
||||
async function handleLoadModel(modelId: string): Promise<void> {
|
||||
if (pendingLoads[modelId]) return;
|
||||
const controller = new AbortController();
|
||||
loadControllers.set(modelId, controller);
|
||||
pendingLoads[modelId] = true;
|
||||
try {
|
||||
await loadModel(modelId, controller.signal);
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
} finally {
|
||||
loadControllers.delete(modelId);
|
||||
delete pendingLoads[modelId];
|
||||
}
|
||||
}
|
||||
|
||||
function cancelLoad(modelId: string): void {
|
||||
loadControllers.get(modelId)?.abort();
|
||||
}
|
||||
|
||||
function toggleIdorName(): void {
|
||||
showIdorNameStore.update((prev) => (prev === "name" ? "id" : "name"));
|
||||
}
|
||||
@@ -170,14 +191,20 @@
|
||||
{/if}
|
||||
</td>
|
||||
<td class="w-12">
|
||||
{#if model.state === "stopped"}
|
||||
<button class="btn btn--sm" onclick={() => loadModel(model.id)}>Load</button>
|
||||
{#if model.state === "stopped" && pendingLoads[model.id]}
|
||||
<button class="btn btn--sm" onclick={() => cancelLoad(model.id)}>Cancel</button>
|
||||
{:else if model.state === "stopped"}
|
||||
<button class="btn btn--sm" onclick={() => handleLoadModel(model.id)}>Load</button>
|
||||
{:else}
|
||||
<button class="btn btn--sm" onclick={() => unloadSingleModel(model.id)} disabled={model.state !== "ready"}>Unload</button>
|
||||
{/if}
|
||||
</td>
|
||||
<td class="w-20">
|
||||
<span class="w-16 text-center status status--{model.state}">{model.state}</span>
|
||||
{#if model.state === "stopped" && pendingLoads[model.id]}
|
||||
<span class="w-16 text-center status status--queued">queued</span>
|
||||
{:else}
|
||||
<span class="w-16 text-center status status--{model.state}">{model.state}</span>
|
||||
{/if}
|
||||
</td>
|
||||
</tr>
|
||||
{/each}
|
||||
|
||||
@@ -145,7 +145,7 @@
|
||||
<div class="flex flex-col h-full">
|
||||
<!-- Model selector -->
|
||||
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
|
||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select an audio model..." disabled={isTranscribing} />
|
||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select an audio model..." disabled={isTranscribing} capabilities={["audio_transcriptions"]} />
|
||||
</div>
|
||||
|
||||
<!-- Empty state for no models configured -->
|
||||
|
||||
@@ -0,0 +1,632 @@
|
||||
<script lang="ts">
|
||||
import { models } from "../../stores/api";
|
||||
import { persistentStore } from "../../stores/persistent";
|
||||
import { streamChatCompletion } from "../../lib/chatApi";
|
||||
|
||||
type Status = "waiting" | "streaming" | "done" | "error";
|
||||
type Phase = "waiting" | "loading" | "reasoning" | "content";
|
||||
type RunState = {
|
||||
status: Status;
|
||||
loadingText: string;
|
||||
reasoningContent: string;
|
||||
content: string;
|
||||
loadingDone: boolean;
|
||||
waitingMs: number;
|
||||
loadingMs: number;
|
||||
reasoningMs: number;
|
||||
contentMs: number;
|
||||
phase: Phase;
|
||||
elapsedMs: number;
|
||||
error?: string;
|
||||
};
|
||||
type TestEntry = { id: string; model: string };
|
||||
|
||||
const LOAD_MARKER = "━━━━━";
|
||||
|
||||
const DEFAULT_PROMPT = "Write a few sentences about the history of computing.";
|
||||
const DEFAULT_MAX_TOKENS = 256;
|
||||
|
||||
const promptStore = persistentStore<string>("concurrency-prompt", DEFAULT_PROMPT);
|
||||
const maxTokensStore = persistentStore<number>("concurrency-max-tokens", DEFAULT_MAX_TOKENS);
|
||||
const testListStore = persistentStore<TestEntry[]>("concurrency-test-list", []);
|
||||
|
||||
let runs = $state<Record<string, RunState>>({});
|
||||
let isRunning = $state(false);
|
||||
let abortController: AbortController | null = null;
|
||||
let dragIndex = $state<number | null>(null);
|
||||
let dragOverIndex = $state<number | null>(null);
|
||||
|
||||
const timelineCollapsedStore = persistentStore<boolean>("concurrency-timeline-collapsed", false);
|
||||
|
||||
let timelineMaxMs = $derived(Math.max(100, ...Object.values(runs).map((r) => r.elapsedMs)));
|
||||
|
||||
let availableModels = $derived($models.filter((m) => !m.unlisted));
|
||||
let hasModels = $derived(availableModels.length > 0);
|
||||
let canRun = $derived(!isRunning && $testListStore.length > 0 && $promptStore.trim() !== "");
|
||||
|
||||
function newId(): string {
|
||||
if (typeof crypto !== "undefined" && "randomUUID" in crypto) {
|
||||
return crypto.randomUUID();
|
||||
}
|
||||
return `${Date.now()}-${Math.random().toString(36).slice(2)}`;
|
||||
}
|
||||
|
||||
function addModel(modelId: string) {
|
||||
if (isRunning) return;
|
||||
testListStore.update((list) => [...list, { id: newId(), model: modelId }]);
|
||||
}
|
||||
|
||||
function removeEntry(id: string) {
|
||||
if (isRunning) return;
|
||||
testListStore.update((list) => list.filter((e) => e.id !== id));
|
||||
const next = { ...runs };
|
||||
delete next[id];
|
||||
runs = next;
|
||||
}
|
||||
|
||||
function clearAll() {
|
||||
if (isRunning) return;
|
||||
testListStore.set([]);
|
||||
runs = {};
|
||||
}
|
||||
|
||||
function onDragStart(i: number, e: DragEvent) {
|
||||
if (isRunning) return;
|
||||
dragIndex = i;
|
||||
if (e.dataTransfer) {
|
||||
e.dataTransfer.effectAllowed = "move";
|
||||
e.dataTransfer.setData("text/plain", String(i));
|
||||
}
|
||||
}
|
||||
|
||||
function onDragOver(i: number, e: DragEvent) {
|
||||
if (isRunning || dragIndex === null) return;
|
||||
e.preventDefault();
|
||||
if (e.dataTransfer) e.dataTransfer.dropEffect = "move";
|
||||
dragOverIndex = i;
|
||||
}
|
||||
|
||||
function onDrop(i: number, e: DragEvent) {
|
||||
if (isRunning || dragIndex === null) return;
|
||||
e.preventDefault();
|
||||
const from = dragIndex;
|
||||
const to = i;
|
||||
dragIndex = null;
|
||||
dragOverIndex = null;
|
||||
if (from === to) return;
|
||||
testListStore.update((list) => {
|
||||
const next = [...list];
|
||||
const [moved] = next.splice(from, 1);
|
||||
next.splice(to, 0, moved);
|
||||
return next;
|
||||
});
|
||||
}
|
||||
|
||||
function onDragEnd() {
|
||||
dragIndex = null;
|
||||
dragOverIndex = null;
|
||||
}
|
||||
|
||||
function emptyRun(): RunState {
|
||||
return {
|
||||
status: "waiting",
|
||||
loadingText: "",
|
||||
reasoningContent: "",
|
||||
content: "",
|
||||
loadingDone: false,
|
||||
waitingMs: 0,
|
||||
loadingMs: 0,
|
||||
reasoningMs: 0,
|
||||
contentMs: 0,
|
||||
phase: "waiting",
|
||||
elapsedMs: 0,
|
||||
};
|
||||
}
|
||||
|
||||
// Detect and split the llama-swap loading block (wrapped in ━━━━━ markers,
|
||||
// delivered as reasoning_content) from the model's own reasoning tokens.
|
||||
function ingestReasoning(
|
||||
prev: RunState,
|
||||
chunk: string
|
||||
): { loadingText: string; reasoningContent: string; loadingDone: boolean; nowPhase: Phase } {
|
||||
if (prev.loadingDone) {
|
||||
return {
|
||||
loadingText: prev.loadingText,
|
||||
reasoningContent: prev.reasoningContent + chunk,
|
||||
loadingDone: true,
|
||||
nowPhase: "reasoning",
|
||||
};
|
||||
}
|
||||
|
||||
const combined = prev.loadingText + chunk;
|
||||
// Not enough to decide whether this is a loading marker
|
||||
if (combined.length < LOAD_MARKER.length) {
|
||||
if (LOAD_MARKER.startsWith(combined)) {
|
||||
return { loadingText: combined, reasoningContent: prev.reasoningContent, loadingDone: false, nowPhase: "loading" };
|
||||
}
|
||||
return {
|
||||
loadingText: "",
|
||||
reasoningContent: prev.reasoningContent + combined,
|
||||
loadingDone: true,
|
||||
nowPhase: "reasoning",
|
||||
};
|
||||
}
|
||||
|
||||
if (!combined.startsWith(LOAD_MARKER)) {
|
||||
return {
|
||||
loadingText: "",
|
||||
reasoningContent: prev.reasoningContent + combined,
|
||||
loadingDone: true,
|
||||
nowPhase: "reasoning",
|
||||
};
|
||||
}
|
||||
|
||||
// We're inside a loading block — look for the closing marker
|
||||
const closingIdx = combined.indexOf(LOAD_MARKER, LOAD_MARKER.length);
|
||||
if (closingIdx < 0) {
|
||||
return { loadingText: combined, reasoningContent: prev.reasoningContent, loadingDone: false, nowPhase: "loading" };
|
||||
}
|
||||
const newlineIdx = combined.indexOf("\n", closingIdx);
|
||||
const sliceEnd = newlineIdx >= 0 ? newlineIdx + 1 : combined.length;
|
||||
const loadingPart = combined.substring(0, sliceEnd);
|
||||
// Strip the trailing " \n" the loader sends after the closing marker
|
||||
const remainder = combined.substring(sliceEnd).replace(/^[ \t]*\n?/, "");
|
||||
return {
|
||||
loadingText: loadingPart,
|
||||
reasoningContent: prev.reasoningContent + remainder,
|
||||
loadingDone: true,
|
||||
nowPhase: remainder ? "reasoning" : "waiting",
|
||||
};
|
||||
}
|
||||
|
||||
async function runOne(entry: TestEntry, signal: AbortSignal) {
|
||||
const start = performance.now();
|
||||
let phaseStart = start;
|
||||
runs[entry.id] = { ...emptyRun(), status: "streaming" };
|
||||
|
||||
const accrue = (
|
||||
prev: RunState,
|
||||
now: number
|
||||
): { waitingMs: number; loadingMs: number; reasoningMs: number; contentMs: number } => {
|
||||
const delta = now - phaseStart;
|
||||
const base = {
|
||||
waitingMs: prev.waitingMs,
|
||||
loadingMs: prev.loadingMs,
|
||||
reasoningMs: prev.reasoningMs,
|
||||
contentMs: prev.contentMs,
|
||||
};
|
||||
if (prev.phase === "waiting") return { ...base, waitingMs: base.waitingMs + delta };
|
||||
if (prev.phase === "loading") return { ...base, loadingMs: base.loadingMs + delta };
|
||||
if (prev.phase === "reasoning") return { ...base, reasoningMs: base.reasoningMs + delta };
|
||||
if (prev.phase === "content") return { ...base, contentMs: base.contentMs + delta };
|
||||
return base;
|
||||
};
|
||||
|
||||
const ticker = window.setInterval(() => {
|
||||
const prev = runs[entry.id];
|
||||
if (!prev || prev.status !== "streaming") return;
|
||||
const now = performance.now();
|
||||
const accrued = accrue(prev, now);
|
||||
phaseStart = now;
|
||||
runs[entry.id] = { ...prev, ...accrued, elapsedMs: now - start };
|
||||
}, 50);
|
||||
|
||||
try {
|
||||
const stream = streamChatCompletion(entry.model, [{ role: "user", content: $promptStore }], signal, {
|
||||
endpoint: "v1/chat/completions",
|
||||
max_tokens: $maxTokensStore,
|
||||
});
|
||||
for await (const chunk of stream) {
|
||||
if (chunk.done) break;
|
||||
const prev = runs[entry.id];
|
||||
if (!prev) break;
|
||||
const now = performance.now();
|
||||
const accrued = accrue(prev, now);
|
||||
phaseStart = now;
|
||||
|
||||
let nextPhase: Phase = prev.phase;
|
||||
let loadingText = prev.loadingText;
|
||||
let reasoningContent = prev.reasoningContent;
|
||||
let loadingDone = prev.loadingDone;
|
||||
|
||||
if (chunk.reasoning_content) {
|
||||
const parsed = ingestReasoning(prev, chunk.reasoning_content);
|
||||
loadingText = parsed.loadingText;
|
||||
reasoningContent = parsed.reasoningContent;
|
||||
loadingDone = parsed.loadingDone;
|
||||
nextPhase = parsed.nowPhase;
|
||||
}
|
||||
if (chunk.content) nextPhase = "content";
|
||||
|
||||
runs[entry.id] = {
|
||||
...prev,
|
||||
...accrued,
|
||||
loadingText,
|
||||
reasoningContent,
|
||||
content: prev.content + (chunk.content ?? ""),
|
||||
loadingDone,
|
||||
phase: nextPhase,
|
||||
elapsedMs: now - start,
|
||||
};
|
||||
}
|
||||
const prev = runs[entry.id];
|
||||
if (prev) {
|
||||
const now = performance.now();
|
||||
const accrued = accrue(prev, now);
|
||||
runs[entry.id] = { ...prev, ...accrued, status: "done", elapsedMs: now - start };
|
||||
}
|
||||
} catch (err) {
|
||||
const prev = runs[entry.id] ?? emptyRun();
|
||||
const now = performance.now();
|
||||
const accrued = accrue(prev, now);
|
||||
const aborted = err instanceof Error && err.name === "AbortError";
|
||||
runs[entry.id] = {
|
||||
...prev,
|
||||
...accrued,
|
||||
status: "error",
|
||||
elapsedMs: now - start,
|
||||
error: aborted ? "aborted" : err instanceof Error ? err.message : String(err),
|
||||
};
|
||||
} finally {
|
||||
window.clearInterval(ticker);
|
||||
}
|
||||
}
|
||||
|
||||
async function run() {
|
||||
if (!canRun) return;
|
||||
const entries = $testListStore;
|
||||
const initial: Record<string, RunState> = {};
|
||||
for (const e of entries) {
|
||||
initial[e.id] = emptyRun();
|
||||
}
|
||||
runs = initial;
|
||||
isRunning = true;
|
||||
abortController = new AbortController();
|
||||
try {
|
||||
await Promise.allSettled(entries.map((e) => runOne(e, abortController!.signal)));
|
||||
} finally {
|
||||
isRunning = false;
|
||||
abortController = null;
|
||||
}
|
||||
}
|
||||
|
||||
function stop() {
|
||||
abortController?.abort();
|
||||
}
|
||||
|
||||
function waitingBarClass(run: RunState): string {
|
||||
if (run.status === "error" && run.phase === "waiting") return "bg-red-500";
|
||||
return "bg-slate-200 dark:bg-white/10";
|
||||
}
|
||||
|
||||
function loadingBarClass(run: RunState): string {
|
||||
if (run.status === "error" && run.phase === "loading") return "bg-red-500";
|
||||
return "bg-slate-400 dark:bg-slate-500";
|
||||
}
|
||||
|
||||
function reasoningBarClass(run: RunState): string {
|
||||
if (run.status === "error" && run.phase === "reasoning") return "bg-red-500";
|
||||
return "bg-purple-500";
|
||||
}
|
||||
|
||||
function contentBarClass(run: RunState): string {
|
||||
if (run.status === "error" && run.phase === "content") return "bg-red-500";
|
||||
if (run.status === "done") return "bg-green-500";
|
||||
return "bg-amber-400 dark:bg-amber-500";
|
||||
}
|
||||
|
||||
function niceStepMs(maxMs: number): number {
|
||||
if (maxMs <= 500) return 100;
|
||||
if (maxMs <= 2000) return 500;
|
||||
if (maxMs <= 5000) return 1000;
|
||||
if (maxMs <= 20000) return 5000;
|
||||
if (maxMs <= 60000) return 10000;
|
||||
return 30000;
|
||||
}
|
||||
|
||||
function formatTickMs(ms: number): string {
|
||||
if (ms < 1000) return `${ms}`;
|
||||
return `${(ms / 1000).toFixed(ms % 1000 === 0 ? 0 : 1)}s`;
|
||||
}
|
||||
|
||||
let timelineTicks = $derived.by(() => {
|
||||
const step = niceStepMs(timelineMaxMs);
|
||||
const ticks: number[] = [];
|
||||
for (let t = 0; t <= timelineMaxMs; t += step) ticks.push(t);
|
||||
return ticks;
|
||||
});
|
||||
|
||||
function statusBadgeClass(status: Status): string {
|
||||
switch (status) {
|
||||
case "waiting":
|
||||
return "bg-gray-200 text-gray-700 dark:bg-gray-700 dark:text-gray-200";
|
||||
case "streaming":
|
||||
return "bg-amber-200 text-amber-900 dark:bg-amber-500/30 dark:text-amber-200";
|
||||
case "done":
|
||||
return "bg-green-200 text-green-900 dark:bg-green-500/30 dark:text-green-200";
|
||||
case "error":
|
||||
return "bg-red-200 text-red-900 dark:bg-red-500/30 dark:text-red-200";
|
||||
}
|
||||
}
|
||||
|
||||
function formatElapsed(ms: number): string {
|
||||
if (ms < 1000) return `${Math.round(ms)}ms`;
|
||||
return `${(ms / 1000).toFixed(2)}s`;
|
||||
}
|
||||
|
||||
function resetDefaults() {
|
||||
promptStore.set(DEFAULT_PROMPT);
|
||||
maxTokensStore.set(DEFAULT_MAX_TOKENS);
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col md:flex-row gap-4 h-full min-h-0">
|
||||
<!-- Left column: run controls, model picker, settings -->
|
||||
<div class="md:w-72 shrink-0 flex flex-col gap-3 min-h-0">
|
||||
<!-- Run controls -->
|
||||
<div class="flex items-center gap-2">
|
||||
{#if isRunning}
|
||||
<button class="btn bg-red-500 hover:bg-red-600 text-white border-red-500" onclick={stop}>
|
||||
<span class="inline-block w-3 h-3 bg-white align-middle mr-2"></span>Stop
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
class="btn bg-primary text-btn-primary-text hover:opacity-90"
|
||||
onclick={run}
|
||||
disabled={!canRun}
|
||||
title={$testListStore.length === 0 ? "Add models from the list below" : "Run concurrent requests"}
|
||||
>
|
||||
<span class="inline-block align-middle mr-2" aria-hidden="true">▶</span>Go
|
||||
</button>
|
||||
{/if}
|
||||
<button class="btn btn--sm" onclick={clearAll} disabled={isRunning || $testListStore.length === 0}>
|
||||
Clear ({$testListStore.length})
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Available models -->
|
||||
<div class="flex flex-col min-h-0 flex-1">
|
||||
<div class="text-xs font-medium text-txtsecondary mb-1">
|
||||
Models <span class="text-[10px] font-normal">— click to queue (add the same model more than once to test parallel requests)</span>
|
||||
</div>
|
||||
<div class="flex-1 border border-gray-200 dark:border-white/10 rounded overflow-y-auto min-h-0">
|
||||
{#if !hasModels}
|
||||
<div class="p-3 text-sm text-txtsecondary text-center">No models configured.</div>
|
||||
{:else}
|
||||
<ul class="divide-y divide-gray-100 dark:divide-white/5">
|
||||
{#each availableModels as m (m.id)}
|
||||
<li>
|
||||
<button
|
||||
class="w-full text-left px-2 py-1.5 text-sm hover:bg-secondary-hover transition-colors disabled:opacity-50 disabled:cursor-not-allowed flex items-center gap-2"
|
||||
onclick={() => addModel(m.id)}
|
||||
disabled={isRunning}
|
||||
title="Add {m.id}"
|
||||
>
|
||||
<span class="text-primary" aria-hidden="true">+</span>
|
||||
<span class="truncate flex-1">{m.id}</span>
|
||||
</button>
|
||||
</li>
|
||||
{/each}
|
||||
</ul>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Settings -->
|
||||
<div class="flex flex-col gap-2 border-t border-gray-200 dark:border-white/10 pt-3">
|
||||
<div class="flex items-center justify-between">
|
||||
<label for="concurrency-prompt" class="text-xs font-medium text-txtsecondary">Prompt</label>
|
||||
<button
|
||||
class="text-[10px] text-txtsecondary hover:text-txtmain underline"
|
||||
onclick={resetDefaults}
|
||||
disabled={isRunning}
|
||||
>
|
||||
reset defaults
|
||||
</button>
|
||||
</div>
|
||||
<textarea
|
||||
id="concurrency-prompt"
|
||||
class="w-full px-2 py-1.5 text-sm rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary resize-none"
|
||||
rows="3"
|
||||
bind:value={$promptStore}
|
||||
disabled={isRunning}
|
||||
></textarea>
|
||||
<label for="concurrency-max-tokens" class="text-xs font-medium text-txtsecondary">max_tokens</label>
|
||||
<input
|
||||
id="concurrency-max-tokens"
|
||||
type="number"
|
||||
min="1"
|
||||
class="w-full px-2 py-1.5 text-sm rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
bind:value={$maxTokensStore}
|
||||
disabled={isRunning}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Right column: result panels (draggable to reorder) -->
|
||||
<div class="flex-1 min-w-0 min-h-0 overflow-y-auto">
|
||||
{#if $testListStore.length === 0}
|
||||
<div class="h-full flex items-center justify-center px-6">
|
||||
<div class="max-w-md text-sm text-txtsecondary space-y-4">
|
||||
<h4 class="text-base font-semibold text-txtmain pb-0">Load Test</h4>
|
||||
<p>
|
||||
Fire several streaming chat completions at llama-swap at the same time to see how it handles parallel
|
||||
loading and concurrent inference. Each request streams into its own panel with a live timer and status.
|
||||
</p>
|
||||
<ol class="list-decimal list-inside space-y-1">
|
||||
<li>Click models on the left to queue them — repeat a model to hit it with parallel requests.</li>
|
||||
<li>Tweak the prompt and <code>max_tokens</code> if you want.</li>
|
||||
<li>Press <span class="font-semibold text-txtmain">Go</span> to launch them concurrently.</li>
|
||||
</ol>
|
||||
<p class="text-xs">Tip: drag a result card's header to reorder, or hit × to drop it.</p>
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<!-- Gantt-style timeline -->
|
||||
<div class="mb-3 border border-gray-200 dark:border-white/10 rounded">
|
||||
<button
|
||||
class="w-full flex items-center gap-2 px-2 py-1.5 text-xs font-medium text-txtsecondary hover:bg-secondary-hover transition-colors {$timelineCollapsedStore ? 'rounded' : 'rounded-t border-b border-gray-200 dark:border-white/10'}"
|
||||
onclick={() => timelineCollapsedStore.update((v) => !v)}
|
||||
aria-expanded={!$timelineCollapsedStore}
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4 transition-transform {$timelineCollapsedStore ? '-rotate-90' : ''}"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
viewBox="0 0 24 24"
|
||||
aria-hidden="true"
|
||||
>
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7"></path>
|
||||
</svg>
|
||||
<span>Timeline</span>
|
||||
{#if !$timelineCollapsedStore}
|
||||
<span class="flex items-center gap-3 text-[10px] text-txtsecondary font-normal ml-3" aria-hidden="true">
|
||||
<span class="flex items-center gap-1"><span class="inline-block w-2.5 h-2.5 rounded-sm bg-slate-200 dark:bg-white/10 border border-gray-300 dark:border-white/10"></span>waiting</span>
|
||||
<span class="flex items-center gap-1"><span class="inline-block w-2.5 h-2.5 rounded-sm bg-slate-400 dark:bg-slate-500"></span>loading</span>
|
||||
<span class="flex items-center gap-1"><span class="inline-block w-2.5 h-2.5 rounded-sm bg-purple-500"></span>reasoning</span>
|
||||
<span class="flex items-center gap-1"><span class="inline-block w-2.5 h-2.5 rounded-sm bg-amber-400 dark:bg-amber-500"></span>streaming</span>
|
||||
<span class="flex items-center gap-1"><span class="inline-block w-2.5 h-2.5 rounded-sm bg-green-500"></span>done</span>
|
||||
<span class="flex items-center gap-1"><span class="inline-block w-2.5 h-2.5 rounded-sm bg-red-500"></span>error</span>
|
||||
</span>
|
||||
{/if}
|
||||
<span class="ml-auto tabular-nums text-txtsecondary">
|
||||
max {formatElapsed(timelineMaxMs)} · {$testListStore.length} request{$testListStore.length === 1 ? "" : "s"}
|
||||
</span>
|
||||
</button>
|
||||
{#if !$timelineCollapsedStore}
|
||||
<div class="px-2 py-2">
|
||||
<!-- X axis ticks -->
|
||||
<div class="flex" aria-hidden="true">
|
||||
<div class="w-40 shrink-0"></div>
|
||||
<div class="relative flex-1 h-4 border-b border-gray-200 dark:border-white/10">
|
||||
{#each timelineTicks as t (t)}
|
||||
<div
|
||||
class="absolute top-0 bottom-0 border-l border-gray-200 dark:border-white/10"
|
||||
style="left: {(t / timelineMaxMs) * 100}%;"
|
||||
>
|
||||
<span class="absolute -top-0.5 left-1 text-[10px] text-txtsecondary tabular-nums">{formatTickMs(t)}</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
<div class="w-16 shrink-0"></div>
|
||||
</div>
|
||||
<!-- Bars -->
|
||||
<div class="flex flex-col gap-1 mt-1">
|
||||
{#each $testListStore as entry, i (entry.id)}
|
||||
{@const run = runs[entry.id]}
|
||||
{@const waitingPct = run ? (run.waitingMs / timelineMaxMs) * 100 : 0}
|
||||
{@const loadingPct = run ? (run.loadingMs / timelineMaxMs) * 100 : 0}
|
||||
{@const reasoningPct = run ? (run.reasoningMs / timelineMaxMs) * 100 : 0}
|
||||
{@const contentPct = run ? (run.contentMs / timelineMaxMs) * 100 : 0}
|
||||
<div class="flex items-center text-xs">
|
||||
<div class="w-40 shrink-0 flex items-center gap-1 pr-2 text-txtsecondary">
|
||||
<span class="tabular-nums w-5 text-right">{i + 1}.</span>
|
||||
<span class="truncate" title={entry.model}>{entry.model}</span>
|
||||
</div>
|
||||
<div class="relative flex-1 h-4">
|
||||
{#each timelineTicks as t (t)}
|
||||
<div
|
||||
class="absolute top-0 bottom-0 border-l border-gray-100 dark:border-white/5"
|
||||
style="left: {(t / timelineMaxMs) * 100}%;"
|
||||
aria-hidden="true"
|
||||
></div>
|
||||
{/each}
|
||||
{#if run && run.waitingMs > 0}
|
||||
<div
|
||||
class="absolute top-0.5 bottom-0.5 rounded-l-sm transition-all {waitingBarClass(run)}"
|
||||
style="left: 0; width: {waitingPct}%;"
|
||||
title="waiting {formatElapsed(run.waitingMs)}"
|
||||
></div>
|
||||
{/if}
|
||||
{#if run && run.loadingMs > 0}
|
||||
<div
|
||||
class="absolute top-0.5 bottom-0.5 transition-all {loadingBarClass(run)} {run.waitingMs === 0 ? 'rounded-l-sm' : ''}"
|
||||
style="left: {waitingPct}%; width: {loadingPct}%;"
|
||||
title="loading {formatElapsed(run.loadingMs)}"
|
||||
></div>
|
||||
{/if}
|
||||
{#if run && run.reasoningMs > 0}
|
||||
<div
|
||||
class="absolute top-0.5 bottom-0.5 transition-all {reasoningBarClass(run)} {run.waitingMs === 0 && run.loadingMs === 0 ? 'rounded-l-sm' : ''}"
|
||||
style="left: {waitingPct + loadingPct}%; width: {reasoningPct}%;"
|
||||
title="reasoning {formatElapsed(run.reasoningMs)}"
|
||||
></div>
|
||||
{/if}
|
||||
{#if run && run.contentMs > 0}
|
||||
<div
|
||||
class="absolute top-0.5 bottom-0.5 transition-all {contentBarClass(run)} {run.waitingMs === 0 && run.loadingMs === 0 && run.reasoningMs === 0 ? 'rounded-l-sm' : ''} {run.status === 'done' || run.status === 'error' ? 'rounded-r-sm' : ''}"
|
||||
style="left: {waitingPct + loadingPct + reasoningPct}%; width: {contentPct}%;"
|
||||
title="content {formatElapsed(run.contentMs)}"
|
||||
></div>
|
||||
{/if}
|
||||
</div>
|
||||
<div class="w-16 shrink-0 pl-2 tabular-nums text-txtsecondary text-right">
|
||||
{run ? formatElapsed(run.elapsedMs) : "—"}
|
||||
</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
<div class="grid grid-cols-1 lg:grid-cols-2 xl:grid-cols-3 gap-3" role="list">
|
||||
{#each $testListStore as entry, i (entry.id)}
|
||||
{@const run = runs[entry.id]}
|
||||
{@const status = run?.status ?? "waiting"}
|
||||
<div
|
||||
class="border rounded flex flex-col min-h-0 transition-colors {dragOverIndex === i && dragIndex !== i
|
||||
? 'border-primary ring-2 ring-primary/40'
|
||||
: 'border-gray-200 dark:border-white/10'} {dragIndex === i ? 'opacity-40' : ''}"
|
||||
style="height: 280px;"
|
||||
role="listitem"
|
||||
ondragover={(e) => onDragOver(i, e)}
|
||||
ondrop={(e) => onDrop(i, e)}
|
||||
>
|
||||
<div
|
||||
class="shrink-0 flex items-center gap-2 px-2 py-1.5 border-b border-gray-200 dark:border-white/10 bg-secondary/40 rounded-t"
|
||||
draggable={!isRunning}
|
||||
role="button"
|
||||
tabindex="-1"
|
||||
aria-label="Drag to reorder {entry.model}"
|
||||
ondragstart={(e) => onDragStart(i, e)}
|
||||
ondragend={onDragEnd}
|
||||
class:cursor-grab={!isRunning}
|
||||
title={isRunning ? "" : "Drag to reorder"}
|
||||
>
|
||||
<span class="text-txtsecondary select-none" aria-hidden="true">⋮⋮</span>
|
||||
<span class="text-txtsecondary tabular-nums text-xs w-5 text-right">{i + 1}.</span>
|
||||
<span class="flex-1 truncate text-sm font-medium" title={entry.model}>{entry.model}</span>
|
||||
<span class="text-xs tabular-nums text-txtsecondary">
|
||||
{run ? formatElapsed(run.elapsedMs) : "—"}
|
||||
</span>
|
||||
<span class="status text-[10px] {statusBadgeClass(status)}">{status}</span>
|
||||
<button
|
||||
class="w-5 h-5 flex items-center justify-center text-txtsecondary hover:text-red-500 transition-colors rounded disabled:opacity-30 disabled:cursor-not-allowed"
|
||||
onclick={() => removeEntry(entry.id)}
|
||||
disabled={isRunning}
|
||||
aria-label="Remove"
|
||||
tabindex="-1"
|
||||
>
|
||||
×
|
||||
</button>
|
||||
</div>
|
||||
<div class="flex-1 min-h-0 overflow-y-auto font-mono text-xs px-2 py-1.5">
|
||||
{#if run?.loadingText}
|
||||
<div class="bg-secondary/40 dark:bg-white/5 text-txtsecondary rounded px-2 py-1 mb-2 whitespace-pre-wrap">{run.loadingText.trim()}</div>
|
||||
{/if}
|
||||
{#if run?.reasoningContent}
|
||||
<div class="text-purple-700 dark:text-purple-300 whitespace-pre-wrap">{run.reasoningContent}</div>
|
||||
{/if}
|
||||
{#if run?.content}
|
||||
<div class="whitespace-pre-wrap {run.reasoningContent ? 'mt-2' : ''}">{run.content}</div>
|
||||
{/if}
|
||||
{#if run?.status === "error" && run?.error}
|
||||
<div class="text-red-500 mt-2">[error] {run.error}</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
@@ -193,7 +193,7 @@
|
||||
<div class="flex flex-col h-full">
|
||||
<!-- Model selector and mode toggle -->
|
||||
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
|
||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select an image model..." disabled={isGenerating} />
|
||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select an image model..." disabled={isGenerating} capabilities={["image_generation", "image_to_image"]} matchAny={true} />
|
||||
|
||||
<select
|
||||
class="px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
|
||||
@@ -6,12 +6,15 @@
|
||||
value: string;
|
||||
placeholder?: string;
|
||||
disabled?: boolean;
|
||||
capabilities?: string[];
|
||||
matchAny?: boolean;
|
||||
}
|
||||
|
||||
let { value = $bindable(), placeholder = "Select a model...", disabled = false }: Props = $props();
|
||||
let { value = $bindable(), placeholder = "Select a model...", disabled = false, capabilities, matchAny = false }: Props = $props();
|
||||
|
||||
let grouped = $derived(groupModels($models));
|
||||
let hasModels = $derived(grouped.local.length > 0 || Object.keys(grouped.peersByProvider).length > 0);
|
||||
let grouped = $derived(groupModels($models, capabilities, matchAny));
|
||||
let hasMatching = $derived(grouped.localMatching.length > 0);
|
||||
let hasModels = $derived(hasMatching || grouped.local.length > 0 || Object.keys(grouped.peersByProvider).length > 0);
|
||||
</script>
|
||||
|
||||
{#if hasModels}
|
||||
@@ -21,6 +24,18 @@
|
||||
{disabled}
|
||||
>
|
||||
<option value="">{placeholder}</option>
|
||||
{#if hasMatching}
|
||||
<optgroup label="Matching Capabilities">
|
||||
{#each grouped.localMatching as model (model.id)}
|
||||
<option value={model.id}>{model.id}</option>
|
||||
{#if model.aliases}
|
||||
{#each model.aliases as alias (alias)}
|
||||
<option value={alias}> ↳ {alias}</option>
|
||||
{/each}
|
||||
{/if}
|
||||
{/each}
|
||||
</optgroup>
|
||||
{/if}
|
||||
{#if grouped.local.length > 0}
|
||||
<optgroup label="Local">
|
||||
{#each grouped.local as model (model.id)}
|
||||
|
||||
@@ -264,7 +264,7 @@
|
||||
<div class="flex flex-col h-full">
|
||||
<!-- Top bar: model selector + query input (table mode) + mode toggle -->
|
||||
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
|
||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select a rerank model..." disabled={isLoading} />
|
||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select a rerank model..." disabled={isLoading} capabilities={["reranker"]} />
|
||||
{#if editorMode === "table"}
|
||||
<input
|
||||
type="text"
|
||||
|
||||
@@ -206,7 +206,7 @@
|
||||
<div class="flex flex-col h-full">
|
||||
<!-- Model and voice selectors -->
|
||||
<div class="shrink-0 flex gap-2 mb-4">
|
||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select a speech model..." disabled={isGenerating} />
|
||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select a speech model..." disabled={isGenerating} capabilities={["audio_speech"]} />
|
||||
<div class="flex gap-2">
|
||||
<select
|
||||
class="shrink-0 px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
|
||||
@@ -139,7 +139,8 @@
|
||||
}
|
||||
|
||||
.status--starting,
|
||||
.status--stopping {
|
||||
.status--stopping,
|
||||
.status--queued {
|
||||
@apply bg-warning/10 text-warning;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
import { describe, it, expect } from "vitest";
|
||||
import { matchesCapabilities, groupModels } from "./modelUtils";
|
||||
import type { Model } from "./types";
|
||||
|
||||
function makeModel(overrides: Partial<Model> = {}): Model {
|
||||
return {
|
||||
id: "test-model",
|
||||
state: "ready",
|
||||
name: "Test Model",
|
||||
description: "",
|
||||
unlisted: false,
|
||||
peerID: "",
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe("matchesCapabilities", () => {
|
||||
it("returns true when required is empty", () => {
|
||||
const model = makeModel();
|
||||
expect(matchesCapabilities(model, [])).toBe(true);
|
||||
});
|
||||
|
||||
it("returns false when model has no capabilities", () => {
|
||||
const model = makeModel();
|
||||
expect(matchesCapabilities(model, ["vision"])).toBe(false);
|
||||
});
|
||||
|
||||
it("returns false when model has empty capabilities object", () => {
|
||||
const model = makeModel({ capabilities: {} });
|
||||
expect(matchesCapabilities(model, ["vision"])).toBe(false);
|
||||
});
|
||||
|
||||
it("returns true when model has the single required capability", () => {
|
||||
const model = makeModel({ capabilities: { vision: true } });
|
||||
expect(matchesCapabilities(model, ["vision"])).toBe(true);
|
||||
});
|
||||
|
||||
it("returns false when model lacks the required capability", () => {
|
||||
const model = makeModel({ capabilities: { vision: true } });
|
||||
expect(matchesCapabilities(model, ["audio_transcriptions"])).toBe(false);
|
||||
});
|
||||
|
||||
it("AND semantics: returns true only when all required are present", () => {
|
||||
const model = makeModel({ capabilities: { vision: true, audio_transcriptions: true } });
|
||||
expect(matchesCapabilities(model, ["vision", "audio_transcriptions"])).toBe(true);
|
||||
expect(matchesCapabilities(model, ["vision", "reranker"])).toBe(false);
|
||||
});
|
||||
|
||||
it("matchAny=true: returns true when at least one required is present", () => {
|
||||
const model = makeModel({ capabilities: { vision: true } });
|
||||
expect(matchesCapabilities(model, ["vision", "reranker"], true)).toBe(true);
|
||||
expect(matchesCapabilities(model, ["audio_transcriptions", "reranker"], true)).toBe(false);
|
||||
});
|
||||
|
||||
it("matchAny=true with empty required returns true", () => {
|
||||
const model = makeModel();
|
||||
expect(matchesCapabilities(model, [], true)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("groupModels", () => {
|
||||
const models: Model[] = [
|
||||
makeModel({ id: "chat-model", capabilities: { vision: true } }),
|
||||
makeModel({ id: "audio-model", capabilities: { audio_transcriptions: true } }),
|
||||
makeModel({ id: "no-caps-model" }),
|
||||
makeModel({ id: "peer-model", peerID: "peer1" }),
|
||||
makeModel({ id: "unlisted-model", unlisted: true, capabilities: { vision: true } }),
|
||||
];
|
||||
|
||||
it("filters out unlisted models", () => {
|
||||
const result = groupModels(models);
|
||||
expect(result.localMatching.length + result.local.length).toBe(3);
|
||||
expect([...result.localMatching, ...result.local].every((m) => !m.unlisted)).toBe(true);
|
||||
});
|
||||
|
||||
it("separates peer models into peersByProvider", () => {
|
||||
const result = groupModels(models);
|
||||
expect(result.peersByProvider["peer1"]).toHaveLength(1);
|
||||
expect(result.peersByProvider["peer1"][0].id).toBe("peer-model");
|
||||
});
|
||||
|
||||
it("without capabilities, all local models go to local (non-matching)", () => {
|
||||
const result = groupModels(models);
|
||||
expect(result.localMatching).toHaveLength(0);
|
||||
expect(result.local).toHaveLength(3);
|
||||
});
|
||||
|
||||
it("with capabilities, matching models go to localMatching", () => {
|
||||
const result = groupModels(models, ["vision"]);
|
||||
expect(result.localMatching).toHaveLength(1);
|
||||
expect(result.localMatching[0].id).toBe("chat-model");
|
||||
expect(result.local).toHaveLength(2);
|
||||
});
|
||||
|
||||
it("with capabilities, models without capabilities go to local", () => {
|
||||
const result = groupModels(models, ["vision"]);
|
||||
expect(result.local.find((m) => m.id === "no-caps-model")).toBeDefined();
|
||||
});
|
||||
|
||||
it("with matchAny, matches models with any listed capability", () => {
|
||||
const result = groupModels(models, ["vision", "audio_transcriptions"], true);
|
||||
expect(result.localMatching).toHaveLength(2);
|
||||
expect(result.localMatching.map((m) => m.id)).toContain("chat-model");
|
||||
expect(result.localMatching.map((m) => m.id)).toContain("audio-model");
|
||||
expect(result.local).toHaveLength(1);
|
||||
});
|
||||
|
||||
it("with empty capabilities array, all local go to local (non-matching)", () => {
|
||||
const result = groupModels(models, []);
|
||||
expect(result.localMatching).toHaveLength(0);
|
||||
expect(result.local).toHaveLength(3);
|
||||
});
|
||||
});
|
||||
@@ -2,14 +2,40 @@ import type { Model } from "./types";
|
||||
|
||||
export interface GroupedModels {
|
||||
local: Model[];
|
||||
localMatching: Model[];
|
||||
peersByProvider: Record<string, Model[]>;
|
||||
}
|
||||
|
||||
export function groupModels(models: Model[]): GroupedModels {
|
||||
export function matchesCapabilities(model: Model, required: string[], matchAny = false): boolean {
|
||||
if (!required.length) return true;
|
||||
if (!model.capabilities) return false;
|
||||
const caps = model.capabilities as Record<string, boolean>;
|
||||
if (matchAny) {
|
||||
return required.some((cap) => caps[cap] === true);
|
||||
}
|
||||
return required.every((cap) => caps[cap] === true);
|
||||
}
|
||||
|
||||
export function groupModels(models: Model[], capabilities?: string[], matchAny = false): GroupedModels {
|
||||
const available = models.filter((m) => !m.unlisted);
|
||||
const local = available.filter((m) => !m.peerID);
|
||||
const peerModels = available.filter((m) => m.peerID);
|
||||
|
||||
let localMatching: Model[] = [];
|
||||
let localRest: Model[] = [];
|
||||
|
||||
if (capabilities && capabilities.length > 0) {
|
||||
for (const model of local) {
|
||||
if (matchesCapabilities(model, capabilities, matchAny)) {
|
||||
localMatching.push(model);
|
||||
} else {
|
||||
localRest.push(model);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
localRest = local;
|
||||
}
|
||||
|
||||
const peersByProvider = peerModels.reduce(
|
||||
(acc, model) => {
|
||||
const peerId = model.peerID || "unknown";
|
||||
@@ -20,5 +46,5 @@ export function groupModels(models: Model[]): GroupedModels {
|
||||
{} as Record<string, Model[]>
|
||||
);
|
||||
|
||||
return { local, peersByProvider };
|
||||
return { local: localRest, localMatching, peersByProvider };
|
||||
}
|
||||
|
||||
@@ -2,6 +2,16 @@ export type ConnectionState = "connected" | "connecting" | "disconnected";
|
||||
|
||||
export type ModelStatus = "ready" | "starting" | "stopping" | "stopped" | "shutdown" | "unknown";
|
||||
|
||||
export interface ModelCapabilities {
|
||||
vision?: boolean;
|
||||
audio_transcriptions?: boolean;
|
||||
audio_speech?: boolean;
|
||||
image_generation?: boolean;
|
||||
image_to_image?: boolean;
|
||||
function_calling?: boolean;
|
||||
reranker?: boolean;
|
||||
}
|
||||
|
||||
export interface Model {
|
||||
id: string;
|
||||
state: ModelStatus;
|
||||
@@ -10,6 +20,7 @@ export interface Model {
|
||||
unlisted: boolean;
|
||||
peerID: string;
|
||||
aliases?: string[];
|
||||
capabilities?: ModelCapabilities;
|
||||
}
|
||||
|
||||
export interface TokenMetrics {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user