Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0cfe5a6639 | |||
| 44e1501e81 | |||
| 46cea36bc2 | |||
| ccfba0df28 | |||
| ddfae90b19 | |||
| 29d3d9ba20 | |||
| 9be9a87fa0 | |||
| 6ea551362e | |||
| 03d58e53fa | |||
| c790d0ee03 | |||
| 4ca9c478a2 | |||
| 146a9eab24 | |||
| 02e015fa49 | |||
| 63bc266395 | |||
| 636b53e70f | |||
| 59cd3b690d | |||
| 5d1e62d224 | |||
| dbb869d019 | |||
| 26bb17e57e | |||
| 2982dd3d40 |
+1
-1
@@ -13,7 +13,7 @@ reviews:
|
|||||||
docstrings:
|
docstrings:
|
||||||
enabled: false
|
enabled: false
|
||||||
auto_review:
|
auto_review:
|
||||||
enabled: true
|
enabled: false
|
||||||
drafts: false
|
drafts: false
|
||||||
chat:
|
chat:
|
||||||
auto_reply: true
|
auto_reply: true
|
||||||
|
|||||||
@@ -13,11 +13,11 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f #v10.2.0
|
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f #v10.2.0
|
||||||
with:
|
with:
|
||||||
days-before-issue-stale: 14
|
days-before-issue-stale: 30
|
||||||
days-before-issue-close: 14
|
days-before-issue-close: 30
|
||||||
stale-issue-label: "stale"
|
stale-issue-label: "stale"
|
||||||
stale-issue-message: "This issue is stale because it has been open for 2 weeks with no activity."
|
stale-issue-message: "This issue is stale because it has been open without activity for 30 days. Please remove the stale label if this was an error."
|
||||||
close-issue-message: "This issue was closed because it has been inactive for 2 weeks since being marked as stale."
|
close-issue-message: "This issue was closed because it has been inactive for 30 days since being marked as stale."
|
||||||
days-before-pr-stale: -1
|
days-before-pr-stale: -1
|
||||||
days-before-pr-close: -1
|
days-before-pr-close: -1
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|||||||
@@ -2,10 +2,10 @@ name: Build Containers
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
# time has no specific meaning, trying to time it after
|
# time has no specific meaning, trying to time it after
|
||||||
# the llama.cpp daily packages are published
|
# the llama.cpp daily packages have time to build and publish (~8hr after llama.cpp project's cron)
|
||||||
# https://github.com/ggml-org/llama.cpp/blob/master/.github/workflows/docker.yml
|
# https://github.com/ggml-org/llama.cpp/blob/master/.github/workflows/docker.yml
|
||||||
schedule:
|
schedule:
|
||||||
- cron: "37 5 * * *"
|
- cron: "00 12,18 * * *"
|
||||||
|
|
||||||
# Allows manual triggering of the workflow
|
# Allows manual triggering of the workflow
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|||||||
@@ -32,11 +32,9 @@ jobs:
|
|||||||
uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # 6.4.0
|
uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # 6.4.0
|
||||||
with:
|
with:
|
||||||
node-version: "24"
|
node-version: "24"
|
||||||
- name: Install dependencies and build UI
|
- name: Build UI
|
||||||
run: |
|
run: |
|
||||||
cd ui-svelte
|
make ui
|
||||||
npm ci
|
|
||||||
npm run build
|
|
||||||
|
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
uses: goreleaser/goreleaser-action@1a80836c5c9d9e5755a25cb59ec6f45a3b5f41a8 #7.2.1
|
uses: goreleaser/goreleaser-action@1a80836c5c9d9e5755a25cb59ec6f45a3b5f41a8 #7.2.1
|
||||||
|
|||||||
@@ -5,3 +5,6 @@ dist/
|
|||||||
.vscode
|
.vscode
|
||||||
.DS_Store
|
.DS_Store
|
||||||
.dev/
|
.dev/
|
||||||
|
|
||||||
|
# UI build output; placeholder.txt is kept so the go:embed succeeds.
|
||||||
|
internal/server/ui_dist/*
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
|
|||||||
|
|
||||||
- Follow test naming conventions like `TestProxyManager_<test name>`, `TestProcessGroup_<test name>`, etc.
|
- Follow test naming conventions like `TestProxyManager_<test name>`, `TestProcessGroup_<test name>`, etc.
|
||||||
- Use `go test -v -run <name pattern for new tests>` to run any new tests you've written.
|
- Use `go test -v -run <name pattern for new tests>` to run any new tests you've written.
|
||||||
- Run `gofmt -l .` before committing to verify formatting. Fix any reported files with `gofmt -w <file>`.
|
- Run `gofmt -w <file>` before committing to fix any formatting
|
||||||
|
- Build go binaries into the ./build/ subdirectory
|
||||||
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
|
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
|
||||||
- Use `make test-all` before completing work. This includes long running concurrency tests.
|
- Use `make test-all` before completing work. This includes long running concurrency tests.
|
||||||
- Use `make test-ui` after making changes to the UI in ui-svelte/
|
- Use `make test-ui` after making changes to the UI in ui-svelte/
|
||||||
|
|||||||
@@ -19,21 +19,17 @@ all: mac linux simple-responder
|
|||||||
clean:
|
clean:
|
||||||
rm -rf $(BUILD_DIR)
|
rm -rf $(BUILD_DIR)
|
||||||
|
|
||||||
proxy/ui_dist/placeholder.txt:
|
|
||||||
mkdir -p proxy/ui_dist
|
|
||||||
touch $@
|
|
||||||
|
|
||||||
# use cached test results while developing
|
# use cached test results while developing
|
||||||
test-dev: proxy/ui_dist/placeholder.txt
|
test-dev:
|
||||||
go test -short ./proxy/... ./internal/...
|
go test -short ./...
|
||||||
staticcheck ./proxy/... ./internal/... || true
|
staticcheck ./... || true
|
||||||
|
|
||||||
test: proxy/ui_dist/placeholder.txt
|
test:
|
||||||
go test -short -count=1 ./proxy/... ./internal/...
|
go test -short -count=1 ./internal/...
|
||||||
|
|
||||||
# for CI - full test (takes longer)
|
# for CI - full test (takes longer)
|
||||||
test-all: proxy/ui_dist/placeholder.txt
|
test-all:
|
||||||
go test -race -count=1 ./proxy/... ./internal/...
|
go test -race -count=1 ./internal/...
|
||||||
|
|
||||||
ui/node_modules:
|
ui/node_modules:
|
||||||
cd ui-svelte && npm install
|
cd ui-svelte && npm install
|
||||||
@@ -41,6 +37,7 @@ ui/node_modules:
|
|||||||
# build react UI
|
# build react UI
|
||||||
ui: ui/node_modules
|
ui: ui/node_modules
|
||||||
cd ui-svelte && npm run build
|
cd ui-svelte && npm run build
|
||||||
|
touch internal/server/ui_dist/placeholder.txt
|
||||||
|
|
||||||
# Build OSX binary
|
# Build OSX binary
|
||||||
mac: ui
|
mac: ui
|
||||||
@@ -63,7 +60,7 @@ windows: ui
|
|||||||
@echo "Building Windows binary..."
|
@echo "Building Windows binary..."
|
||||||
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
|
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
|
||||||
|
|
||||||
# for testing proxy.Process
|
# for testing with real external processes
|
||||||
simple-responder:
|
simple-responder:
|
||||||
@echo "Building simple responder"
|
@echo "Building simple responder"
|
||||||
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 cmd/simple-responder/simple-responder.go
|
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 cmd/simple-responder/simple-responder.go
|
||||||
|
|||||||
@@ -0,0 +1,306 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
var loremWords = strings.Fields(
|
||||||
|
"Lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod tempor " +
|
||||||
|
"incididunt ut labore et dolore magna aliqua Ut enim ad minim veniam quis nostrud " +
|
||||||
|
"exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat Duis aute " +
|
||||||
|
"irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla " +
|
||||||
|
"pariatur Excepteur sint occaecat cupidatat non proident sunt in culpa qui officia " +
|
||||||
|
"deserunt mollit anim id est laborum Sed ut perspiciatis unde omnis iste natus error " +
|
||||||
|
"sit voluptatem accusantium doloremque laudantium totam rem aperiam eaque ipsa quae " +
|
||||||
|
"ab illo inventore veritatis et quasi architecto beatae vitae dicta sunt explicabo " +
|
||||||
|
"Nemo enim ipsam voluptatem quia voluptas sit aspernatur aut odit aut fugit",
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
flagListen = flag.String("listen", "localhost:9898", "listen address")
|
||||||
|
flagTokens = flag.Int("tokens", 1000, "number of tokens to return")
|
||||||
|
flagTPS = flag.Float64("tps", 75, "tokens per second")
|
||||||
|
flagLoad = flag.String("load", "0s", "simulated load duration (e.g. 2s, 500ms)")
|
||||||
|
)
|
||||||
|
|
||||||
|
type chunkDelta struct {
|
||||||
|
Role string `json:"role,omitempty"`
|
||||||
|
Content string `json:"content,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type chunkChoice struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Delta chunkDelta `json:"delta"`
|
||||||
|
FinishReason *string `json:"finish_reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type chatChunk struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []chunkChoice `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type completionMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type completionChoice struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Message completionMessage `json:"message"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type completionUsage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type chatCompletion struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []completionChoice `json:"choices"`
|
||||||
|
Usage completionUsage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func loremText(n int) string {
|
||||||
|
words := make([]string, n)
|
||||||
|
for i := range words {
|
||||||
|
words[i] = loremWords[i%len(loremWords)]
|
||||||
|
}
|
||||||
|
return strings.Join(words, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendChunk(w http.ResponseWriter, content string, finishReason *string) error {
|
||||||
|
chunk := chatChunk{
|
||||||
|
ID: "chatcmpl-fake",
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
Model: "fake-model",
|
||||||
|
Choices: []chunkChoice{
|
||||||
|
{
|
||||||
|
Index: 0,
|
||||||
|
Delta: chunkDelta{Content: content},
|
||||||
|
FinishReason: finishReason,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(chunk)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = fmt.Fprintf(w, "data: %s\n\n", data)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// startLoading runs the countdown log and closes ready when loadDur elapses.
|
||||||
|
// If loadDur is zero, ready is closed immediately.
|
||||||
|
func startLoading(loadDur time.Duration) <-chan struct{} {
|
||||||
|
ready := make(chan struct{})
|
||||||
|
if loadDur == 0 {
|
||||||
|
close(ready)
|
||||||
|
return ready
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
deadline := time.Now().Add(loadDur)
|
||||||
|
log.Printf("loading... %s remaining", loadDur.Round(time.Second))
|
||||||
|
ticker := time.NewTicker(time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
timer := time.NewTimer(loadDur)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
close(ready)
|
||||||
|
log.Printf("ready")
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if rem := time.Until(deadline).Round(time.Second); rem > 0 {
|
||||||
|
log.Printf("loading... %s remaining", rem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return ready
|
||||||
|
}
|
||||||
|
|
||||||
|
func healthHandler(ready <-chan struct{}) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
select {
|
||||||
|
case <-ready:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func chatHandler(ready <-chan struct{}) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "failed to read body", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
streaming := gjson.GetBytes(body, "stream").Bool()
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ready:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens := *flagTokens
|
||||||
|
tps := *flagTPS
|
||||||
|
if tps <= 0 {
|
||||||
|
tps = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if !streaming {
|
||||||
|
delay := time.Duration(float64(tokens) / tps * float64(time.Second))
|
||||||
|
select {
|
||||||
|
case <-time.After(delay):
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
text := loremText(tokens)
|
||||||
|
resp := chatCompletion{
|
||||||
|
ID: "chatcmpl-fake",
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
Model: "fake-model",
|
||||||
|
Choices: []completionChoice{
|
||||||
|
{
|
||||||
|
Index: 0,
|
||||||
|
Message: completionMessage{Role: "assistant", Content: text},
|
||||||
|
FinishReason: "stop",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Usage: completionUsage{
|
||||||
|
PromptTokens: 0,
|
||||||
|
CompletionTokens: tokens,
|
||||||
|
TotalTokens: tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
|
w.Header().Set("Connection", "keep-alive")
|
||||||
|
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send role delta first
|
||||||
|
first := chatChunk{
|
||||||
|
ID: "chatcmpl-fake",
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
Model: "fake-model",
|
||||||
|
Choices: []chunkChoice{
|
||||||
|
{Index: 0, Delta: chunkDelta{Role: "assistant"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if data, err := json.Marshal(first); err == nil {
|
||||||
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
interval := time.Duration(float64(time.Second) / tps)
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
stop := "stop"
|
||||||
|
for i := 0; i < tokens; i++ {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
}
|
||||||
|
|
||||||
|
word := loremWords[i%len(loremWords)]
|
||||||
|
if i < tokens-1 {
|
||||||
|
if err := sendChunk(w, word+" ", nil); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := sendChunk(w, word, &stop); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
loadDur, err := time.ParseDuration(*flagLoad)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("invalid -load value %q: %v", *flagLoad, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ready := startLoading(loadDur)
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/health", healthHandler(ready))
|
||||||
|
mux.HandleFunc("/v1/chat/completions", chatHandler(ready))
|
||||||
|
|
||||||
|
srv := &http.Server{
|
||||||
|
Addr: *flagListen,
|
||||||
|
Handler: mux,
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
log.Printf("listening on %s (tokens=%d tps=%.1f load=%s)",
|
||||||
|
*flagListen, *flagTokens, *flagTPS, loadDur)
|
||||||
|
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
|
log.Fatalf("server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
quit := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
<-quit
|
||||||
|
|
||||||
|
log.Println("shutting down...")
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := srv.Shutdown(ctx); err != nil {
|
||||||
|
log.Printf("shutdown error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -8,9 +8,9 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
"github.com/mostlygeek/llama-swap/internal/perf"
|
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func printSysStat(s perf.SysStat) {
|
func printSysStat(s perf.SysStat) {
|
||||||
|
|||||||
@@ -0,0 +1,96 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
prompt := flag.String("prompt", "Write a few sentences about the history of computing.", "user message sent to each model")
|
||||||
|
maxTokens := flag.Int("max-tokens", 256, "max_tokens per request")
|
||||||
|
flag.Usage = func() {
|
||||||
|
fmt.Fprintf(os.Stderr, "Usage: %s [flags] <base-url> <model> [model...]\n", os.Args[0])
|
||||||
|
fmt.Fprintf(os.Stderr, "Example: %s -max-tokens 400 http://localhost:8080 A B C D\n\n", os.Args[0])
|
||||||
|
flag.PrintDefaults()
|
||||||
|
}
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
args := flag.Args()
|
||||||
|
if len(args) < 2 {
|
||||||
|
flag.Usage()
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := args[0]
|
||||||
|
models := args[1:]
|
||||||
|
|
||||||
|
m := newModel(models)
|
||||||
|
prog := tea.NewProgram(m, tea.WithAltScreen(), tea.WithMouseCellMotion())
|
||||||
|
|
||||||
|
// Chain of triggers ensures requests are sent in the order provided.
|
||||||
|
triggers := make([]chan struct{}, len(models))
|
||||||
|
for i := range triggers {
|
||||||
|
triggers[i] = make(chan struct{}, 1)
|
||||||
|
}
|
||||||
|
triggers[0] <- struct{}{}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
for i, name := range models {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(idx int, mdl string) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
<-triggers[idx]
|
||||||
|
|
||||||
|
reqStart := time.Now()
|
||||||
|
prog.Send(statusMsg{idx: idx, status: statusStreaming})
|
||||||
|
|
||||||
|
if idx+1 < len(triggers) {
|
||||||
|
triggers[idx+1] <- struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
err := sendRequest(baseURL, mdl, *prompt, *maxTokens, idx, func(i int, text string) {
|
||||||
|
prog.Send(deltaMsg{idx: i, text: text})
|
||||||
|
})
|
||||||
|
|
||||||
|
elapsed := time.Since(reqStart)
|
||||||
|
if err != nil {
|
||||||
|
prog.Send(statusMsg{idx: idx, status: statusError, elapsed: elapsed, err: err})
|
||||||
|
} else {
|
||||||
|
prog.Send(statusMsg{idx: idx, status: statusDone, elapsed: elapsed})
|
||||||
|
}
|
||||||
|
}(i, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := prog.Run(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
printSummary(m, start)
|
||||||
|
}
|
||||||
|
|
||||||
|
func printSummary(m *model, start time.Time) {
|
||||||
|
fmt.Println("Summary:")
|
||||||
|
for _, p := range m.panels {
|
||||||
|
switch p.status {
|
||||||
|
case statusError:
|
||||||
|
fmt.Printf(" [%d] %-20s ERROR elapsed=%s err=%v\n",
|
||||||
|
p.idx, p.model, p.elapsed.Round(time.Millisecond), p.err)
|
||||||
|
case statusDone:
|
||||||
|
fmt.Printf(" [%d] %-20s done elapsed=%s\n",
|
||||||
|
p.idx, p.model, p.elapsed.Round(time.Millisecond))
|
||||||
|
default:
|
||||||
|
fmt.Printf(" [%d] %-20s %s\n", p.idx, p.model, p.status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Printf("all done in %s\n", time.Since(start).Round(time.Millisecond))
|
||||||
|
}
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// deltaSink receives streamed text fragments for a given model panel.
|
||||||
|
type deltaSink func(idx int, text string)
|
||||||
|
|
||||||
|
type streamDelta struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
ReasoningContent string `json:"reasoning_content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type streamChoice struct {
|
||||||
|
Delta streamDelta `json:"delta"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type streamChunk struct {
|
||||||
|
Choices []streamChoice `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendRequest streams a chat completion and forwards each content/reasoning
|
||||||
|
// delta to sink. Reasoning and assistant content are emitted into the same
|
||||||
|
// stream so they render together.
|
||||||
|
func sendRequest(baseURL, model, prompt string, maxTokens, idx int, sink deltaSink) error {
|
||||||
|
payload := map[string]any{
|
||||||
|
"model": model,
|
||||||
|
"messages": []map[string]string{
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
},
|
||||||
|
"max_tokens": maxTokens,
|
||||||
|
"stream": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.Post(baseURL+"/v1/chat/completions", "application/json", bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
b, _ := io.ReadAll(resp.Body)
|
||||||
|
return fmt.Errorf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(b)))
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if !strings.HasPrefix(line, "data:") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||||
|
if data == "" || data == "[DONE]" {
|
||||||
|
if data == "[DONE]" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var chunk streamChunk
|
||||||
|
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, c := range chunk.Choices {
|
||||||
|
if c.Delta.ReasoningContent != "" {
|
||||||
|
sink(idx, c.Delta.ReasoningContent)
|
||||||
|
}
|
||||||
|
if c.Delta.Content != "" {
|
||||||
|
sink(idx, c.Delta.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return scanner.Err()
|
||||||
|
}
|
||||||
@@ -0,0 +1,343 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/charmbracelet/bubbles/viewport"
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
)
|
||||||
|
|
||||||
|
type panelStatus int
|
||||||
|
|
||||||
|
const (
|
||||||
|
statusWaiting panelStatus = iota
|
||||||
|
statusStreaming
|
||||||
|
statusDone
|
||||||
|
statusError
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s panelStatus) String() string {
|
||||||
|
switch s {
|
||||||
|
case statusStreaming:
|
||||||
|
return "streaming"
|
||||||
|
case statusDone:
|
||||||
|
return "done"
|
||||||
|
case statusError:
|
||||||
|
return "error"
|
||||||
|
default:
|
||||||
|
return "waiting"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// deltaMsg appends streamed text to a panel.
|
||||||
|
type deltaMsg struct {
|
||||||
|
idx int
|
||||||
|
text string
|
||||||
|
}
|
||||||
|
|
||||||
|
// statusMsg updates a panel's lifecycle state.
|
||||||
|
type statusMsg struct {
|
||||||
|
idx int
|
||||||
|
status panelStatus
|
||||||
|
elapsed time.Duration
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type panel struct {
|
||||||
|
idx int
|
||||||
|
model string
|
||||||
|
color lipgloss.Color
|
||||||
|
status panelStatus
|
||||||
|
buf strings.Builder
|
||||||
|
elapsed time.Duration
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
minPanelWidth = 28
|
||||||
|
maxCols = 3
|
||||||
|
panelHeight = 9 // total box height including border + header
|
||||||
|
)
|
||||||
|
|
||||||
|
type model struct {
|
||||||
|
panels []*panel
|
||||||
|
focused int
|
||||||
|
vp viewport.Model
|
||||||
|
width int
|
||||||
|
height int
|
||||||
|
cols int
|
||||||
|
pw int // inner panel content width
|
||||||
|
ready bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newModel(models []string) *model {
|
||||||
|
// Assign a stable color per unique model name (by first appearance).
|
||||||
|
colorOf := map[string]lipgloss.Color{}
|
||||||
|
panels := make([]*panel, len(models))
|
||||||
|
for i, m := range models {
|
||||||
|
c, ok := colorOf[m]
|
||||||
|
if !ok {
|
||||||
|
c = modelPalette[len(colorOf)%len(modelPalette)]
|
||||||
|
colorOf[m] = c
|
||||||
|
}
|
||||||
|
panels[i] = &panel{idx: i, model: m, color: c, status: statusWaiting}
|
||||||
|
}
|
||||||
|
return &model{panels: panels, focused: 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) Init() tea.Cmd { return nil }
|
||||||
|
|
||||||
|
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case tea.WindowSizeMsg:
|
||||||
|
m.width = msg.Width
|
||||||
|
m.height = msg.Height
|
||||||
|
m.relayout()
|
||||||
|
m.refreshViewport(true)
|
||||||
|
return m, nil
|
||||||
|
|
||||||
|
case tea.KeyMsg:
|
||||||
|
switch msg.String() {
|
||||||
|
case "q", "ctrl+c", "esc":
|
||||||
|
return m, tea.Quit
|
||||||
|
case "tab", "right", "l":
|
||||||
|
m.setFocus(m.focused + 1)
|
||||||
|
return m, nil
|
||||||
|
case "shift+tab", "left", "h":
|
||||||
|
m.setFocus(m.focused - 1)
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.vp, cmd = m.vp.Update(msg)
|
||||||
|
return m, cmd
|
||||||
|
|
||||||
|
case tea.MouseMsg:
|
||||||
|
if msg.Action == tea.MouseActionPress && msg.Button == tea.MouseButtonLeft {
|
||||||
|
if idx, ok := m.panelAt(msg.X, msg.Y); ok {
|
||||||
|
m.setFocus(idx)
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.vp, cmd = m.vp.Update(msg)
|
||||||
|
return m, cmd
|
||||||
|
|
||||||
|
case deltaMsg:
|
||||||
|
p := m.panels[msg.idx]
|
||||||
|
p.buf.WriteString(msg.text)
|
||||||
|
if msg.idx == m.focused {
|
||||||
|
atBottom := m.vp.AtBottom()
|
||||||
|
m.refreshViewport(false)
|
||||||
|
if atBottom {
|
||||||
|
m.vp.GotoBottom()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
|
||||||
|
case statusMsg:
|
||||||
|
p := m.panels[msg.idx]
|
||||||
|
p.status = msg.status
|
||||||
|
p.elapsed = msg.elapsed
|
||||||
|
p.err = msg.err
|
||||||
|
if msg.err != nil {
|
||||||
|
errTxt := lipgloss.NewStyle().Foreground(lipgloss.Color("196")).Render("\n" + msg.err.Error())
|
||||||
|
p.buf.WriteString(errTxt)
|
||||||
|
if msg.idx == m.focused {
|
||||||
|
m.refreshViewport(false)
|
||||||
|
m.vp.GotoBottom()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) setFocus(idx int) {
|
||||||
|
if len(m.panels) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if idx < 0 {
|
||||||
|
idx = len(m.panels) - 1
|
||||||
|
}
|
||||||
|
if idx >= len(m.panels) {
|
||||||
|
idx = 0
|
||||||
|
}
|
||||||
|
if idx == m.focused {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.focused = idx
|
||||||
|
m.refreshViewport(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// relayout recomputes grid columns and panel/viewport dimensions.
|
||||||
|
func (m *model) relayout() {
|
||||||
|
if m.width < minPanelWidth+4 {
|
||||||
|
m.cols = 1
|
||||||
|
} else {
|
||||||
|
m.cols = m.width / (minPanelWidth + 2)
|
||||||
|
if m.cols > maxCols {
|
||||||
|
m.cols = maxCols
|
||||||
|
}
|
||||||
|
if m.cols > len(m.panels) {
|
||||||
|
m.cols = len(m.panels)
|
||||||
|
}
|
||||||
|
if m.cols < 1 {
|
||||||
|
m.cols = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// inner content width: total width / cols, minus borders+padding (4) and gap.
|
||||||
|
boxOuter := m.width/m.cols - 1
|
||||||
|
m.pw = boxOuter - 4
|
||||||
|
if m.pw < 8 {
|
||||||
|
m.pw = 8
|
||||||
|
}
|
||||||
|
|
||||||
|
m.vp = viewport.New(m.pw, panelHeight-2)
|
||||||
|
m.ready = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) refreshViewport(reset bool) {
|
||||||
|
if !m.ready || len(m.panels) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
content := lipgloss.NewStyle().Width(m.pw).Render(m.panels[m.focused].buf.String())
|
||||||
|
m.vp.SetContent(content)
|
||||||
|
if reset {
|
||||||
|
m.vp.GotoBottom()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// panelAt maps screen coordinates to a panel index based on the grid layout.
|
||||||
|
func (m *model) panelAt(x, y int) (int, bool) {
|
||||||
|
if m.cols == 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
boxOuterW := m.width/m.cols + 1
|
||||||
|
col := x / boxOuterW
|
||||||
|
row := y / panelHeight
|
||||||
|
idx := row*m.cols + col
|
||||||
|
if col < m.cols && idx >= 0 && idx < len(m.panels) {
|
||||||
|
return idx, true
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) View() string {
|
||||||
|
if !m.ready {
|
||||||
|
return "loading..."
|
||||||
|
}
|
||||||
|
|
||||||
|
rows := []string{}
|
||||||
|
var current []string
|
||||||
|
for i, p := range m.panels {
|
||||||
|
current = append(current, m.renderPanel(p, i == m.focused))
|
||||||
|
if len(current) == m.cols {
|
||||||
|
rows = append(rows, lipgloss.JoinHorizontal(lipgloss.Top, current...))
|
||||||
|
current = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(current) > 0 {
|
||||||
|
rows = append(rows, lipgloss.JoinHorizontal(lipgloss.Top, current...))
|
||||||
|
}
|
||||||
|
|
||||||
|
grid := lipgloss.JoinVertical(lipgloss.Left, rows...)
|
||||||
|
footer := lipgloss.NewStyle().Faint(true).Render(
|
||||||
|
"tab/click: focus panel • wheel/↑↓/pgup/pgdn: scroll focused • q: quit")
|
||||||
|
return grid + "\n" + footer
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelPalette gives each panel a distinct, readable color for its name.
|
||||||
|
var modelPalette = []lipgloss.Color{
|
||||||
|
"39", // blue
|
||||||
|
"213", // magenta
|
||||||
|
"214", // orange
|
||||||
|
"45", // cyan
|
||||||
|
"141", // purple
|
||||||
|
"203", // salmon
|
||||||
|
"82", // lime
|
||||||
|
"227", // light yellow
|
||||||
|
}
|
||||||
|
|
||||||
|
func statusColor(s panelStatus) lipgloss.Color {
|
||||||
|
switch s {
|
||||||
|
case statusStreaming:
|
||||||
|
return lipgloss.Color("220") // yellow - active
|
||||||
|
case statusDone:
|
||||||
|
return lipgloss.Color("42") // green - success
|
||||||
|
case statusError:
|
||||||
|
return lipgloss.Color("196") // red - error
|
||||||
|
default:
|
||||||
|
return lipgloss.Color("244") // gray - waiting
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) renderPanel(p *panel, focused bool) string {
|
||||||
|
border := lipgloss.RoundedBorder()
|
||||||
|
if focused {
|
||||||
|
border = lipgloss.DoubleBorder()
|
||||||
|
}
|
||||||
|
style := lipgloss.NewStyle().
|
||||||
|
Border(border).
|
||||||
|
BorderForeground(lipgloss.Color("240"))
|
||||||
|
|
||||||
|
statusTxt := p.status.String()
|
||||||
|
if p.elapsed > 0 {
|
||||||
|
statusTxt += " " + p.elapsed.Round(time.Millisecond).String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Header: model name (left, model color) + status/timer (right, status color).
|
||||||
|
name := fmt.Sprintf("[%d] %s", p.idx, p.model)
|
||||||
|
gap := m.pw - lipgloss.Width(name) - lipgloss.Width(statusTxt)
|
||||||
|
if gap < 1 {
|
||||||
|
name = truncate(name, m.pw-lipgloss.Width(statusTxt)-1)
|
||||||
|
gap = m.pw - lipgloss.Width(name) - lipgloss.Width(statusTxt)
|
||||||
|
}
|
||||||
|
if gap < 1 {
|
||||||
|
gap = 1
|
||||||
|
}
|
||||||
|
header := lipgloss.NewStyle().Bold(true).Foreground(p.color).Render(name) +
|
||||||
|
strings.Repeat(" ", gap) +
|
||||||
|
lipgloss.NewStyle().Foreground(statusColor(p.status)).Render(statusTxt)
|
||||||
|
|
||||||
|
var bodyLines string
|
||||||
|
if focused {
|
||||||
|
bodyLines = m.vp.View()
|
||||||
|
} else {
|
||||||
|
bodyLines = tailLines(p.buf.String(), m.pw, panelHeight-2)
|
||||||
|
}
|
||||||
|
|
||||||
|
content := lipgloss.JoinVertical(lipgloss.Left, header, bodyLines)
|
||||||
|
return style.Width(m.pw).Height(panelHeight - 2).Render(content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncate(s string, w int) string {
|
||||||
|
if w <= 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if lipgloss.Width(s) <= w {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
r := []rune(s)
|
||||||
|
if len(r) > w {
|
||||||
|
r = r[:w]
|
||||||
|
}
|
||||||
|
return string(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tailLines wraps text to width w and returns the last n lines.
|
||||||
|
func tailLines(s string, w, n int) string {
|
||||||
|
wrapped := lipgloss.NewStyle().Width(w).Render(s)
|
||||||
|
lines := strings.Split(wrapped, "\n")
|
||||||
|
if len(lines) > n {
|
||||||
|
lines = lines[len(lines)-n:]
|
||||||
|
}
|
||||||
|
for len(lines) < n {
|
||||||
|
lines = append(lines, "")
|
||||||
|
}
|
||||||
|
return strings.Join(lines, "\n")
|
||||||
|
}
|
||||||
+13
-4
@@ -281,7 +281,7 @@ models:
|
|||||||
b: 2
|
b: 2
|
||||||
# objects can contain complex types with macro substitution
|
# objects can contain complex types with macro substitution
|
||||||
# becomes: c: [0.7, false, "model: llama"]
|
# becomes: c: [0.7, false, "model: llama"]
|
||||||
c: [ "${temp}", false, "model: ${MODEL_ID}" ]
|
c: ["${temp}", false, "model: ${MODEL_ID}"]
|
||||||
|
|
||||||
# concurrencyLimit: overrides the allowed number of active parallel requests to a model
|
# concurrencyLimit: overrides the allowed number of active parallel requests to a model
|
||||||
# - optional, default: 0
|
# - optional, default: 0
|
||||||
@@ -347,11 +347,20 @@ models:
|
|||||||
# matrix: run concurrent models with a solver-based swap DSL
|
# matrix: run concurrent models with a solver-based swap DSL
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
#
|
#
|
||||||
# Note:
|
# Matrix or Groups?
|
||||||
# A config must use either a matrix or legacy groups, not both. A configuration error
|
#
|
||||||
# will occur if both are defined. Configuration examples for legacy Groups can be found:
|
# Groups are available and fully supported. The syntax may be easier to use
|
||||||
|
# for simple use cases.
|
||||||
|
#
|
||||||
|
# Documentation can be found here:
|
||||||
# https://github.com/mostlygeek/llama-swap/blob/40e39f7/config.example.yaml#L334-L396
|
# https://github.com/mostlygeek/llama-swap/blob/40e39f7/config.example.yaml#L334-L396
|
||||||
#
|
#
|
||||||
|
# A config can only use a matrix (recommended) or groups. A configuration error
|
||||||
|
# will occur if both are defined. Groups is legacy but is fully supported with
|
||||||
|
# no plans to deprecate it.
|
||||||
|
#
|
||||||
|
# ~~~~~
|
||||||
|
#
|
||||||
# The matrix declares valid combinations of models that can run concurrently.
|
# The matrix declares valid combinations of models that can run concurrently.
|
||||||
# When a model is requested, the solver finds the cheapest way to make it
|
# When a model is requested, the solver finds the cheapest way to make it
|
||||||
# available by evicting as few (and least costly) running models as possible.
|
# available by evicting as few (and least costly) running models as possible.
|
||||||
|
|||||||
@@ -2,10 +2,6 @@ ARG BASE_IMAGE=ghcr.io/ggml-org/llama.cpp
|
|||||||
ARG BASE_TAG=server-cuda
|
ARG BASE_TAG=server-cuda
|
||||||
FROM ${BASE_IMAGE}:${BASE_TAG}
|
FROM ${BASE_IMAGE}:${BASE_TAG}
|
||||||
|
|
||||||
# has to be after the FROM
|
|
||||||
# TARGETARCH is auto-set by `docker buildx build --platform …` (amd64/arm64);
|
|
||||||
# falls back to amd64 when an older `docker build` runs without buildx.
|
|
||||||
ARG TARGETARCH=amd64
|
|
||||||
ARG LS_VER=170
|
ARG LS_VER=170
|
||||||
ARG LS_REPO=mostlygeek/llama-swap
|
ARG LS_REPO=mostlygeek/llama-swap
|
||||||
|
|
||||||
@@ -37,9 +33,15 @@ WORKDIR /app
|
|||||||
ENV PATH="/app:${PATH}"
|
ENV PATH="/app:${PATH}"
|
||||||
|
|
||||||
RUN \
|
RUN \
|
||||||
curl -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_${TARGETARCH}.tar.gz" && \
|
set -eux; \
|
||||||
tar -zxf "llama-swap_${LS_VER}_linux_${TARGETARCH}.tar.gz" && \
|
case "$(uname -m)" in \
|
||||||
rm "llama-swap_${LS_VER}_linux_${TARGETARCH}.tar.gz"
|
x86_64) ARCH=amd64 ;; \
|
||||||
|
aarch64) ARCH=arm64 ;; \
|
||||||
|
*) echo "unsupported arch: $(uname -m)" >&2; exit 1 ;; \
|
||||||
|
esac; \
|
||||||
|
curl --fail -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_${ARCH}.tar.gz" && \
|
||||||
|
tar -zxf "llama-swap_${LS_VER}_linux_${ARCH}.tar.gz" && \
|
||||||
|
rm "llama-swap_${LS_VER}_linux_${ARCH}.tar.gz"
|
||||||
|
|
||||||
COPY --chown=$UID:$GID config.example.yaml /app/config.yaml
|
COPY --chown=$UID:$GID config.example.yaml /app/config.yaml
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,264 @@
|
|||||||
|
# New Router Migration TODO
|
||||||
|
|
||||||
|
This document tracks the work needed for [cmd/newrouter/main.go](../cmd/newrouter/main.go) and [internal/router/](../internal/router/) to reach feature parity with the legacy entrypoint at [llama-swap.go](../llama-swap.go) plus [proxy/proxymanager.go](../proxy/proxymanager.go).
|
||||||
|
|
||||||
|
The work is split into phases so each can land and be tested independently. Earlier phases unblock later ones.
|
||||||
|
|
||||||
|
## Current state (newrouter)
|
||||||
|
|
||||||
|
`cmd/newrouter` already supports:
|
||||||
|
|
||||||
|
- Loading config via `-config`
|
||||||
|
- Selecting Matrix vs Group router based on config
|
||||||
|
- Peer routing fallback
|
||||||
|
- Plain HTTP listen (`-listen`)
|
||||||
|
- Graceful shutdown on `SIGINT` / `SIGTERM`
|
||||||
|
- Model extraction from JSON body, query string, and form bodies (see [router.go:88](../internal/router/router.go#L88))
|
||||||
|
- `Server.ServeHTTP` dispatches a single request to peer or local router based on the requested model
|
||||||
|
|
||||||
|
Everything below is missing or only partially implemented.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Phase 1 — Package relocation -- Completed.
|
||||||
|
|
||||||
|
Goal: move shared infrastructure packages out from under `proxy/` so the new router does not depend on the legacy proxy tree. This is a prerequisite for retiring `proxy/` in Phase 8.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Phase 2 — Server lifecycle parity -- Completed.
|
||||||
|
|
||||||
|
Goal: make `cmd/newrouter` a drop-in replacement for the legacy binary's process model, _without_ yet adding any extra HTTP endpoints.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Phase 3 — `internal/chain` package -- Completed.
|
||||||
|
|
||||||
|
API: `chain.New(mws...).Then(final)` for ServeMux registration; `Append` returns an extended Chain without mutating the receiver, so a base stack (auth/CORS) can be reused across many routes with per-route additions.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Phase 4 — `internal/server` package scaffolding (ProxyManager replacement) -- Completed.
|
||||||
|
|
||||||
|
Goal: build the [internal/server](../internal/server/) package so it can stand in for [proxy.ProxyManager](../proxy/proxymanager.go#L67) — the mux, lifecycle, model dispatch, custom endpoints, request filters, auth/CORS, and upstream passthrough. After this phase, `cmd/newrouter/main.go` constructs a `server.Server` instead of a bare `router.Server`.
|
||||||
|
|
||||||
|
The legacy `ProxyManager` collapses three concerns into one struct: the HTTP mux, the model→process router, and the cross-cutting services (loggers, metrics, perf, inflight counter, version). The new layout keeps the `router.Router` implementations focused on model dispatch and lets `internal/server.Server` own the mux and all cross-cutting middleware. `server.Server` builds the `local` and `peer` routers directly and dispatches between them itself, so it fully **supersedes `internal/router.Server`** — see the cleanup item below.
|
||||||
|
|
||||||
|
The phase is split into sub-phases that can land and be tested independently:
|
||||||
|
|
||||||
|
| Sub-phase | Scope |
|
||||||
|
| --------- | -------------------------------------------------------------------------- |
|
||||||
|
| 4a | package scaffolding — struct, `New`, `ServeHTTP`, `Shutdown`, model routes |
|
||||||
|
| 4b | custom (non-model-dispatched) HTTP endpoints |
|
||||||
|
| 4c | request-body filter middleware |
|
||||||
|
| 4d | auth & CORS middleware |
|
||||||
|
| 4e | upstream passthrough |
|
||||||
|
|
||||||
|
The package is split by concern across stub files already in place:
|
||||||
|
|
||||||
|
| File | Responsibility | Filled in by |
|
||||||
|
| ------------ | ----------------------------------------------- | ---------------------- |
|
||||||
|
| `server.go` | `Server` struct, `New`, `ServeHTTP`, `Shutdown` | 4a |
|
||||||
|
| `log.go` | `muxlog` combined logger; `/logs` handlers | 4a |
|
||||||
|
| `auth.go` | `CreateAuthMiddleware` | 4d |
|
||||||
|
| `filters.go` | request-body filter middleware | 4c |
|
||||||
|
| `api.go` | llama-swap-specific API handlers | 4b / Phase 5 / Phase 6 |
|
||||||
|
| `ui.go` | embedded UI serving | Phase 7 |
|
||||||
|
|
||||||
|
### Phase 4a — package scaffolding -- Completed.
|
||||||
|
|
||||||
|
`server.Server` owns the mux, the `local`/`peer` routers, `muxlog`, and a
|
||||||
|
shutdown context. `New` builds the routers, registers all model-dispatched
|
||||||
|
routes on a stdlib `http.ServeMux`, and wraps the mux with the global CORS
|
||||||
|
middleware. `localPeerHandler` resolves the model once via `router.FetchModel`
|
||||||
|
and dispatches to `local` or `peer`. `Shutdown` stops both routers in parallel
|
||||||
|
and is idempotent. `cmd/newrouter/main.go` now constructs `server.New(...)`;
|
||||||
|
`internal/router/server.go` and `server_test.go` were removed as dead code.
|
||||||
|
|
||||||
|
### Phase 4b — Custom HTTP endpoints -- Completed.
|
||||||
|
|
||||||
|
`GET /v1/models` (local + peer models, aliases, metadata), `GET /health`,
|
||||||
|
`GET /wol-health`, and `GET /` → `/ui` are registered. `GET /favicon.ico` is
|
||||||
|
deferred to Phase 7 since it requires the embedded UI filesystem.
|
||||||
|
|
||||||
|
### Phase 4c — Request-body filters -- Completed.
|
||||||
|
|
||||||
|
`CreateFilterMiddleware` (in `filters.go`) applies `UseModelName`,
|
||||||
|
`StripParams`, `SetParams`, and `SetParamsByID` to JSON requests, then
|
||||||
|
re-attaches the body with `Content-Length` / `Transfer-Encoding` cleanup.
|
||||||
|
|
||||||
|
### Phase 4d — Auth & CORS -- Completed.
|
||||||
|
|
||||||
|
`CreateAuthMiddleware` validates API keys (Bearer / Basic / `x-api-key`) and
|
||||||
|
strips the headers before upstream. `CreateCORSMiddleware` answers OPTIONS
|
||||||
|
preflight; `/v1/models` echoes the `Origin`.
|
||||||
|
|
||||||
|
### Phase 4e — Upstream passthrough -- Completed.
|
||||||
|
|
||||||
|
`GET /upstream` → `/ui/models`, and `/upstream/<model>/<path>` proxies to the
|
||||||
|
resolved model with multi-segment name resolution, canonical-form redirect
|
||||||
|
(301/308), and prefix stripping.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Phase 5 — Operations endpoints -- Completed.
|
||||||
|
|
||||||
|
A new `router.LocalRouter` interface embeds `Router` and adds `RunningModels()`
|
||||||
|
and `Unload(timeout, models...)`, both implemented once on `baseRouter` so
|
||||||
|
`Group` and `Matrix` share them — the legacy matrix/group divergence at
|
||||||
|
[proxymanager.go:1167](../proxy/proxymanager.go#L1167) collapses since
|
||||||
|
`baseRouter` already unifies process storage. `Peer` does not implement it;
|
||||||
|
`Server.local` is typed `LocalRouter`, `Server.peer` stays `Router`.
|
||||||
|
|
||||||
|
`GET /unload` stops every local process; `GET /running` lists non-stopped
|
||||||
|
processes joined against config for `cmd`/`proxy`/`ttl`/`name`/`description`.
|
||||||
|
`startPreload` fires a background `GET /` at each `Hooks.OnStartup.Preload`
|
||||||
|
model and emits `shared.ModelPreloadedEvent`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Phase 6 — Metrics, perf, and SSE -- Completed.
|
||||||
|
|
||||||
|
`perf.Monitor` is created and started in `cmd/newrouter/main.go` (it outlives
|
||||||
|
config reloads via `UpdateConfig`) and passed into `server.New`. `GET /metrics`
|
||||||
|
serves `perf.Monitor.MetricsHandler()` output, 503 when disabled.
|
||||||
|
|
||||||
|
`internal/process` emits `shared.ProcessStateChangeEvent` from `setState`.
|
||||||
|
`server.inflightCounter` (atomic) + `CreateInflightMiddleware` track
|
||||||
|
model-dispatched requests and emit `InFlightRequestsEvent`. `metricsMonitor`
|
||||||
|
(in `metrics.go`) parses token usage from upstream responses via
|
||||||
|
`CreateMetricsMiddleware`.
|
||||||
|
|
||||||
|
The `/api` group (API-key protected) is registered: `POST /api/models/unload`,
|
||||||
|
`POST /api/models/unload/{model...}`, `GET /api/events` (SSE: `modelStatus` /
|
||||||
|
`logData` / `metrics` / `inflight`), `GET /api/metrics`, `GET /api/performance`
|
||||||
|
(`?after=` RFC3339 filter), `GET /api/version`. `GET /api/captures/{id}`
|
||||||
|
returns 501 until 6f.
|
||||||
|
|
||||||
|
### Phase 6f — Request/response captures -- Completed.
|
||||||
|
|
||||||
|
`proxy/cache` moved to `internal/cache`. `metricsMonitor` stores zstd+CBOR
|
||||||
|
`ReqRespCapture` records in a sized `cache.Cache` (`captureBuffer` MB, 0
|
||||||
|
disables). `CreateMetricsMiddleware` buffers request body/headers before
|
||||||
|
dispatch; `record` builds the capture per a `captureFieldsByPath` table
|
||||||
|
(`captures.go`) that trims large audio/image payloads, defaulting JSON routes
|
||||||
|
to `captureAll`. `GET /api/captures/{id}` decompresses and returns the capture;
|
||||||
|
`getMetrics` resolves `HasCapture` against the cache.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Phase 7 — UI serving -- Completed.
|
||||||
|
|
||||||
|
`internal/server/ui.go` embeds `ui_dist` and serves it. `GET /ui/` is
|
||||||
|
brotli/gzip-aware via `serveCompressedFile`; unknown paths without a file
|
||||||
|
extension fall back to `index.html` for SPA routing. `GET /favicon.ico` serves
|
||||||
|
from the same embedded FS. The Makefile `ui` target copies the vite build into
|
||||||
|
`internal/server/ui_dist`; a committed `placeholder.txt` keeps the embed valid
|
||||||
|
before a build runs.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Phase 8a - Review Part I
|
||||||
|
|
||||||
|
- [x] All functionality from the proxy package has been migrated in the above phases — with the remaining gaps listed in Phase 8b
|
||||||
|
- [x] Test coverage at or exceeds the level from the proxy package — `internal/server` now at 76.6% vs 73.9% (`proxy`)
|
||||||
|
|
||||||
|
### Findings
|
||||||
|
|
||||||
|
**Gap 1 — Request logging middleware missing -- Resolved.**
|
||||||
|
|
||||||
|
`CreateRequestLogMiddleware` ([log.go](../internal/server/log.go)) records one
|
||||||
|
access-log line per request to `s.proxylog` in the legacy format
|
||||||
|
`clientIP "METHOD PATH PROTO" status bodySize "UA" duration`, skipping
|
||||||
|
`/wol-health`, `/api/performance`, and `/metrics`. A `statusRecorder` captures
|
||||||
|
the status/body size (forwarding `Flush` for SSE) and `clientIP` honours
|
||||||
|
`X-Forwarded-For` / `X-Real-IP`. It is wired as the outermost middleware in
|
||||||
|
`routes()`, wrapping the CORS layer.
|
||||||
|
|
||||||
|
**Gap 2 — Per-model log streaming not supported -- Resolved **
|
||||||
|
|
||||||
|
`Server.getLogger` ([log.go:50](../internal/server/log.go#L50)) only handles `""`, `"proxy"`, and `"upstream"`. The legacy `ProxyManager.getLogger` ([proxymanager_loghandlers.go:92](../proxy/proxymanager_loghandlers.go#L92)) additionally resolves a model ID against the active process groups / matrix and returns that process's logger. Callers of `GET /logs/stream/<modelID>` will get a 400 instead of the model's live log stream.
|
||||||
|
|
||||||
|
**Gap 3 — `UseModelName` not applied to multipart form endpoints -- Resolved.**
|
||||||
|
|
||||||
|
`CreateFormFilterMiddleware` ([filters.go](../internal/server/filters.go)) parses
|
||||||
|
`multipart/form-data` requests, rewrites the `model` field with `UseModelName`,
|
||||||
|
reconstructs the body via `rewriteMultipartModel`, and re-attaches it with
|
||||||
|
`Content-Type` / `Content-Length` cleanup. It runs in `modelChain` after the
|
||||||
|
JSON `filterMW`; each is a no-op for the other's Content-Type. Audio
|
||||||
|
transcription (`/v1/audio/transcriptions`) and image edit (`/v1/images/edits`)
|
||||||
|
now honour `use_model_name`.
|
||||||
|
|
||||||
|
**Coverage gaps (0 % functions) -- Resolved.**
|
||||||
|
|
||||||
|
The functions previously at 0 % (`handleListModels`, `handleMetrics`,
|
||||||
|
`handleRootRedirect`, `handleUpstreamRedirect`, `handleUpstream`,
|
||||||
|
`findModelInPath`, `handleAPICapture`, `handleAPIUnloadAll`,
|
||||||
|
`handleAPIUnloadModel`, `CreateAuthMiddleware`, `extractAPIKey`,
|
||||||
|
`handleLogStream`, `applyFilters`, `decompressBody`, `filterAcceptEncoding`,
|
||||||
|
`handleUI`, `handleFavicon`) now have tests across `auth_test.go`, `api_test.go`,
|
||||||
|
`filters_test.go`, `log_test.go`, and `extras_test.go`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Phase 8b - Fill gaps discovered in Phase 8a
|
||||||
|
|
||||||
|
- [x] **Add request-log middleware** — `CreateRequestLogMiddleware` ([log.go](../internal/server/log.go)) records `clientIP "METHOD PATH PROTO" status bodySize "UA" duration` to `s.proxylog`, skips `/wol-health` / `/api/performance` / `/metrics`, and is wired as the outermost middleware in `routes()`.
|
||||||
|
- [x] **Extend `getLogger` with model-ID resolution** — add a `default:` branch to `Server.getLogger` ([log.go:50](../internal/server/log.go#L50)) that resolves the ID via `s.local` (using a new `LocalRouter.GetProcess(name)` method or equivalent) and returns that process's `Logger()`. Match the fallback behaviour: return a 400 with `"invalid logger. Use 'proxy', 'upstream' or a model's ID"` when not found.
|
||||||
|
- [x] **`UseModelName` rewrite for multipart endpoints** — `CreateFormFilterMiddleware` parses `multipart/form-data`, rewrites the `model` field according to `UseModelName`, reconstructs the body, and updates `Content-Type` / `Content-Length`. It is wired into `modelChain` after the JSON filter.
|
||||||
|
- [x] **Raise test coverage to ≥ 74 %** — `internal/server` now at 76.1%; tests added for every 0 % function across `auth_test.go`, `api_test.go`, `filters_test.go`, `log_test.go`, and `extras_test.go`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Phase 8c - Review Part II (entrypoint comparison)
|
||||||
|
|
||||||
|
A second pass comparing [cmd/newrouter/main.go](../cmd/newrouter/main.go) against
|
||||||
|
the legacy [llama-swap.go](../llama-swap.go) + [proxy.New](../proxy/proxymanager.go#L104)
|
||||||
|
surfaced four more gaps, all in logger setup.
|
||||||
|
|
||||||
|
**Gap 4 — `LogToStdout` config ignored -- Resolved.**
|
||||||
|
|
||||||
|
`cmd/newrouter/main.go` previously hardcoded `proxyLog` / `upstreamLog` to
|
||||||
|
`os.Stdout`, and the old `muxlog()` helper built a Monitor that nothing wrote
|
||||||
|
into — so `logToStdout` had no effect and `/logs` (combined history) was always
|
||||||
|
empty. `server.NewLoggers` ([log.go](../internal/server/log.go)) now replicates
|
||||||
|
the legacy switch: `proxy` / `upstream` monitors feed `muxLog` (or `io.Discard`)
|
||||||
|
per `none` / `both` / `upstream` / `proxy`, so `muxLog` accumulates the combined
|
||||||
|
history. `server.New` takes `muxlog` as a parameter. The loggers outlive config
|
||||||
|
reloads, so a `LogToStdout` change requires a restart to take effect.
|
||||||
|
|
||||||
|
**Gap 5 — `LogTimeFormat` config ignored -- Resolved.**
|
||||||
|
|
||||||
|
`cmd/newrouter/main.go` now maps `cfg.LogTimeFormat` to a Go time layout via the
|
||||||
|
`logTimeFormats` table and applies it (alongside log level) to the proxy and
|
||||||
|
upstream monitors in `applyLogSettings`, re-applied on config reload.
|
||||||
|
|
||||||
|
**Gap 6 — `LogRequests` deprecation warning missing.**
|
||||||
|
|
||||||
|
The legacy [proxymanager.go:127](../proxy/proxymanager.go#L127) warns when the
|
||||||
|
deprecated `logRequests` config key is set. `cmd/newrouter` does not. Low
|
||||||
|
priority — left open.
|
||||||
|
|
||||||
|
**Gap 7 — PID debug log missing -- Resolved.**
|
||||||
|
|
||||||
|
`cmd/newrouter/main.go` now logs `PID: %d` at debug level after `applyLogSettings`,
|
||||||
|
matching [llama-swap.go:71](../llama-swap.go#L71).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Phase X (tbd) — Cutover
|
||||||
|
|
||||||
|
- [ ] Swap `llama-swap.go` to delegate to `cmd/newrouter` (or rename newrouter to be the primary entrypoint)
|
||||||
|
- [ ] Update `Makefile` build targets
|
||||||
|
- [ ] Update docs / README references to the legacy binary
|
||||||
|
- [ ] Remove `proxy/proxymanager*.go` and `gin-gonic` dependency once nothing imports them
|
||||||
|
- [ ] Run `make test-all` and confirm concurrency suite still passes against the new entrypoint
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Cross-cutting concerns to keep in mind
|
||||||
|
|
||||||
|
- **Single body read**: legacy and newrouter both buffer the request body once. When adding filters (Phase 4c), make sure the buffered bytes flow through `Content-Length` / `transfer-encoding` cleanup as in [proxymanager.go:872](../proxy/proxymanager.go#L872).
|
||||||
|
- **Streaming flag in context**: legacy stashes `streaming` and `model` under `proxyCtxKey`. The new router uses `ModelKey` / `ModelIDKey` — pick one set of keys and use them consistently for metrics + log handlers.
|
||||||
|
- **Matrix vs Group divergence**: any handler that calls `swapProcessGroup` or `findGroupByModelName` in the legacy needs a matrix branch too. The new router's `Router` interface already abstracts this — preserve that abstraction rather than reintroducing the branch in every handler.
|
||||||
|
- **Shutdown ordering**: `httpServer.Shutdown` must drain inflight requests _before_ `Server.Shutdown` tears down processes, otherwise inflight requests 502. Current newrouter ordering at [main.go:87](../cmd/newrouter/main.go#L87) is correct — keep it.
|
||||||
@@ -4,6 +4,9 @@ go 1.26.1
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/billziss-gh/golib v0.2.0
|
github.com/billziss-gh/golib v0.2.0
|
||||||
|
github.com/charmbracelet/bubbles v1.0.0
|
||||||
|
github.com/charmbracelet/bubbletea v1.3.10
|
||||||
|
github.com/charmbracelet/lipgloss v1.1.0
|
||||||
github.com/fxamacker/cbor/v2 v2.9.1
|
github.com/fxamacker/cbor/v2 v2.9.1
|
||||||
github.com/gin-gonic/gin v1.10.0
|
github.com/gin-gonic/gin v1.10.0
|
||||||
github.com/klauspost/compress v1.18.5
|
github.com/klauspost/compress v1.18.5
|
||||||
@@ -11,16 +14,27 @@ require (
|
|||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
github.com/tidwall/gjson v1.18.0
|
github.com/tidwall/gjson v1.18.0
|
||||||
github.com/tidwall/sjson v1.2.5
|
github.com/tidwall/sjson v1.2.5
|
||||||
|
golang.org/x/sync v0.20.0
|
||||||
|
golang.org/x/sys v0.41.0
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||||
github.com/bytedance/sonic v1.11.6 // indirect
|
github.com/bytedance/sonic v1.11.6 // indirect
|
||||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||||
|
github.com/charmbracelet/colorprofile v0.4.1 // indirect
|
||||||
|
github.com/charmbracelet/x/ansi v0.11.6 // indirect
|
||||||
|
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||||
|
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||||
|
github.com/clipperhouse/displaywidth v0.9.0 // indirect
|
||||||
|
github.com/clipperhouse/stringish v0.1.1 // indirect
|
||||||
|
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
|
||||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/ebitengine/purego v0.10.0 // indirect
|
github.com/ebitengine/purego v0.10.0 // indirect
|
||||||
|
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||||
@@ -31,13 +45,20 @@ require (
|
|||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||||
github.com/leodido/go-urn v1.4.0 // indirect
|
github.com/leodido/go-urn v1.4.0 // indirect
|
||||||
|
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||||
|
github.com/mattn/go-runewidth v0.0.19 // indirect
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
|
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||||
|
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||||
|
github.com/muesli/termenv v0.16.0 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
|
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
|
||||||
|
github.com/rivo/uniseg v0.4.7 // indirect
|
||||||
github.com/tidwall/match v1.1.1 // indirect
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
github.com/tidwall/pretty v1.2.1 // indirect
|
github.com/tidwall/pretty v1.2.1 // indirect
|
||||||
github.com/tklauser/go-sysconf v0.3.16 // indirect
|
github.com/tklauser/go-sysconf v0.3.16 // indirect
|
||||||
@@ -45,11 +66,11 @@ require (
|
|||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
github.com/x448/float16 v0.8.4 // indirect
|
github.com/x448/float16 v0.8.4 // indirect
|
||||||
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/crypto v0.45.0 // indirect
|
golang.org/x/crypto v0.45.0 // indirect
|
||||||
golang.org/x/net v0.47.0 // indirect
|
golang.org/x/net v0.47.0 // indirect
|
||||||
golang.org/x/sys v0.41.0 // indirect
|
|
||||||
golang.org/x/text v0.31.0 // indirect
|
golang.org/x/text v0.31.0 // indirect
|
||||||
google.golang.org/protobuf v1.34.1 // indirect
|
google.golang.org/protobuf v1.34.1 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,9 +1,31 @@
|
|||||||
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||||
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||||
github.com/billziss-gh/golib v0.2.0 h1:NyvcAQdfvM8xokKkKotiligKjKXzuQD4PPykg1nKc/8=
|
github.com/billziss-gh/golib v0.2.0 h1:NyvcAQdfvM8xokKkKotiligKjKXzuQD4PPykg1nKc/8=
|
||||||
github.com/billziss-gh/golib v0.2.0/go.mod h1:mZpUYANXZkDKSnyYbX9gfnyxwe0ddRhUtfXcsD5r8dw=
|
github.com/billziss-gh/golib v0.2.0/go.mod h1:mZpUYANXZkDKSnyYbX9gfnyxwe0ddRhUtfXcsD5r8dw=
|
||||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||||
|
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
|
||||||
|
github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
|
||||||
|
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
||||||
|
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
|
||||||
|
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
|
||||||
|
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
|
||||||
|
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||||
|
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||||
|
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
|
||||||
|
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
|
||||||
|
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
||||||
|
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
|
||||||
|
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
|
||||||
|
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
|
||||||
|
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
|
||||||
|
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
|
||||||
|
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
|
||||||
|
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
||||||
|
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
|
||||||
|
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
||||||
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
||||||
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
||||||
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
|
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
|
||||||
@@ -13,6 +35,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
|||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/ebitengine/purego v0.10.0 h1:QIw4xfpWT6GWTzaW5XEKy3HXoqrJGx1ijYHzTF0/ISU=
|
github.com/ebitengine/purego v0.10.0 h1:QIw4xfpWT6GWTzaW5XEKy3HXoqrJGx1ijYHzTF0/ISU=
|
||||||
github.com/ebitengine/purego v0.10.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
github.com/ebitengine/purego v0.10.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||||
|
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||||
|
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||||
github.com/fxamacker/cbor/v2 v2.9.1 h1:2rWm8B193Ll4VdjsJY28jxs70IdDsHRWgQYAI80+rMQ=
|
github.com/fxamacker/cbor/v2 v2.9.1 h1:2rWm8B193Ll4VdjsJY28jxs70IdDsHRWgQYAI80+rMQ=
|
||||||
github.com/fxamacker/cbor/v2 v2.9.1/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
|
github.com/fxamacker/cbor/v2 v2.9.1/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||||
@@ -47,21 +71,35 @@ github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZY
|
|||||||
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
|
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
|
||||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||||
|
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||||
|
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||||
|
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||||
|
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
|
||||||
|
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||||
|
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||||
|
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||||
|
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||||
|
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||||
|
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||||
|
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
|
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
|
||||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||||
|
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||||
|
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||||
github.com/shirou/gopsutil/v4 v4.26.4 h1:B4SXVbcwTyrocPHEmWBC4uCYr4Xcu3MK1TXqbprAOWY=
|
github.com/shirou/gopsutil/v4 v4.26.4 h1:B4SXVbcwTyrocPHEmWBC4uCYr4Xcu3MK1TXqbprAOWY=
|
||||||
github.com/shirou/gopsutil/v4 v4.26.4/go.mod h1:LZ6ewCSkBqUpvSOf+LsTGnRinC6iaNUNMGBtDkJBaLQ=
|
github.com/shirou/gopsutil/v4 v4.26.4/go.mod h1:LZ6ewCSkBqUpvSOf+LsTGnRinC6iaNUNMGBtDkJBaLQ=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
@@ -97,6 +135,8 @@ github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65E
|
|||||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||||
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||||
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
@@ -104,10 +144,15 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
|||||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||||
|
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
|
||||||
|
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
|
||||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||||
|
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||||
|
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||||
|
|||||||
@@ -0,0 +1,63 @@
|
|||||||
|
// Package chain composes http.Handler middleware into a single handler.
|
||||||
|
//
|
||||||
|
// A Middleware wraps a downstream http.Handler and may run logic before or
|
||||||
|
// after delegating to it, or short-circuit by not calling next at all
|
||||||
|
// (e.g. auth failure, CORS preflight).
|
||||||
|
package chain
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// Middleware wraps an http.Handler with cross-cutting behavior. It receives
|
||||||
|
// the next handler in the chain and returns a handler that may call next,
|
||||||
|
// modify the request/response around it, or short-circuit.
|
||||||
|
type Middleware func(next http.Handler) http.Handler
|
||||||
|
|
||||||
|
// Chain is a reusable middleware stack. Build it once with New (and optionally
|
||||||
|
// extend per-route with Append), then call Then to wrap each terminal handler
|
||||||
|
// when registering routes against an http.ServeMux:
|
||||||
|
//
|
||||||
|
// api := chain.New(authMW, corsMW)
|
||||||
|
// mux.Handle("/v1/chat/completions", api.Then(dispatch))
|
||||||
|
// mux.Handle("/v1/embeddings", api.Append(filters).Then(dispatch))
|
||||||
|
//
|
||||||
|
// Middlewares execute left-to-right: mws[0] runs first and may call into
|
||||||
|
// mws[1], and so on, with the terminal handler invoked last. A middleware
|
||||||
|
// that does not call next short-circuits the remainder of the chain.
|
||||||
|
// A zero Chain is valid and applies no middleware.
|
||||||
|
type Chain struct {
|
||||||
|
mws []Middleware
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns a Chain that applies mws left-to-right around any terminal
|
||||||
|
// handler passed to Then.
|
||||||
|
func New(mws ...Middleware) Chain {
|
||||||
|
cp := make([]Middleware, len(mws))
|
||||||
|
copy(cp, mws)
|
||||||
|
return Chain{mws: cp}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append returns a new Chain with mws added after the existing middleware.
|
||||||
|
// The receiver is not modified, so a base Chain can be safely reused across
|
||||||
|
// multiple routes that each need different per-route additions.
|
||||||
|
func (c Chain) Append(mws ...Middleware) Chain {
|
||||||
|
out := make([]Middleware, 0, len(c.mws)+len(mws))
|
||||||
|
out = append(out, c.mws...)
|
||||||
|
out = append(out, mws...)
|
||||||
|
return Chain{mws: out}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then wraps final with the chain's middleware and returns the resulting
|
||||||
|
// handler, suitable for passing to http.ServeMux.Handle. With an empty chain,
|
||||||
|
// Then returns final unchanged.
|
||||||
|
func (c Chain) Then(final http.Handler) http.Handler {
|
||||||
|
h := final
|
||||||
|
for i := len(c.mws) - 1; i >= 0; i-- {
|
||||||
|
h = c.mws[i](h)
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
// ThenFunc is shorthand for Then(http.HandlerFunc(f)).
|
||||||
|
func (c Chain) ThenFunc(f http.HandlerFunc) http.Handler {
|
||||||
|
return c.Then(f)
|
||||||
|
}
|
||||||
@@ -0,0 +1,205 @@
|
|||||||
|
package chain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// recordingMiddleware appends tag before calling next and "-after-"+tag after.
|
||||||
|
func recordingMiddleware(tag string, log *[]string) Middleware {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
*log = append(*log, tag)
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
*log = append(*log, "after-"+tag)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChain_HandlersExecuteInDeclaredOrder(t *testing.T) {
|
||||||
|
var log []string
|
||||||
|
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log = append(log, "final")
|
||||||
|
})
|
||||||
|
|
||||||
|
h := New(
|
||||||
|
recordingMiddleware("a", &log),
|
||||||
|
recordingMiddleware("b", &log),
|
||||||
|
recordingMiddleware("c", &log),
|
||||||
|
).Then(final)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
h.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
want := []string{"a", "b", "c", "final", "after-c", "after-b", "after-a"}
|
||||||
|
if !equal(log, want) {
|
||||||
|
t.Fatalf("execution order mismatch:\n got: %v\nwant: %v", log, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChain_ShortCircuitsWhenMiddlewareDoesNotCallNext(t *testing.T) {
|
||||||
|
var log []string
|
||||||
|
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log = append(log, "final")
|
||||||
|
})
|
||||||
|
|
||||||
|
gate := func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log = append(log, "gate")
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
h := New(
|
||||||
|
recordingMiddleware("outer", &log),
|
||||||
|
gate,
|
||||||
|
recordingMiddleware("inner", &log),
|
||||||
|
).Then(final)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
h.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("status: got %d, want %d", rec.Code, http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
want := []string{"outer", "gate", "after-outer"}
|
||||||
|
if !equal(log, want) {
|
||||||
|
t.Fatalf("short-circuit order mismatch:\n got: %v\nwant: %v", log, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChain_EarlyWritesAreVisibleToLaterMiddleware(t *testing.T) {
|
||||||
|
header := func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("X-Set-By", "outer")
|
||||||
|
_, _ = io.WriteString(w, "outer:")
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
inner := func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// The outer middleware already set the header; we should see it.
|
||||||
|
if got := w.Header().Get("X-Set-By"); got != "outer" {
|
||||||
|
_, _ = io.WriteString(w, "missing-header;")
|
||||||
|
}
|
||||||
|
_, _ = io.WriteString(w, "inner:")
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = io.WriteString(w, "final")
|
||||||
|
})
|
||||||
|
|
||||||
|
h := New(header, inner).Then(final)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(rec.Body)
|
||||||
|
if got := string(body); !strings.Contains(got, "outer:inner:final") {
|
||||||
|
t.Fatalf("body: got %q, want it to contain %q", got, "outer:inner:final")
|
||||||
|
}
|
||||||
|
if got := rec.Header().Get("X-Set-By"); got != "outer" {
|
||||||
|
t.Fatalf("header X-Set-By: got %q, want %q", got, "outer")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChain_ReusableAcrossRoutesViaThen(t *testing.T) {
|
||||||
|
var log []string
|
||||||
|
base := New(
|
||||||
|
recordingMiddleware("auth", &log),
|
||||||
|
recordingMiddleware("cors", &log),
|
||||||
|
)
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.Handle("/a", base.ThenFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log = append(log, "handler-a")
|
||||||
|
}))
|
||||||
|
mux.Handle("/b", base.ThenFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log = append(log, "handler-b")
|
||||||
|
}))
|
||||||
|
|
||||||
|
srv := httptest.NewServer(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
for _, path := range []string{"/a", "/b"} {
|
||||||
|
resp, err := http.Get(srv.URL + path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GET %s: %v", path, err)
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []string{
|
||||||
|
"auth", "cors", "handler-a", "after-cors", "after-auth",
|
||||||
|
"auth", "cors", "handler-b", "after-cors", "after-auth",
|
||||||
|
}
|
||||||
|
if !equal(log, want) {
|
||||||
|
t.Fatalf("reusable chain order mismatch:\n got: %v\nwant: %v", log, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChain_AppendDoesNotMutateReceiver(t *testing.T) {
|
||||||
|
var log []string
|
||||||
|
base := New(recordingMiddleware("base", &log))
|
||||||
|
extended := base.Append(recordingMiddleware("extra", &log))
|
||||||
|
|
||||||
|
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log = append(log, "final")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Run extended first to surface any aliasing of the underlying slice.
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
extended.Then(final).ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||||
|
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
base.Then(final).ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||||
|
|
||||||
|
want := []string{
|
||||||
|
"base", "extra", "final", "after-extra", "after-base",
|
||||||
|
"base", "final", "after-base",
|
||||||
|
}
|
||||||
|
if !equal(log, want) {
|
||||||
|
t.Fatalf("Append must not mutate the receiver:\n got: %v\nwant: %v", log, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChain_ZeroValueAndEmptyThenAreIdentity(t *testing.T) {
|
||||||
|
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusTeapot)
|
||||||
|
})
|
||||||
|
|
||||||
|
for name, c := range map[string]Chain{
|
||||||
|
"zero": {},
|
||||||
|
"empty": New(),
|
||||||
|
} {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
h := c.Then(final)
|
||||||
|
if _, ok := h.(http.HandlerFunc); !ok {
|
||||||
|
t.Fatalf("expected http.HandlerFunc identity, got %T", h)
|
||||||
|
}
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||||
|
if rec.Code != http.StatusTeapot {
|
||||||
|
t.Fatalf("status: got %d, want %d", rec.Code, http.StatusTeapot)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func equal(a, b []string) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if a[i] != b[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
@@ -272,6 +272,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
nextPort := config.StartPort
|
nextPort := config.StartPort
|
||||||
for _, modelId := range modelIds {
|
for _, modelId := range modelIds {
|
||||||
modelConfig := config.Models[modelId]
|
modelConfig := config.Models[modelId]
|
||||||
|
modelConfig.HealthCheckTimeout = config.HealthCheckTimeout
|
||||||
|
|
||||||
// Strip comments from command fields
|
// Strip comments from command fields
|
||||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||||
@@ -189,42 +189,46 @@ groups:
|
|||||||
SendLoadingState: false,
|
SendLoadingState: false,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": {
|
"model1": {
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
Proxy: "http://localhost:8080",
|
Proxy: "http://localhost:8080",
|
||||||
Aliases: []string{"m1", "model-one"},
|
Aliases: []string{"m1", "model-one"},
|
||||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
Name: "Model 1",
|
Name: "Model 1",
|
||||||
Description: "This is model 1",
|
Description: "This is model 1",
|
||||||
SendLoadingState: &modelLoadingState,
|
SendLoadingState: &modelLoadingState,
|
||||||
Timeouts: defaultTimeout,
|
Timeouts: defaultTimeout,
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
},
|
},
|
||||||
"model2": {
|
"model2": {
|
||||||
Cmd: "path/to/server --arg1 one",
|
Cmd: "path/to/server --arg1 one",
|
||||||
Proxy: "http://localhost:8081",
|
Proxy: "http://localhost:8081",
|
||||||
Aliases: []string{"m2"},
|
Aliases: []string{"m2"},
|
||||||
Env: []string{},
|
Env: []string{},
|
||||||
CheckEndpoint: "/",
|
CheckEndpoint: "/",
|
||||||
SendLoadingState: &modelLoadingState,
|
SendLoadingState: &modelLoadingState,
|
||||||
Timeouts: defaultTimeout,
|
Timeouts: defaultTimeout,
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
},
|
},
|
||||||
"model3": {
|
"model3": {
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
Proxy: "http://localhost:8081",
|
Proxy: "http://localhost:8081",
|
||||||
Aliases: []string{"mthree"},
|
Aliases: []string{"mthree"},
|
||||||
Env: []string{},
|
Env: []string{},
|
||||||
CheckEndpoint: "/",
|
CheckEndpoint: "/",
|
||||||
SendLoadingState: &modelLoadingState,
|
SendLoadingState: &modelLoadingState,
|
||||||
Timeouts: defaultTimeout,
|
Timeouts: defaultTimeout,
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
},
|
},
|
||||||
"model4": {
|
"model4": {
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
Proxy: "http://localhost:8082",
|
Proxy: "http://localhost:8082",
|
||||||
CheckEndpoint: "/",
|
CheckEndpoint: "/",
|
||||||
Aliases: []string{},
|
Aliases: []string{},
|
||||||
Env: []string{},
|
Env: []string{},
|
||||||
SendLoadingState: &modelLoadingState,
|
SendLoadingState: &modelLoadingState,
|
||||||
Timeouts: defaultTimeout,
|
Timeouts: defaultTimeout,
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
@@ -176,44 +176,48 @@ groups:
|
|||||||
SendLoadingState: false,
|
SendLoadingState: false,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": {
|
"model1": {
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||||
Proxy: "http://localhost:8080",
|
Proxy: "http://localhost:8080",
|
||||||
Aliases: []string{"m1", "model-one"},
|
Aliases: []string{"m1", "model-one"},
|
||||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
SendLoadingState: &modelLoadingState,
|
SendLoadingState: &modelLoadingState,
|
||||||
Timeouts: defaultTimeout,
|
Timeouts: defaultTimeout,
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
},
|
},
|
||||||
"model2": {
|
"model2": {
|
||||||
Cmd: "path/to/server --arg1 one",
|
Cmd: "path/to/server --arg1 one",
|
||||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||||
Proxy: "http://localhost:8081",
|
Proxy: "http://localhost:8081",
|
||||||
Aliases: []string{"m2"},
|
Aliases: []string{"m2"},
|
||||||
Env: []string{},
|
Env: []string{},
|
||||||
CheckEndpoint: "/",
|
CheckEndpoint: "/",
|
||||||
SendLoadingState: &modelLoadingState,
|
SendLoadingState: &modelLoadingState,
|
||||||
Timeouts: defaultTimeout,
|
Timeouts: defaultTimeout,
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
},
|
},
|
||||||
"model3": {
|
"model3": {
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||||
Proxy: "http://localhost:8081",
|
Proxy: "http://localhost:8081",
|
||||||
Aliases: []string{"mthree"},
|
Aliases: []string{"mthree"},
|
||||||
Env: []string{},
|
Env: []string{},
|
||||||
CheckEndpoint: "/",
|
CheckEndpoint: "/",
|
||||||
SendLoadingState: &modelLoadingState,
|
SendLoadingState: &modelLoadingState,
|
||||||
Timeouts: defaultTimeout,
|
Timeouts: defaultTimeout,
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
},
|
},
|
||||||
"model4": {
|
"model4": {
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||||
Proxy: "http://localhost:8082",
|
Proxy: "http://localhost:8082",
|
||||||
CheckEndpoint: "/",
|
CheckEndpoint: "/",
|
||||||
Aliases: []string{},
|
Aliases: []string{},
|
||||||
Env: []string{},
|
Env: []string{},
|
||||||
SendLoadingState: &modelLoadingState,
|
SendLoadingState: &modelLoadingState,
|
||||||
Timeouts: defaultTimeout,
|
Timeouts: defaultTimeout,
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
@@ -54,6 +54,9 @@ type ModelConfig struct {
|
|||||||
|
|
||||||
// Timeout settings for proxy connections
|
// Timeout settings for proxy connections
|
||||||
Timeouts TimeoutsConfig `yaml:"timeouts"`
|
Timeouts TimeoutsConfig `yaml:"timeouts"`
|
||||||
|
|
||||||
|
// Copy of HealthCheckTimeout from global config
|
||||||
|
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/event"
|
"github.com/mostlygeek/llama-swap/internal/event"
|
||||||
)
|
)
|
||||||
|
|
||||||
const DataEventID = 0x04
|
const DataEventID = 0x04
|
||||||
|
|||||||
@@ -0,0 +1,214 @@
|
|||||||
|
package perf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseNvidiaSmiLine parses a single line from nvidia-smi CSV output.
|
||||||
|
// Format: index,name,uuid,temperature.gpu,utilization.gpu,memory.used,memory.total,fan.speed,power.draw
|
||||||
|
func ParseNvidiaSmiLine(line string) *GpuStat {
|
||||||
|
fields := strings.Split(line, ",")
|
||||||
|
if len(fields) < 9 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
id, _ := strconv.Atoi(strings.TrimSpace(fields[0]))
|
||||||
|
name := strings.TrimSpace(fields[1])
|
||||||
|
uuid := strings.TrimSpace(fields[2])
|
||||||
|
tempC, _ := strconv.Atoi(strings.TrimSpace(fields[3]))
|
||||||
|
gpuUtil, _ := strconv.ParseFloat(strings.TrimSpace(fields[4]), 64)
|
||||||
|
memUsed, _ := strconv.Atoi(strings.TrimSpace(fields[5]))
|
||||||
|
memTotal, _ := strconv.Atoi(strings.TrimSpace(fields[6]))
|
||||||
|
fanSpeed, _ := strconv.ParseFloat(strings.TrimSpace(fields[7]), 64)
|
||||||
|
powerDraw, _ := strconv.ParseFloat(strings.TrimSpace(fields[8]), 64)
|
||||||
|
|
||||||
|
var memUtil float64
|
||||||
|
if memTotal > 0 {
|
||||||
|
memUtil = float64(memUsed) / float64(memTotal) * 100
|
||||||
|
}
|
||||||
|
|
||||||
|
return &GpuStat{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
ID: id,
|
||||||
|
Name: name,
|
||||||
|
UUID: uuid,
|
||||||
|
TempC: tempC,
|
||||||
|
GpuUtilPct: gpuUtil,
|
||||||
|
MemUtilPct: memUtil,
|
||||||
|
MemUsedMB: memUsed,
|
||||||
|
MemTotalMB: memTotal,
|
||||||
|
FanSpeedPct: fanSpeed,
|
||||||
|
PowerDrawW: powerDraw,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mactopOutput maps the subset of mactop's headless JSON output that is
|
||||||
|
// relevant to GpuStat. Note that mactop's memory object is whole-system memory,
|
||||||
|
// not GPU-attributed; the darwin monitor overlays ioreg's GPU-attributed
|
||||||
|
// unified memory (see overlayIoregMem) so both backends report consistent
|
||||||
|
// memory figures.
|
||||||
|
type mactopOutput struct {
|
||||||
|
SocMetrics struct {
|
||||||
|
GPUPower float64 `json:"gpu_power"`
|
||||||
|
GPUFreq int `json:"gpu_freq_mhz"`
|
||||||
|
GPUTemp float64 `json:"gpu_temp"`
|
||||||
|
} `json:"soc_metrics"`
|
||||||
|
Memory struct {
|
||||||
|
Total uint64 `json:"total"`
|
||||||
|
Used uint64 `json:"used"`
|
||||||
|
} `json:"memory"`
|
||||||
|
GPUUsage float64 `json:"gpu_usage"`
|
||||||
|
SystemInfo struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
GPUCoreCount int `json:"gpu_core_count"`
|
||||||
|
} `json:"system_info"`
|
||||||
|
Fans []struct {
|
||||||
|
RPM int `json:"rpm"`
|
||||||
|
MinRPM int `json:"min_rpm"`
|
||||||
|
MaxRPM int `json:"max_rpm"`
|
||||||
|
} `json:"fans"`
|
||||||
|
Temperatures []struct {
|
||||||
|
Group string `json:"group"`
|
||||||
|
Avg float64 `json:"avg_celsius"`
|
||||||
|
} `json:"temperatures"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ioreg output uses ` = ` (with spaces) for top-level device properties and
|
||||||
|
// `=` (no spaces) for values inside nested dictionaries such as
|
||||||
|
// PerformanceStatistics.
|
||||||
|
var (
|
||||||
|
reIoregModel = regexp.MustCompile(`"model"\s*=\s*"([^"]+)"`)
|
||||||
|
reIoregCoreCount = regexp.MustCompile(`"gpu-core-count"\s*=\s*(\d+)`)
|
||||||
|
reIoregUtil = regexp.MustCompile(`"Device Utilization %"=(\d+)`)
|
||||||
|
reIoregMemUsed = regexp.MustCompile(`"In use system memory"=(\d+)`)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseIoregOutput parses `ioreg -r -c IOGPU -d 1 -f` output into a GpuStat for
|
||||||
|
// the Apple Silicon integrated GPU. This is a fallback for when mactop is not
|
||||||
|
// installed: utilization and used memory are available, but power, temperature,
|
||||||
|
// and fan speed are not exposed by ioreg. memTotalMB is the unified memory size
|
||||||
|
// supplied by the caller, since Apple Silicon shares memory between CPU and GPU.
|
||||||
|
// Returns nil if no GPU device is found in the output.
|
||||||
|
func ParseIoregOutput(out []byte, memTotalMB int) *GpuStat {
|
||||||
|
utilMatch := reIoregUtil.FindSubmatch(out)
|
||||||
|
memMatch := reIoregMemUsed.FindSubmatch(out)
|
||||||
|
if utilMatch == nil && memMatch == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var gpuUtil float64
|
||||||
|
if utilMatch != nil {
|
||||||
|
gpuUtil, _ = strconv.ParseFloat(string(utilMatch[1]), 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
const toMB = 1024 * 1024
|
||||||
|
var memUsedMB int
|
||||||
|
if memMatch != nil {
|
||||||
|
memUsedBytes, _ := strconv.ParseInt(string(memMatch[1]), 10, 64)
|
||||||
|
memUsedMB = int(memUsedBytes / toMB)
|
||||||
|
}
|
||||||
|
|
||||||
|
var memUtil float64
|
||||||
|
if memTotalMB > 0 {
|
||||||
|
memUtil = float64(memUsedMB) / float64(memTotalMB) * 100
|
||||||
|
}
|
||||||
|
|
||||||
|
name := "Apple GPU"
|
||||||
|
if m := reIoregModel.FindSubmatch(out); m != nil {
|
||||||
|
name = string(m[1])
|
||||||
|
}
|
||||||
|
if m := reIoregCoreCount.FindSubmatch(out); m != nil {
|
||||||
|
if cores, err := strconv.Atoi(string(m[1])); err == nil && cores > 0 {
|
||||||
|
name = fmt.Sprintf("%s (%d-core GPU)", name, cores)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &GpuStat{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
ID: 0,
|
||||||
|
Name: name,
|
||||||
|
GpuUtilPct: gpuUtil,
|
||||||
|
MemUtilPct: memUtil,
|
||||||
|
MemUsedMB: memUsedMB,
|
||||||
|
MemTotalMB: memTotalMB,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseMactopLine parses a single line of mactop headless JSON output into a
|
||||||
|
// GpuStat for the Apple Silicon integrated GPU. Returns nil if the line cannot
|
||||||
|
// be parsed.
|
||||||
|
func ParseMactopLine(line string) *GpuStat {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if line == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var out mactopOutput
|
||||||
|
if err := json.Unmarshal([]byte(line), &out); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const toMB = 1024 * 1024
|
||||||
|
memUsedMB := int(out.Memory.Used / toMB)
|
||||||
|
memTotalMB := int(out.Memory.Total / toMB)
|
||||||
|
|
||||||
|
var memUtil float64
|
||||||
|
if memTotalMB > 0 {
|
||||||
|
memUtil = float64(memUsedMB) / float64(memTotalMB) * 100
|
||||||
|
}
|
||||||
|
|
||||||
|
name := out.SystemInfo.Name
|
||||||
|
if name == "" {
|
||||||
|
name = "Apple GPU"
|
||||||
|
}
|
||||||
|
if out.SystemInfo.GPUCoreCount > 0 {
|
||||||
|
name = fmt.Sprintf("%s (%d-core GPU)", name, out.SystemInfo.GPUCoreCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unified memory has no dedicated VRAM sensor; use the memory temperature
|
||||||
|
// group when mactop exposes it.
|
||||||
|
var vramTempC int
|
||||||
|
for _, t := range out.Temperatures {
|
||||||
|
if strings.EqualFold(t.Group, "Memory") {
|
||||||
|
vramTempC = int(math.Round(t.Avg))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Average fan load across all fans as a percentage of their RPM range.
|
||||||
|
var fanSpeed float64
|
||||||
|
var fanCount int
|
||||||
|
for _, f := range out.Fans {
|
||||||
|
if f.MaxRPM > f.MinRPM {
|
||||||
|
pct := float64(f.RPM-f.MinRPM) / float64(f.MaxRPM-f.MinRPM) * 100
|
||||||
|
if pct < 0 {
|
||||||
|
pct = 0
|
||||||
|
}
|
||||||
|
fanSpeed += pct
|
||||||
|
fanCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if fanCount > 0 {
|
||||||
|
fanSpeed /= float64(fanCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &GpuStat{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
ID: 0,
|
||||||
|
Name: name,
|
||||||
|
TempC: int(math.Round(out.SocMetrics.GPUTemp)),
|
||||||
|
VramTempC: vramTempC,
|
||||||
|
GpuUtilPct: out.GPUUsage,
|
||||||
|
MemUtilPct: memUtil,
|
||||||
|
MemUsedMB: memUsedMB,
|
||||||
|
MemTotalMB: memTotalMB,
|
||||||
|
FanSpeedPct: fanSpeed,
|
||||||
|
PowerDrawW: out.SocMetrics.GPUPower,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,13 +6,13 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
"github.com/mostlygeek/llama-swap/internal/ring"
|
"github.com/mostlygeek/llama-swap/internal/ring"
|
||||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrNotImplemented = errors.New("Not Implemented")
|
ErrNotImplemented = errors.New("not implemented")
|
||||||
ErrNoGpuTool = errors.New("no GPU monitoring tool available")
|
ErrNoGpuTool = errors.New("no GPU monitoring tool available")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
package perf
|
package perf
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
@@ -11,7 +15,156 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func getGpuStats(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
func getGpuStats(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||||
return nil, ErrNotImplemented
|
if ch, err := tryMactop(ctx, every, logger); err == nil {
|
||||||
|
logger.Info("using mactop for GPU monitoring")
|
||||||
|
return ch, nil
|
||||||
|
} else {
|
||||||
|
logger.Debugf("mactop: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if ch, err := tryIoreg(ctx, every, logger); err == nil {
|
||||||
|
logger.Info("using ioreg for GPU monitoring")
|
||||||
|
return ch, nil
|
||||||
|
} else {
|
||||||
|
logger.Debugf("ioreg: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, ErrNoGpuTool
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryIoreg polls `ioreg -r -c IOGPU -d 1 -f` for Apple Silicon GPU stats. It is
|
||||||
|
// a fallback for when mactop is not installed. ioreg exposes GPU utilization and
|
||||||
|
// used memory but not power, temperature, or fan speed.
|
||||||
|
func tryIoreg(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||||
|
if _, err := exec.LookPath("ioreg"); err != nil {
|
||||||
|
return nil, ErrNoGpuTool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify ioreg actually reports a GPU device before committing to it, so we
|
||||||
|
// can fall through to ErrNoGpuTool otherwise.
|
||||||
|
if stat := sampleIoreg(ctx); stat == nil {
|
||||||
|
return nil, fmt.Errorf("ioreg reported no GPU device")
|
||||||
|
}
|
||||||
|
|
||||||
|
if every < time.Second {
|
||||||
|
every = time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := make(chan []GpuStat, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
ticker := time.NewTicker(every)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
stat := sampleIoreg(ctx)
|
||||||
|
if stat == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case ch <- []GpuStat{*stat}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return ch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sampleIoreg runs ioreg once and parses a single GpuStat, or returns nil.
|
||||||
|
func sampleIoreg(ctx context.Context) *GpuStat {
|
||||||
|
out, err := exec.CommandContext(ctx, "ioreg", "-r", "-c", "IOGPU", "-d", "1", "-f").Output()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var memTotalMB int
|
||||||
|
if vmStat, err := mem.VirtualMemory(); err == nil {
|
||||||
|
memTotalMB = int(vmStat.Total / (1024 * 1024))
|
||||||
|
}
|
||||||
|
|
||||||
|
return ParseIoregOutput(out, memTotalMB)
|
||||||
|
}
|
||||||
|
|
||||||
|
// overlayIoregMem replaces a GpuStat's memory fields with the GPU-attributed
|
||||||
|
// unified memory reported by ioreg. mactop only exposes whole-system memory, so
|
||||||
|
// without this the mactop and ioreg backends would report different memory
|
||||||
|
// semantics. It is a no-op when ioreg is unavailable or reports no GPU memory,
|
||||||
|
// leaving the mactop-supplied values in place.
|
||||||
|
func overlayIoregMem(ctx context.Context, stat *GpuStat) {
|
||||||
|
ioStat := sampleIoreg(ctx)
|
||||||
|
if ioStat == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stat.MemUsedMB = ioStat.MemUsedMB
|
||||||
|
stat.MemTotalMB = ioStat.MemTotalMB
|
||||||
|
stat.MemUtilPct = ioStat.MemUtilPct
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryMactop streams Apple Silicon GPU stats from mactop's headless mode.
|
||||||
|
// See https://github.com/metaspartan/mactop. mactop emits one JSON object per
|
||||||
|
// sample to stdout, which we parse into GpuStat.
|
||||||
|
func tryMactop(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||||
|
if _, err := exec.LookPath("mactop"); err != nil {
|
||||||
|
return nil, ErrNoGpuTool
|
||||||
|
}
|
||||||
|
|
||||||
|
// mactop samples power over the interval, so give it at least a second.
|
||||||
|
intervalMs := int(every.Milliseconds())
|
||||||
|
if intervalMs < 1000 {
|
||||||
|
intervalMs = 1000
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "mactop",
|
||||||
|
"--headless",
|
||||||
|
"--format", "json",
|
||||||
|
"--interval", fmt.Sprintf("%d", intervalMs),
|
||||||
|
)
|
||||||
|
|
||||||
|
stdout, err := cmd.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("mactop stdout pipe failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return nil, fmt.Errorf("mactop start failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := make(chan []GpuStat, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(stdout)
|
||||||
|
// mactop's JSON objects can be large; allow generous line lengths.
|
||||||
|
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
stat := ParseMactopLine(line)
|
||||||
|
if stat != nil {
|
||||||
|
// mactop only reports whole-system memory; overlay ioreg's
|
||||||
|
// GPU-attributed unified memory so both backends are consistent.
|
||||||
|
overlayIoregMem(ctx, stat)
|
||||||
|
select {
|
||||||
|
case ch <- []GpuStat{*stat}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cmd.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return ch, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func readSysStats() (SysStat, error) {
|
func readSysStats() (SysStat, error) {
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -224,3 +224,90 @@ func TestCurrent_ConcurrentAccess(t *testing.T) {
|
|||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseNvidiaSmiLine_ValidLine(t *testing.T) {
|
||||||
|
line := "0, NVIDIA GeForce RTX 3080, GPU-12345678-1234-1234-1234-123456789abc, 65, 80, 8192, 10240, 75, 250"
|
||||||
|
|
||||||
|
stat := ParseNvidiaSmiLine(line)
|
||||||
|
require.NotNil(t, stat)
|
||||||
|
|
||||||
|
assert.Equal(t, 0, stat.ID)
|
||||||
|
assert.Equal(t, "NVIDIA GeForce RTX 3080", stat.Name)
|
||||||
|
assert.Equal(t, "GPU-12345678-1234-1234-1234-123456789abc", stat.UUID)
|
||||||
|
assert.Equal(t, 65, stat.TempC)
|
||||||
|
assert.Equal(t, 80.0, stat.GpuUtilPct)
|
||||||
|
assert.Equal(t, 8192, stat.MemUsedMB)
|
||||||
|
assert.Equal(t, 10240, stat.MemTotalMB)
|
||||||
|
assert.Equal(t, 75.0, stat.FanSpeedPct)
|
||||||
|
assert.Equal(t, 250.0, stat.PowerDrawW)
|
||||||
|
assert.InDelta(t, 80.0, stat.MemUtilPct, 0.01)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseNvidiaSmiLine_ShortLine(t *testing.T) {
|
||||||
|
line := "0, NVIDIA GPU, GPU-123"
|
||||||
|
|
||||||
|
stat := ParseNvidiaSmiLine(line)
|
||||||
|
assert.Nil(t, stat)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseNvidiaSmiLine_MissingFields(t *testing.T) {
|
||||||
|
line := "0, NVIDIA GPU, GPU-123, 65, 80, 8192, 10240, 75"
|
||||||
|
|
||||||
|
stat := ParseNvidiaSmiLine(line)
|
||||||
|
assert.Nil(t, stat)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseNvidiaSmiLine_ZeroMemoryTotal(t *testing.T) {
|
||||||
|
line := "0, NVIDIA GPU, GPU-123, 65, 80, 0, 0, 75, 250"
|
||||||
|
|
||||||
|
stat := ParseNvidiaSmiLine(line)
|
||||||
|
require.NotNil(t, stat)
|
||||||
|
assert.Equal(t, 0.0, stat.MemUtilPct)
|
||||||
|
}
|
||||||
|
|
||||||
|
const ioregSample = `+-o AGXAcceleratorG13X <class AGXAcceleratorG13X, id 0x1000009a1, registered, matched, active, busy 0 (39191 ms), retain 108>
|
||||||
|
{
|
||||||
|
"model" = "Apple M1 Pro"
|
||||||
|
"gpu-core-count" = 16
|
||||||
|
"PerformanceStatistics" = {"In use system memory (driver)"=0,"Alloc system memory"=14511046656,"Tiler Utilization %"=34,"recoveryCount"=0,"Renderer Utilization %"=34,"Device Utilization %"=34,"In use system memory"=7688503296}
|
||||||
|
"IOClass" = "AGXAcceleratorG13X"
|
||||||
|
}`
|
||||||
|
|
||||||
|
func TestParseIoregOutput_ValidOutput(t *testing.T) {
|
||||||
|
const memTotalMB = 32768
|
||||||
|
|
||||||
|
stat := ParseIoregOutput([]byte(ioregSample), memTotalMB)
|
||||||
|
require.NotNil(t, stat)
|
||||||
|
|
||||||
|
assert.Equal(t, 0, stat.ID)
|
||||||
|
assert.Equal(t, "Apple M1 Pro (16-core GPU)", stat.Name)
|
||||||
|
assert.Equal(t, 34.0, stat.GpuUtilPct)
|
||||||
|
assert.Equal(t, 7688503296/(1024*1024), stat.MemUsedMB)
|
||||||
|
assert.Equal(t, memTotalMB, stat.MemTotalMB)
|
||||||
|
assert.InDelta(t, float64(stat.MemUsedMB)/memTotalMB*100, stat.MemUtilPct, 0.01)
|
||||||
|
// Not exposed by ioreg.
|
||||||
|
assert.Equal(t, 0, stat.TempC)
|
||||||
|
assert.Equal(t, 0.0, stat.PowerDrawW)
|
||||||
|
assert.Equal(t, 0.0, stat.FanSpeedPct)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseIoregOutput_NoGpuDevice(t *testing.T) {
|
||||||
|
stat := ParseIoregOutput([]byte("no gpu here"), 32768)
|
||||||
|
assert.Nil(t, stat)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseIoregOutput_ZeroMemTotal(t *testing.T) {
|
||||||
|
stat := ParseIoregOutput([]byte(ioregSample), 0)
|
||||||
|
require.NotNil(t, stat)
|
||||||
|
assert.Equal(t, 0.0, stat.MemUtilPct)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseIoregOutput_MissingModel(t *testing.T) {
|
||||||
|
const out = `"Device Utilization %"=50,"In use system memory"=1048576`
|
||||||
|
|
||||||
|
stat := ParseIoregOutput([]byte(out), 1024)
|
||||||
|
require.NotNil(t, stat)
|
||||||
|
assert.Equal(t, "Apple GPU", stat.Name)
|
||||||
|
assert.Equal(t, 50.0, stat.GpuUtilPct)
|
||||||
|
assert.Equal(t, 1, stat.MemUsedMB)
|
||||||
|
}
|
||||||
|
|||||||
@@ -170,7 +170,7 @@ func tryNvidiaSmi(ctx context.Context, every time.Duration, logger *logmon.Monit
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
stat := parseNvidiaSmiLine(line)
|
stat := ParseNvidiaSmiLine(line)
|
||||||
if stat != nil {
|
if stat != nil {
|
||||||
select {
|
select {
|
||||||
case ch <- []GpuStat{*stat}:
|
case ch <- []GpuStat{*stat}:
|
||||||
@@ -184,42 +184,6 @@ func tryNvidiaSmi(ctx context.Context, every time.Duration, logger *logmon.Monit
|
|||||||
return ch, nil
|
return ch, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseNvidiaSmiLine(line string) *GpuStat {
|
|
||||||
fields := strings.Split(line, ", ")
|
|
||||||
if len(fields) < 9 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
id, _ := strconv.Atoi(strings.TrimSpace(fields[0]))
|
|
||||||
name := strings.TrimSpace(fields[1])
|
|
||||||
uuid := strings.TrimSpace(fields[2])
|
|
||||||
tempC, _ := strconv.Atoi(strings.TrimSpace(fields[3]))
|
|
||||||
gpuUtil, _ := strconv.ParseFloat(strings.TrimSpace(fields[4]), 64)
|
|
||||||
memUsed, _ := strconv.Atoi(strings.TrimSpace(fields[5]))
|
|
||||||
memTotal, _ := strconv.Atoi(strings.TrimSpace(fields[6]))
|
|
||||||
fanSpeed, _ := strconv.ParseFloat(strings.TrimSpace(fields[7]), 64)
|
|
||||||
powerDraw, _ := strconv.ParseFloat(strings.TrimSpace(fields[8]), 64)
|
|
||||||
|
|
||||||
var memUtil float64
|
|
||||||
if memTotal > 0 {
|
|
||||||
memUtil = float64(memUsed) / float64(memTotal) * 100
|
|
||||||
}
|
|
||||||
|
|
||||||
return &GpuStat{
|
|
||||||
Timestamp: time.Now(),
|
|
||||||
ID: id,
|
|
||||||
Name: name,
|
|
||||||
UUID: uuid,
|
|
||||||
TempC: tempC,
|
|
||||||
GpuUtilPct: gpuUtil,
|
|
||||||
MemUtilPct: memUtil,
|
|
||||||
MemUsedMB: memUsed,
|
|
||||||
MemTotalMB: memTotal,
|
|
||||||
FanSpeedPct: fanSpeed,
|
|
||||||
PowerDrawW: powerDraw,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func tryRocmSmi(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
func tryRocmSmi(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||||
if _, err := exec.LookPath("rocm-smi"); err != nil {
|
if _, err := exec.LookPath("rocm-smi"); err != nil {
|
||||||
return nil, ErrNoGpuTool
|
return nil, ErrNoGpuTool
|
||||||
@@ -255,13 +219,18 @@ func tryRocmSmi(ctx context.Context, every time.Duration, logger *logmon.Monitor
|
|||||||
|
|
||||||
stats := make([]GpuStat, 0)
|
stats := make([]GpuStat, 0)
|
||||||
scanner := bufio.NewScanner(strings.NewReader(string(out)))
|
scanner := bufio.NewScanner(strings.NewReader(string(out)))
|
||||||
|
var header string
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := strings.TrimSpace(scanner.Text())
|
line := strings.TrimSpace(scanner.Text())
|
||||||
if line == "" || strings.HasPrefix(line, "device,") {
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(line, "device,") {
|
||||||
|
header = line
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
stat := parseRocmSmiLine(line)
|
stat := parseRocmSmiLine(header, line)
|
||||||
if stat != nil {
|
if stat != nil {
|
||||||
stats = append(stats, *stat)
|
stats = append(stats, *stat)
|
||||||
}
|
}
|
||||||
@@ -280,51 +249,88 @@ func tryRocmSmi(ctx context.Context, every time.Duration, logger *logmon.Monitor
|
|||||||
return ch, nil
|
return ch, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseRocmSmiLine(line string) *GpuStat {
|
func parseRocmSmiLine(header string, line string) *GpuStat {
|
||||||
|
if header == "" || line == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
labels := strings.Split(header, ",")
|
||||||
fields := strings.Split(line, ",")
|
fields := strings.Split(line, ",")
|
||||||
if len(fields) < 20 {
|
if len(labels) != len(fields) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
device := strings.TrimSpace(fields[0])
|
result := &GpuStat{
|
||||||
id, err := strconv.Atoi(strings.TrimPrefix(device, "card"))
|
Timestamp: time.Now(),
|
||||||
if err != nil {
|
ID: -1,
|
||||||
return nil
|
|
||||||
}
|
|
||||||
deviceName := strings.TrimSpace(fields[1])
|
|
||||||
uuid := strings.TrimSpace(fields[5])
|
|
||||||
tempC, _ := strconv.ParseFloat(strings.TrimSpace(fields[6]), 64)
|
|
||||||
vramTempC, _ := strconv.ParseFloat(strings.TrimSpace(fields[8]), 64)
|
|
||||||
fanSpeed, _ := strconv.ParseFloat(strings.TrimSpace(fields[10]), 64)
|
|
||||||
powerDraw, _ := strconv.ParseFloat(strings.TrimSpace(fields[12]), 64)
|
|
||||||
gpuUtil, _ := strconv.ParseFloat(strings.TrimSpace(fields[13]), 64)
|
|
||||||
memUtil, _ := strconv.ParseFloat(strings.TrimSpace(fields[14]), 64)
|
|
||||||
memTotal, _ := strconv.ParseUint(strings.TrimSpace(fields[17]), 10, 64)
|
|
||||||
memUsed, _ := strconv.ParseUint(strings.TrimSpace(fields[18]), 10, 64)
|
|
||||||
cardSeries := strings.TrimSpace(fields[19])
|
|
||||||
name := device
|
|
||||||
if cardSeries != "" && cardSeries != "N/A" {
|
|
||||||
name = cardSeries + " " + device
|
|
||||||
} else if deviceName != "" && deviceName != "N/A" {
|
|
||||||
name = deviceName + " " + device
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var device string
|
||||||
|
var deviceName string
|
||||||
|
var cardSeries string
|
||||||
|
var gfxVersion string
|
||||||
|
|
||||||
const toMB = 1024 * 1024
|
const toMB = 1024 * 1024
|
||||||
|
|
||||||
return &GpuStat{
|
for i, col := range labels {
|
||||||
Timestamp: time.Now(),
|
val := strings.TrimSpace(fields[i])
|
||||||
ID: id,
|
switch col {
|
||||||
Name: name,
|
case "device":
|
||||||
UUID: uuid,
|
device = val
|
||||||
TempC: int(tempC),
|
id, err := strconv.Atoi(strings.TrimPrefix(val, "card"))
|
||||||
VramTempC: int(vramTempC),
|
if err != nil {
|
||||||
GpuUtilPct: gpuUtil,
|
return nil
|
||||||
MemUtilPct: memUtil,
|
}
|
||||||
MemUsedMB: int(memUsed / toMB),
|
result.ID = id
|
||||||
MemTotalMB: int(memTotal / toMB),
|
case "Device Name":
|
||||||
FanSpeedPct: fanSpeed,
|
deviceName = val
|
||||||
PowerDrawW: powerDraw,
|
case "GUID":
|
||||||
|
result.UUID = val
|
||||||
|
case "Temperature (Sensor edge) (C)":
|
||||||
|
tempC, _ := strconv.ParseFloat(val, 64)
|
||||||
|
result.TempC = int(tempC)
|
||||||
|
case "Temperature (Sensor memory) (C)":
|
||||||
|
vramTempC, _ := strconv.ParseFloat(val, 64)
|
||||||
|
result.VramTempC = int(vramTempC)
|
||||||
|
case "Fan speed (%)":
|
||||||
|
fanSpeed, _ := strconv.ParseFloat(val, 64)
|
||||||
|
result.FanSpeedPct = fanSpeed
|
||||||
|
case "Current Socket Graphics Package Power (W)":
|
||||||
|
fallthrough
|
||||||
|
case "Average Graphics Package Power (W)":
|
||||||
|
powerDraw, _ := strconv.ParseFloat(val, 64)
|
||||||
|
result.PowerDrawW = powerDraw
|
||||||
|
case "GPU use (%)":
|
||||||
|
gpuUtil, _ := strconv.ParseFloat(val, 64)
|
||||||
|
result.GpuUtilPct = gpuUtil
|
||||||
|
case "GPU Memory Allocated (VRAM%)":
|
||||||
|
memUtil, _ := strconv.ParseFloat(val, 64)
|
||||||
|
result.MemUtilPct = memUtil
|
||||||
|
case "VRAM Total Memory (B)":
|
||||||
|
memTotal, _ := strconv.ParseUint(val, 10, 64)
|
||||||
|
result.MemTotalMB = int(memTotal / toMB)
|
||||||
|
case "VRAM Total Used Memory (B)":
|
||||||
|
memUsed, _ := strconv.ParseUint(val, 10, 64)
|
||||||
|
result.MemUsedMB = int(memUsed / toMB)
|
||||||
|
case "Card Series":
|
||||||
|
cardSeries = val
|
||||||
|
case "GFX Version":
|
||||||
|
gfxVersion = val
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if result.ID == -1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
name := device
|
||||||
|
if cardSeries != "" && cardSeries != "N/A" {
|
||||||
|
name = cardSeries + " " + device + " (" + gfxVersion + ")"
|
||||||
|
} else if deviceName != "" && deviceName != "N/A" {
|
||||||
|
name = deviceName + " " + device + " (" + gfxVersion + ")"
|
||||||
|
}
|
||||||
|
result.Name = name
|
||||||
|
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func trySysfs(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
func trySysfs(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
package perf
|
package perf
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
@@ -11,7 +15,68 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func getGpuStats(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
func getGpuStats(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||||
return nil, ErrNotImplemented
|
if ch, err := tryNvidiaSmiWindows(ctx, every, logger); err == nil {
|
||||||
|
logger.Info("using nvidia-smi for GPU monitoring")
|
||||||
|
return ch, nil
|
||||||
|
} else {
|
||||||
|
logger.Debugf("nvidia-smi: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, ErrNoGpuTool
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryNvidiaSmiWindows starts nvidia-smi in loop mode on Windows and returns
|
||||||
|
// a channel receiving GPU stat snapshots. Returns ErrNoGpuTool if nvidia-smi
|
||||||
|
// is not available.
|
||||||
|
func tryNvidiaSmiWindows(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
|
||||||
|
if _, err := exec.LookPath("nvidia-smi"); err != nil {
|
||||||
|
return nil, ErrNoGpuTool
|
||||||
|
}
|
||||||
|
|
||||||
|
sec := int(every.Seconds())
|
||||||
|
if sec < 1 {
|
||||||
|
sec = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "nvidia-smi",
|
||||||
|
"--query-gpu=index,name,uuid,temperature.gpu,utilization.gpu,memory.used,memory.total,fan.speed,power.draw",
|
||||||
|
"--format=csv,noheader,nounits",
|
||||||
|
"--loop", fmt.Sprintf("%d", sec),
|
||||||
|
)
|
||||||
|
|
||||||
|
stdout, err := cmd.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("nvidia-smi stdout pipe failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return nil, fmt.Errorf("nvidia-smi start failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := make(chan []GpuStat, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(stdout)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
stat := ParseNvidiaSmiLine(line)
|
||||||
|
if stat != nil {
|
||||||
|
select {
|
||||||
|
case ch <- []GpuStat{*stat}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cmd.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return ch, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func readSysStats() (SysStat, error) {
|
func readSysStats() (SysStat, error) {
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,49 @@
|
|||||||
|
package process
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var simpleResponderPath string
|
||||||
|
|
||||||
|
func skipIfNoSimpleResponder(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
if _, err := os.Stat(simpleResponderPath); os.IsNotExist(err) {
|
||||||
|
t.Skipf("simple-responder not found at %s, run `make simple-responder`", simpleResponderPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
goos := runtime.GOOS
|
||||||
|
goarch := runtime.GOARCH
|
||||||
|
if goos == "windows" {
|
||||||
|
simpleResponderPath = filepath.Join("..", "..", "build", "simple-responder.exe")
|
||||||
|
} else {
|
||||||
|
simpleResponderPath = filepath.Join("..", "..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||||
|
}
|
||||||
|
m.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func getFreePort(t *testing.T) int {
|
||||||
|
t.Helper()
|
||||||
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("getFreePort: %v", err)
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
return l.Addr().(*net.TCPAddr).Port
|
||||||
|
}
|
||||||
|
|
||||||
|
func simpleResponderCmd(t *testing.T, args ...string) (string, int) {
|
||||||
|
port := getFreePort(t)
|
||||||
|
cmdPath := filepath.ToSlash(simpleResponderPath)
|
||||||
|
base := []string{cmdPath, fmt.Sprintf("-port %d", port)}
|
||||||
|
base = append(base, args...)
|
||||||
|
return strings.Join(base, " "), port
|
||||||
|
}
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
package process
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ProcessState string
|
||||||
|
|
||||||
|
const (
|
||||||
|
StateStopped ProcessState = ProcessState("stopped")
|
||||||
|
StateStarting ProcessState = ProcessState("starting")
|
||||||
|
StateReady ProcessState = ProcessState("ready")
|
||||||
|
StateStopping ProcessState = ProcessState("stopping")
|
||||||
|
|
||||||
|
// process is shutdown and will not be restarted
|
||||||
|
StateShutdown ProcessState = ProcessState("shutdown")
|
||||||
|
)
|
||||||
|
|
||||||
|
type Process interface {
|
||||||
|
// Run starts the process blocks until the process is terminated.
|
||||||
|
// The timeout parameter controls how long to wait for the process to get
|
||||||
|
// to a ready state to process traffic
|
||||||
|
Run(timeout time.Duration) error
|
||||||
|
|
||||||
|
// WaitReady blocks until the process is ready to serve requests
|
||||||
|
// or the context is cancelled. It returns nil when the process is ready
|
||||||
|
WaitReady(context.Context) error
|
||||||
|
|
||||||
|
// Stop blocks until the process has terminated. It returns nil when
|
||||||
|
// the process terminated as expected (exit 0)
|
||||||
|
Stop(timeout time.Duration) error
|
||||||
|
|
||||||
|
// State returns the current state of the process
|
||||||
|
// Note: this is a snapshot of the state at the time of the call
|
||||||
|
// and may change at any time after the call returns.
|
||||||
|
State() ProcessState
|
||||||
|
|
||||||
|
// ServeHTTP forwards requests to the underlying process
|
||||||
|
// Calling it when the process is not ready will result in a
|
||||||
|
// 503 response with a body indicating it is a llama-swap-error
|
||||||
|
ServeHTTP(http.ResponseWriter, *http.Request)
|
||||||
|
|
||||||
|
// Logger returns the monitor that captures this process's stdout/stderr.
|
||||||
|
Logger() *logmon.Monitor
|
||||||
|
}
|
||||||
@@ -0,0 +1,684 @@
|
|||||||
|
package process
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/http/httputil"
|
||||||
|
"net/url"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/event"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrStartAborted = fmt.Errorf("aborted")
|
||||||
|
|
||||||
|
// cmdWaitDelay is the upper bound the runtime will wait for child I/O to
|
||||||
|
// drain after the process exits before force-closing the stdout/stderr
|
||||||
|
// pipes. Required so that cmd.Wait() returns even when a forked grandchild
|
||||||
|
// inherits and holds the pipes open (e.g. a shell wrapper that backgrounds
|
||||||
|
// the real binary). killProcess sends the stop signal directly (not via the
|
||||||
|
// cmd context), so this delay is measured from process exit rather than from
|
||||||
|
// the stop request, and stays independent of the caller's graceful timeout.
|
||||||
|
const cmdWaitDelay = 10 * time.Second
|
||||||
|
|
||||||
|
// parentCancelGraceTimeout is the graceful timeout used when the process is
|
||||||
|
// torn down because parentCtx was cancelled (final router teardown or app
|
||||||
|
// shutdown). In the normal flow the process has already been stopped via
|
||||||
|
// Stop() by this point, so killProcess is a no-op kill; the short grace just
|
||||||
|
// bounds the rare case where a process is still alive when its context is cut.
|
||||||
|
const parentCancelGraceTimeout = time.Second
|
||||||
|
|
||||||
|
type runReq struct {
|
||||||
|
timeout time.Duration
|
||||||
|
respond chan error
|
||||||
|
}
|
||||||
|
|
||||||
|
type stopReq struct {
|
||||||
|
timeout time.Duration
|
||||||
|
respond chan error
|
||||||
|
}
|
||||||
|
|
||||||
|
type waitReadyReq struct {
|
||||||
|
respond chan error
|
||||||
|
}
|
||||||
|
|
||||||
|
type startResult struct {
|
||||||
|
cmd *exec.Cmd
|
||||||
|
cmdDone chan struct{}
|
||||||
|
cancel context.CancelFunc
|
||||||
|
handlerFn http.HandlerFunc
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProcessCommand struct {
|
||||||
|
id string
|
||||||
|
config config.ModelConfig
|
||||||
|
parentCtx context.Context
|
||||||
|
|
||||||
|
processLogger *logmon.Monitor
|
||||||
|
proxyLogger *logmon.Monitor
|
||||||
|
|
||||||
|
// waitDelay is assigned to cmd.WaitDelay when starting the upstream
|
||||||
|
// process. Defaults to cmdWaitDelay; tests override it to keep the
|
||||||
|
// pipe-close backstop from dominating their runtime.
|
||||||
|
waitDelay time.Duration
|
||||||
|
|
||||||
|
runCh chan runReq
|
||||||
|
stopCh chan stopReq
|
||||||
|
waitReadyCh chan waitReadyReq
|
||||||
|
|
||||||
|
// current ProcessState. Written only by run(); read by State() via atomic load.
|
||||||
|
state atomic.Value
|
||||||
|
|
||||||
|
// stores the active reverse-proxy handler when the process is running.
|
||||||
|
// Written only by run(); read by ServeHTTP via atomic load.
|
||||||
|
handler atomic.Pointer[http.HandlerFunc]
|
||||||
|
|
||||||
|
lastUse atomic.Int64 // unix nano timestamp of last ServeHTTP completion
|
||||||
|
inflight atomic.Int64 // current in-flight ServeHTTP calls
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Process = (*ProcessCommand)(nil)
|
||||||
|
|
||||||
|
func New(
|
||||||
|
parentCtx context.Context,
|
||||||
|
id string,
|
||||||
|
conf config.ModelConfig,
|
||||||
|
processLogger *logmon.Monitor,
|
||||||
|
proxyLogger *logmon.Monitor,
|
||||||
|
) (*ProcessCommand, error) {
|
||||||
|
p := &ProcessCommand{
|
||||||
|
id: id,
|
||||||
|
config: conf,
|
||||||
|
parentCtx: parentCtx,
|
||||||
|
processLogger: processLogger,
|
||||||
|
proxyLogger: proxyLogger,
|
||||||
|
|
||||||
|
runCh: make(chan runReq),
|
||||||
|
stopCh: make(chan stopReq),
|
||||||
|
waitReadyCh: make(chan waitReadyReq),
|
||||||
|
waitDelay: cmdWaitDelay,
|
||||||
|
}
|
||||||
|
p.state.Store(StateStopped)
|
||||||
|
|
||||||
|
go p.run()
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProcessCommand) Logger() *logmon.Monitor { return p.processLogger }
|
||||||
|
|
||||||
|
// run is the single-writer goroutine that owns all mutable lifecycle state
|
||||||
|
// (current ProcessState, the running *exec.Cmd, the active reverse-proxy
|
||||||
|
// handler, and the list of WaitReady subscribers). Every public method
|
||||||
|
// (Run / Stop / State / WaitReady) is a thin client that sends a request on
|
||||||
|
// one of the channels below and waits for a response — this funnels concurrent
|
||||||
|
// callers through a single serialization point so the state machine never
|
||||||
|
// observes a race.
|
||||||
|
func (p *ProcessCommand) run() {
|
||||||
|
// Mutable state — only read/written from this goroutine. ServeHTTP reads
|
||||||
|
// p.handler concurrently, which is why handler is an atomic.Pointer.
|
||||||
|
// p.state mirrors `state` so State() can observe transitions; setState
|
||||||
|
// writes both.
|
||||||
|
state := StateStopped
|
||||||
|
setState := func(s ProcessState) {
|
||||||
|
old := state
|
||||||
|
state = s
|
||||||
|
p.state.Store(s)
|
||||||
|
if old != s {
|
||||||
|
event.Emit(shared.ProcessStateChangeEvent{
|
||||||
|
ProcessName: p.id,
|
||||||
|
OldState: string(old),
|
||||||
|
NewState: string(s),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
cmd *exec.Cmd
|
||||||
|
cmdDone <-chan struct{}
|
||||||
|
cmdCancel context.CancelFunc
|
||||||
|
readyWaiters []waitReadyReq
|
||||||
|
// runResp parks the in-flight Run caller's response channel. The
|
||||||
|
// interface contract is that Run blocks until the process is
|
||||||
|
// terminated, so we hold this until Stop, parentCtx, or an
|
||||||
|
// upstream exit unblocks it via respondRun.
|
||||||
|
runResp chan<- error
|
||||||
|
)
|
||||||
|
|
||||||
|
// notifyWaiters wakes every blocked WaitReady caller with the given result.
|
||||||
|
// Used on transitions out of StateStarting (ready, failed, aborted, or
|
||||||
|
// shutdown) — anything that resolves the "is it ready yet?" question.
|
||||||
|
notifyWaiters := func(err error) {
|
||||||
|
for _, w := range readyWaiters {
|
||||||
|
select {
|
||||||
|
case w.respond <- err:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
readyWaiters = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// respondRun delivers the final Run result, if a Run caller is parked.
|
||||||
|
respondRun := func(err error) {
|
||||||
|
if runResp != nil {
|
||||||
|
runResp <- err
|
||||||
|
runResp = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
// Shutdown: parent context cancelled. Tear down any running process,
|
||||||
|
// wake any pending WaitReady callers with an error, then exit the
|
||||||
|
// goroutine permanently. Subsequent public-method calls will fail
|
||||||
|
// because parentCtx.Done() unblocks their send-side selects.
|
||||||
|
case <-p.parentCtx.Done():
|
||||||
|
// Mark shutdown before killProcess so concurrent State() readers
|
||||||
|
// stop treating this process as ready while the (possibly slow)
|
||||||
|
// teardown is in progress.
|
||||||
|
setState(StateShutdown)
|
||||||
|
if cmd != nil {
|
||||||
|
p.handler.Store(nil)
|
||||||
|
p.killProcess(cmd, cmdCancel, cmdDone, parentCancelGraceTimeout)
|
||||||
|
cmd = nil
|
||||||
|
cmdDone = nil
|
||||||
|
cmdCancel = nil
|
||||||
|
}
|
||||||
|
notifyWaiters(fmt.Errorf("[%s] shutdown", p.id))
|
||||||
|
respondRun(fmt.Errorf("[%s] shutdown", p.id))
|
||||||
|
return
|
||||||
|
|
||||||
|
// Upstream exited on its own (not via Stop). Drop handler state,
|
||||||
|
// transition to Stopped, and unblock the parked Run caller.
|
||||||
|
// cmdDone is nil while no process is running, so this case is
|
||||||
|
// dormant outside of StateReady.
|
||||||
|
case <-cmdDone:
|
||||||
|
if cmdCancel != nil {
|
||||||
|
cmdCancel()
|
||||||
|
}
|
||||||
|
cmd = nil
|
||||||
|
cmdDone = nil
|
||||||
|
cmdCancel = nil
|
||||||
|
p.handler.Store(nil)
|
||||||
|
setState(StateStopped)
|
||||||
|
respondRun(fmt.Errorf("[%s] upstream exited unexpectedly", p.id))
|
||||||
|
|
||||||
|
// WaitReady: if we're already in a terminal-for-this-question state,
|
||||||
|
// respond immediately; otherwise queue the caller and let a future
|
||||||
|
// state transition wake them via notifyWaiters.
|
||||||
|
case req := <-p.waitReadyCh:
|
||||||
|
switch state {
|
||||||
|
case StateReady:
|
||||||
|
req.respond <- nil
|
||||||
|
case StateShutdown:
|
||||||
|
req.respond <- fmt.Errorf("[%s] shutdown", p.id)
|
||||||
|
default:
|
||||||
|
readyWaiters = append(readyWaiters, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run: start the upstream process. Only valid from StateStopped.
|
||||||
|
// doStart can take a long time (health-check polling), so it runs in
|
||||||
|
// a separate goroutine and we wait on resultCh. While waiting we also
|
||||||
|
// listen for an incoming Stop — that's how callers cancel an in-flight
|
||||||
|
// start.
|
||||||
|
case req := <-p.runCh:
|
||||||
|
if state != StateStopped {
|
||||||
|
req.respond <- fmt.Errorf("[%s] could not be started in %s state", p.id, state)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
setState(StateStarting)
|
||||||
|
|
||||||
|
startCtx, cancelStart := context.WithCancel(context.Background())
|
||||||
|
resultCh := make(chan startResult, 1)
|
||||||
|
go func() {
|
||||||
|
resultCh <- p.doStart(startCtx, req.timeout)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// pendingStop holds a Stop request that arrived mid-start, so we
|
||||||
|
// can respond to it AFTER we've finished tearing the start down.
|
||||||
|
var pendingStop *stopReq
|
||||||
|
select {
|
||||||
|
// doStart finished on its own — either successfully (latch
|
||||||
|
// cmd/handler and move to Ready) or with an error (back to
|
||||||
|
// Stopped). Either way wake WaitReady subscribers and reply
|
||||||
|
// to the Run caller.
|
||||||
|
case res := <-resultCh:
|
||||||
|
if res.err == nil {
|
||||||
|
cmd = res.cmd
|
||||||
|
cmdDone = res.cmdDone
|
||||||
|
cmdCancel = res.cancel
|
||||||
|
fn := res.handlerFn
|
||||||
|
p.handler.Store(&fn)
|
||||||
|
setState(StateReady)
|
||||||
|
notifyWaiters(nil)
|
||||||
|
// Park the Run response — Run blocks until the process
|
||||||
|
// terminates, so we only fire this when Stop, parentCtx,
|
||||||
|
// or the upstream exit takes the process down.
|
||||||
|
runResp = req.respond
|
||||||
|
|
||||||
|
// Start TTL goroutine if configured — self-terminates
|
||||||
|
// when state leaves StateReady.
|
||||||
|
if p.config.UnloadAfter > 0 {
|
||||||
|
ttlDuration := time.Duration(p.config.UnloadAfter) * time.Second
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for range ticker.C {
|
||||||
|
if p.State() != StateReady {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if p.inflight.Load() != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if time.Since(time.Unix(0, p.lastUse.Load())) > ttlDuration {
|
||||||
|
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.id, p.config.UnloadAfter)
|
||||||
|
p.Stop(10 * time.Second)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
setState(StateStopped)
|
||||||
|
notifyWaiters(res.err)
|
||||||
|
req.respond <- res.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop arrived while doStart was still running. Cancel the
|
||||||
|
// start context to abort it, then wait for doStart to return.
|
||||||
|
// If doStart had already crossed the finish line before
|
||||||
|
// cancellation took effect, it returns a live cmd that we
|
||||||
|
// must kill ourselves. The Run caller gets ErrAbort; the Stop
|
||||||
|
// caller is parked in pendingStop and answered below.
|
||||||
|
case stop := <-p.stopCh:
|
||||||
|
cancelStart()
|
||||||
|
res := <-resultCh
|
||||||
|
if res.cmd != nil {
|
||||||
|
p.killProcess(res.cmd, res.cancel, res.cmdDone, stop.timeout)
|
||||||
|
}
|
||||||
|
setState(StateStopped)
|
||||||
|
notifyWaiters(ErrStartAborted)
|
||||||
|
req.respond <- ErrStartAborted
|
||||||
|
pendingStop = &stop
|
||||||
|
|
||||||
|
// Parent context cancelled (e.g. config reload) while doStart
|
||||||
|
// was still running. Stop() returns early when parentCtx is
|
||||||
|
// done and never sends on stopCh, so we must handle shutdown
|
||||||
|
// here to avoid leaving doStart running indefinitely.
|
||||||
|
case <-p.parentCtx.Done():
|
||||||
|
cancelStart()
|
||||||
|
// Mark shutdown before tearing the process down: killProcess
|
||||||
|
// may block (e.g. taskkill on Windows is slow to spawn), and
|
||||||
|
// callers observing State() should see StateShutdown promptly
|
||||||
|
// rather than a stale StateStarting.
|
||||||
|
setState(StateShutdown)
|
||||||
|
res := <-resultCh
|
||||||
|
if res.cmd != nil {
|
||||||
|
p.killProcess(res.cmd, res.cancel, res.cmdDone, parentCancelGraceTimeout)
|
||||||
|
}
|
||||||
|
notifyWaiters(fmt.Errorf("[%s] shutdown", p.id))
|
||||||
|
respondRun(fmt.Errorf("[%s] shutdown", p.id))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// cancelStart is idempotent; calling it again here ensures the
|
||||||
|
// context is released even on the success path (govet leak check).
|
||||||
|
cancelStart()
|
||||||
|
if pendingStop != nil {
|
||||||
|
pendingStop.respond <- nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop: tear down a running process.
|
||||||
|
case stop := <-p.stopCh:
|
||||||
|
if cmd != nil {
|
||||||
|
setState(StateStopping)
|
||||||
|
p.killProcess(cmd, cmdCancel, cmdDone, stop.timeout)
|
||||||
|
cmd = nil
|
||||||
|
cmdDone = nil
|
||||||
|
cmdCancel = nil
|
||||||
|
p.handler.Store(nil)
|
||||||
|
}
|
||||||
|
// Stop is a no-op (and not an error) when already Stopped — this
|
||||||
|
// is what makes it idempotent for callers that don't track state.
|
||||||
|
setState(StateStopped)
|
||||||
|
respondRun(nil)
|
||||||
|
stop.respond <- nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProcessCommand) doStart(startCtx context.Context, healthCheckTimeout time.Duration) startResult {
|
||||||
|
if p.config.Proxy == "" {
|
||||||
|
return startResult{err: fmt.Errorf("upstream proxy missing")}
|
||||||
|
}
|
||||||
|
|
||||||
|
args, err := p.config.SanitizedCommand()
|
||||||
|
if err != nil {
|
||||||
|
return startResult{err: fmt.Errorf("unable to get sanitized command: %w", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL, err := url.Parse(p.config.Proxy)
|
||||||
|
if err != nil {
|
||||||
|
return startResult{err: fmt.Errorf("invalid proxy URL %q: %w", p.config.Proxy, err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
reverseProxy := httputil.NewSingleHostReverseProxy(proxyURL)
|
||||||
|
reverseProxy.Transport = &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: time.Duration(p.config.Timeouts.Connect) * time.Second,
|
||||||
|
KeepAlive: time.Duration(p.config.Timeouts.KeepAlive) * time.Second,
|
||||||
|
}).DialContext,
|
||||||
|
TLSHandshakeTimeout: time.Duration(p.config.Timeouts.TLSHandshake) * time.Second,
|
||||||
|
ResponseHeaderTimeout: time.Duration(p.config.Timeouts.ResponseHeader) * time.Second,
|
||||||
|
ExpectContinueTimeout: time.Duration(p.config.Timeouts.ExpectContinue) * time.Second,
|
||||||
|
ForceAttemptHTTP2: true,
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
MaxIdleConnsPerHost: 10,
|
||||||
|
IdleConnTimeout: time.Duration(p.config.Timeouts.IdleConn) * time.Second,
|
||||||
|
}
|
||||||
|
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||||
|
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||||
|
resp.Header.Set("X-Accel-Buffering", "no")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// httputil.ReverseProxy panics with http.ErrAbortHandler when the upstream
|
||||||
|
// disconnects after response headers have been sent. Recover here so the
|
||||||
|
// streaming termination is treated as a normal client/upstream disconnect.
|
||||||
|
// see: https://github.com/golang/go/issues/23643
|
||||||
|
handlerFn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer func() {
|
||||||
|
if rec := recover(); rec != nil {
|
||||||
|
if rec == http.ErrAbortHandler {
|
||||||
|
p.proxyLogger.Infof("<%s> recovered from upstream disconnection during streaming", p.id)
|
||||||
|
} else {
|
||||||
|
p.proxyLogger.Warnf("<%s> recovered from panic: %v", p.id, rec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
reverseProxy.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
|
||||||
|
// cmdCtx + cmd.Cancel are wired as a safety net: if the context is ever
|
||||||
|
// cancelled while the process is alive, cmd.Cancel sends SIGTERM / CmdStop
|
||||||
|
// and the runtime escalates to SIGKILL after cmd.WaitDelay. In the normal
|
||||||
|
// teardown path killProcess sends the stop signal directly instead, so
|
||||||
|
// cmd.WaitDelay only acts as the inherited-pipe backstop measured from
|
||||||
|
// process exit (see killProcess).
|
||||||
|
cmdCtx, cmdCancel := context.WithCancel(context.Background())
|
||||||
|
cmd := exec.CommandContext(cmdCtx, args[0], args[1:]...)
|
||||||
|
cmd.Stderr = p.processLogger
|
||||||
|
cmd.Stdout = p.processLogger
|
||||||
|
cmd.Env = append(cmd.Environ(), p.config.Env...)
|
||||||
|
cmd.Cancel = func() error { return p.sendStopSignal(cmd) }
|
||||||
|
cmd.WaitDelay = p.waitDelay
|
||||||
|
setProcAttributes(cmd)
|
||||||
|
|
||||||
|
p.proxyLogger.Debugf("<%s> Executing start command: %s, env: %s", p.id, strings.Join(args, " "), strings.Join(p.config.Env, ", "))
|
||||||
|
|
||||||
|
cmdDone := make(chan struct{})
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
cmdCancel()
|
||||||
|
return startResult{err: fmt.Errorf("failed to start command '%s': %w", strings.Join(args, " "), err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
waitErr := cmd.Wait()
|
||||||
|
switch st := p.State(); {
|
||||||
|
case waitErr == nil:
|
||||||
|
p.proxyLogger.Debugf("<%s> process exited cleanly", p.id)
|
||||||
|
case st == StateStopping || st == StateShutdown:
|
||||||
|
// Expected: we force-terminated the process. A forced kill exits
|
||||||
|
// the child with a non-zero code (e.g. taskkill /f on Windows
|
||||||
|
// yields exit status 1), so this is not an error.
|
||||||
|
p.proxyLogger.Debugf("<%s> process stopped by llama-swap: %v", p.id, waitErr)
|
||||||
|
default:
|
||||||
|
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||||
|
p.proxyLogger.Debugf("<%s> process exited: code=%d, err=%v", p.id, exitErr.ExitCode(), waitErr)
|
||||||
|
} else {
|
||||||
|
p.proxyLogger.Debugf("<%s> process exited with error: %v", p.id, waitErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
close(cmdDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
abort := func(err error) startResult {
|
||||||
|
p.killProcess(cmd, cmdCancel, cmdDone, 5*time.Second)
|
||||||
|
return startResult{err: err}
|
||||||
|
}
|
||||||
|
prematureExit := func() startResult {
|
||||||
|
cmdCancel()
|
||||||
|
return startResult{err: fmt.Errorf("upstream command exited prematurely")}
|
||||||
|
}
|
||||||
|
|
||||||
|
if startCtx.Err() != nil {
|
||||||
|
return abort(ErrStartAborted)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
||||||
|
if checkEndpoint == "none" {
|
||||||
|
return startResult{cmd: cmd, cmdDone: cmdDone, cancel: cmdCancel, handlerFn: handlerFn}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait 250ms for the command to start up before health checking
|
||||||
|
select {
|
||||||
|
case <-startCtx.Done():
|
||||||
|
return abort(ErrStartAborted)
|
||||||
|
case <-time.After(250 * time.Millisecond):
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline := time.Now().Add(healthCheckTimeout)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-startCtx.Done():
|
||||||
|
return abort(ErrStartAborted)
|
||||||
|
case <-cmdDone:
|
||||||
|
return prematureExit()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().After(deadline) {
|
||||||
|
return abort(fmt.Errorf("health check timed out after %v", healthCheckTimeout))
|
||||||
|
}
|
||||||
|
|
||||||
|
req, _ := http.NewRequestWithContext(startCtx, "GET", p.config.CheckEndpoint, nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
reverseProxy.ServeHTTP(rr, req)
|
||||||
|
resp := rr.Result()
|
||||||
|
resp.Body.Close()
|
||||||
|
if resp.StatusCode == http.StatusOK {
|
||||||
|
p.proxyLogger.Infof("<%s> Health check passed on %s%s", p.id, p.config.Proxy, p.config.CheckEndpoint)
|
||||||
|
break
|
||||||
|
} else if startCtx.Err() != nil {
|
||||||
|
return abort(ErrStartAborted)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-startCtx.Done():
|
||||||
|
return abort(ErrStartAborted)
|
||||||
|
case <-cmdDone:
|
||||||
|
return prematureExit()
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return startResult{cmd: cmd, cmdDone: cmdDone, cancel: cmdCancel, handlerFn: handlerFn}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendStopSignal runs the configured CmdStop (if any) or sends SIGTERM to
|
||||||
|
// the upstream process. Wired up as cmd.Cancel so it fires whenever the
|
||||||
|
// cmd's context is cancelled.
|
||||||
|
func (p *ProcessCommand) sendStopSignal(cmd *exec.Cmd) error {
|
||||||
|
if cmd == nil || cmd.Process == nil {
|
||||||
|
p.processLogger.Debugf("<%s> sendStopSignal() called with nil cmd or process, nothing to stop", p.id)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
pid := cmd.Process.Pid
|
||||||
|
if p.config.CmdStop != "" {
|
||||||
|
p.processLogger.Debugf("<%s> sendStopSignal() using CmdStop %q for pid %d", p.id, p.config.CmdStop, pid)
|
||||||
|
stopArgs, err := config.SanitizeCommand(
|
||||||
|
strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", pid)),
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
p.processLogger.Debugf("<%s> sendStopSignal() running stop command: %s", p.id, strings.Join(stopArgs, " "))
|
||||||
|
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
|
||||||
|
stopCmd.Env = cmd.Env
|
||||||
|
setProcAttributes(stopCmd)
|
||||||
|
runErr := stopCmd.Run()
|
||||||
|
if runErr != nil {
|
||||||
|
p.processLogger.Errorf("<%s> sendStopSignal() stop command failed: %v", p.id, runErr)
|
||||||
|
} else {
|
||||||
|
p.processLogger.Debugf("<%s> sendStopSignal() stop command completed for pid %d", p.id, pid)
|
||||||
|
}
|
||||||
|
return runErr
|
||||||
|
}
|
||||||
|
// fall through to SIGTERM if sanitize failed
|
||||||
|
p.processLogger.Errorf("<%s> sendStopSignal() failed to sanitize CmdStop %q: %v, falling back to terminateProcessTree", p.id, p.config.CmdStop, err)
|
||||||
|
}
|
||||||
|
// On Unix this SIGTERMs the whole process group so a forked grandchild
|
||||||
|
// (e.g. a shell wrapper that backgrounds the real binary) is taken down
|
||||||
|
// with the parent rather than orphaned.
|
||||||
|
p.processLogger.Debugf("<%s> sendStopSignal() no CmdStop configured, calling terminateProcessTree for pid %d", p.id, pid)
|
||||||
|
termErr := terminateProcessTree(cmd)
|
||||||
|
if termErr != nil {
|
||||||
|
p.processLogger.Errorf("<%s> sendStopSignal() terminateProcessTree failed for pid %d: %v", p.id, pid, termErr)
|
||||||
|
}
|
||||||
|
return termErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// killProcess terminates the upstream process. The flow:
|
||||||
|
//
|
||||||
|
// 1. Send the graceful stop signal (CmdStop / SIGTERM) directly — NOT by
|
||||||
|
// cancelling cmdCtx. Cancelling the context would start cmd.WaitDelay
|
||||||
|
// immediately, which force-kills the process WaitDelay after the signal
|
||||||
|
// and would silently cap gracefulTimeout at WaitDelay whenever
|
||||||
|
// gracefulTimeout is the longer of the two.
|
||||||
|
// 2. We wait up to gracefulTimeout for the process to exit on its own.
|
||||||
|
// 3. If still alive, we SIGKILL the process group directly (Unix) so any
|
||||||
|
// forked descendant is force-terminated alongside the parent.
|
||||||
|
// 4. We wait on cmdDone. cmd.WaitDelay (set when the cmd was built) is the
|
||||||
|
// critical backstop here: once the process exits, if a forked grandchild
|
||||||
|
// inherited the stdout/stderr pipes and is still holding them, the runtime
|
||||||
|
// force-closes the pipes WaitDelay after the exit and cmd.Wait() unblocks.
|
||||||
|
// Because we never cancelled the context, that WaitDelay timer measures
|
||||||
|
// from process exit (see os/exec awaitGoroutines), not from this call.
|
||||||
|
// Without WaitDelay this select would hang forever (the v219 bug).
|
||||||
|
//
|
||||||
|
// cancel() is still invoked (deferred) to release the context, but only after
|
||||||
|
// the process has exited and os/exec's ctx watcher has already torn down, so it
|
||||||
|
// never re-fires cmd.Cancel.
|
||||||
|
func (p *ProcessCommand) killProcess(cmd *exec.Cmd, cancel context.CancelFunc, cmdDone <-chan struct{}, gracefulTimeout time.Duration) {
|
||||||
|
if cancel == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Deliver CmdStop / SIGTERM in a goroutine so a slow or hanging CmdStop
|
||||||
|
// cannot block the run() goroutine; the gracefulTimeout + Process.Kill
|
||||||
|
// path below still guarantees teardown.
|
||||||
|
if cmd != nil {
|
||||||
|
go func() {
|
||||||
|
p.proxyLogger.Debugf("[%s] sending stop signal with timeout %v", p.id, gracefulTimeout)
|
||||||
|
if err := p.sendStopSignal(cmd); err != nil {
|
||||||
|
p.proxyLogger.Warnf("[%s] stop signal failed: %v", p.id, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
timer := time.NewTimer(gracefulTimeout)
|
||||||
|
defer timer.Stop()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-cmdDone:
|
||||||
|
return
|
||||||
|
case <-timer.C:
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd != nil {
|
||||||
|
// SIGKILL the whole process group on Unix so any descendant that
|
||||||
|
// ignored or outlived the graceful signal is force-terminated too.
|
||||||
|
_ = killProcessTree(cmd)
|
||||||
|
}
|
||||||
|
<-cmdDone
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProcessCommand) ID() string {
|
||||||
|
return p.id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProcessCommand) Run(timeout time.Duration) error {
|
||||||
|
req := runReq{
|
||||||
|
timeout: timeout,
|
||||||
|
respond: make(chan error, 1),
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case p.runCh <- req:
|
||||||
|
case <-p.parentCtx.Done():
|
||||||
|
return fmt.Errorf("[%s] shutdown", p.id)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case err := <-req.respond:
|
||||||
|
return err
|
||||||
|
case <-p.parentCtx.Done():
|
||||||
|
return fmt.Errorf("[%s] shutdown", p.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProcessCommand) WaitReady(ctx context.Context) error {
|
||||||
|
req := waitReadyReq{respond: make(chan error, 1)}
|
||||||
|
select {
|
||||||
|
case p.waitReadyCh <- req:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-p.parentCtx.Done():
|
||||||
|
return fmt.Errorf("[%s] shutdown", p.id)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case err := <-req.respond:
|
||||||
|
return err
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProcessCommand) Stop(timeout time.Duration) error {
|
||||||
|
req := stopReq{
|
||||||
|
timeout: timeout,
|
||||||
|
respond: make(chan error, 1),
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case p.stopCh <- req:
|
||||||
|
case <-p.parentCtx.Done():
|
||||||
|
return fmt.Errorf("[%s] shutdown", p.id)
|
||||||
|
}
|
||||||
|
return <-req.respond
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProcessCommand) State() ProcessState {
|
||||||
|
if s, ok := p.state.Load().(ProcessState); ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return StateStopped
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProcessCommand) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fn := p.handler.Load()
|
||||||
|
if fn == nil {
|
||||||
|
http.Error(w, fmt.Sprintf("llama-swap-error: [%s] process is not ready", p.id), http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.inflight.Add(1)
|
||||||
|
defer func() {
|
||||||
|
p.lastUse.Store(time.Now().UnixNano())
|
||||||
|
p.inflight.Add(-1)
|
||||||
|
}()
|
||||||
|
(*fn)(w, r)
|
||||||
|
}
|
||||||
@@ -0,0 +1,262 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package process
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestProcessCommand_StopForkingWrapper is a regression for the bug reported
|
||||||
|
// against v219 where Stop would hang indefinitely when the upstream command
|
||||||
|
// is a shell wrapper that forks the real binary (e.g. `#!/bin/bash` then
|
||||||
|
// `"$@"`). After SIGTERM the wrapper dies but the grandchild inherits the
|
||||||
|
// stdout/stderr pipes; cmd.Wait() blocks waiting for the pipe-copy goroutine
|
||||||
|
// to drain EOF, which never happens while the grandchild holds the fds.
|
||||||
|
//
|
||||||
|
// The fix is cmd.WaitDelay (combined with exec.CommandContext + cmd.Cancel),
|
||||||
|
// which causes the runtime to force-close the pipes after the delay so
|
||||||
|
// cmd.Wait() — and therefore Stop — returns.
|
||||||
|
func TestProcessCommand_StopForkingWrapper(t *testing.T) {
|
||||||
|
skipIfNoSimpleResponder(t)
|
||||||
|
|
||||||
|
port := getFreePort(t)
|
||||||
|
dir := t.TempDir()
|
||||||
|
pidFile := filepath.Join(dir, "child.pid")
|
||||||
|
|
||||||
|
// Wrapper script: backgrounds the child (which inherits stdout/stderr),
|
||||||
|
// records its PID for cleanup, then waits. When SIGTERM hits bash it
|
||||||
|
// dies without forwarding the signal; the grandchild keeps running and
|
||||||
|
// keeps the inherited pipe fds open. This is the scenario reported in
|
||||||
|
// the v219 regression.
|
||||||
|
wrapper := filepath.Join(dir, "wrapper.sh")
|
||||||
|
script := fmt.Sprintf("#!/bin/bash\n%q -port %d -silent &\necho $! > %q\nwait\n",
|
||||||
|
simpleResponderPath, port, pidFile)
|
||||||
|
if err := os.WriteFile(wrapper, []byte(script), 0o755); err != nil {
|
||||||
|
t.Fatalf("WriteFile: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { killChildFromPidFile(pidFile) })
|
||||||
|
|
||||||
|
p := newProcessCommand(t, config.ModelConfig{
|
||||||
|
Cmd: wrapper,
|
||||||
|
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
HealthCheckTimeout: 10,
|
||||||
|
})
|
||||||
|
// Shrink the pipe-close backstop so the test doesn't sit at the
|
||||||
|
// production default (10s). Must be set before Run() so doStart picks
|
||||||
|
// it up when building the cmd.
|
||||||
|
const testWaitDelay = 250 * time.Millisecond
|
||||||
|
p.waitDelay = testWaitDelay
|
||||||
|
|
||||||
|
runErr := runAsync(t, p)
|
||||||
|
|
||||||
|
// Stop must return within a bounded time even though the grandchild
|
||||||
|
// is still holding the pipe open. Budget is generous on top of
|
||||||
|
// testWaitDelay to absorb scheduling jitter on slow CI runners; the
|
||||||
|
// pre-fix behaviour was an unbounded hang, so any reasonable cap
|
||||||
|
// distinguishes pass from fail.
|
||||||
|
stopReturned := make(chan error, 1)
|
||||||
|
stopStart := time.Now()
|
||||||
|
go func() { stopReturned <- p.Stop(testStopTimeout) }()
|
||||||
|
|
||||||
|
const stopBudget = testWaitDelay + 2*time.Second
|
||||||
|
select {
|
||||||
|
case err := <-stopReturned:
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Stop: %v", err)
|
||||||
|
}
|
||||||
|
t.Logf("Stop returned in %v", time.Since(stopStart))
|
||||||
|
case <-time.After(stopBudget):
|
||||||
|
t.Fatalf("Stop did not return within %v — cmd.Wait() likely hung on inherited pipe", stopBudget)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := p.State(); got != StateStopped {
|
||||||
|
t.Errorf("after Stop: expected state %s, got %s", StateStopped, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-runErr:
|
||||||
|
case <-time.After(testReturnTimeout):
|
||||||
|
t.Errorf("Run did not return after Stop")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcessCommand_StopHonorsGracefulTimeout is a regression for the bug
|
||||||
|
// where cmd.WaitDelay capped the graceful shutdown window. killProcess used to
|
||||||
|
// cancel the cmd context to deliver SIGTERM, which starts cmd.WaitDelay
|
||||||
|
// immediately; a process whose SIGTERM handler needs longer than WaitDelay to
|
||||||
|
// finish was force-killed early even though Stop was given a much longer
|
||||||
|
// timeout. The fix sends the signal directly so WaitDelay measures from process
|
||||||
|
// exit (its inherited-pipe backstop role), leaving the graceful window to the
|
||||||
|
// caller's Stop timeout.
|
||||||
|
func TestProcessCommand_StopHonorsGracefulTimeout(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
marker := filepath.Join(dir, "graceful.done")
|
||||||
|
ready := filepath.Join(dir, "trap.ready")
|
||||||
|
|
||||||
|
// On SIGTERM, sleep past the (short) WaitDelay, then write the marker and
|
||||||
|
// exit cleanly. If WaitDelay still drove the kill, bash would be SIGKILLed
|
||||||
|
// mid-handler and the marker would never be written. The ready file is
|
||||||
|
// written only after the trap is installed so the test does not race
|
||||||
|
// SIGTERM ahead of it (CheckEndpoint:none marks ready before bash runs).
|
||||||
|
script := filepath.Join(dir, "graceful.sh")
|
||||||
|
body := fmt.Sprintf(
|
||||||
|
"#!/bin/bash\ncleanup() { sleep 0.6; echo done > %q; exit 0; }\ntrap cleanup SIGTERM\necho ready > %q\nwhile true; do sleep 0.1; done\n",
|
||||||
|
marker, ready,
|
||||||
|
)
|
||||||
|
if err := os.WriteFile(script, []byte(body), 0o755); err != nil {
|
||||||
|
t.Fatalf("WriteFile: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p := newProcessCommand(t, config.ModelConfig{
|
||||||
|
Cmd: script,
|
||||||
|
Proxy: "http://127.0.0.1:1", // unused: health check disabled
|
||||||
|
CheckEndpoint: "none",
|
||||||
|
})
|
||||||
|
// WaitDelay shorter than the handler's 0.6s sleep, and far shorter than the
|
||||||
|
// Stop timeout below — this is the window the old code mis-killed in.
|
||||||
|
p.waitDelay = 200 * time.Millisecond
|
||||||
|
|
||||||
|
runErr := runAsync(t, p)
|
||||||
|
|
||||||
|
// Wait until the trap is installed before stopping.
|
||||||
|
trapDeadline := time.Now().Add(2 * time.Second)
|
||||||
|
for {
|
||||||
|
if _, err := os.Stat(ready); err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if time.Now().After(trapDeadline) {
|
||||||
|
t.Fatalf("script did not install SIGTERM trap in time")
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
stopStart := time.Now()
|
||||||
|
if err := p.Stop(5 * time.Second); err != nil {
|
||||||
|
t.Fatalf("Stop: %v", err)
|
||||||
|
}
|
||||||
|
elapsed := time.Since(stopStart)
|
||||||
|
|
||||||
|
// The handler must have run to completion (marker written) rather than
|
||||||
|
// being force-killed at waitDelay.
|
||||||
|
if _, err := os.Stat(marker); err != nil {
|
||||||
|
t.Fatalf("graceful handler did not complete (marker missing): %v", err)
|
||||||
|
}
|
||||||
|
// And Stop must have waited for the handler (>~0.6s), not returned at the
|
||||||
|
// 200ms waitDelay.
|
||||||
|
if elapsed < 500*time.Millisecond {
|
||||||
|
t.Fatalf("Stop returned in %v — process was killed before its graceful handler finished", elapsed)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := p.State(); got != StateStopped {
|
||||||
|
t.Errorf("after Stop: expected state %s, got %s", StateStopped, got)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-runErr:
|
||||||
|
case <-time.After(testReturnTimeout):
|
||||||
|
t.Errorf("Run did not return after Stop")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcessCommand_StopReapsForkedGrandchild verifies that stopping a forking
|
||||||
|
// wrapper takes down the backgrounded grandchild too, rather than leaving it as
|
||||||
|
// an orphan. The fix is Setpgid (runtime_unix.go): the wrapper leads its own
|
||||||
|
// process group, so the stop signal is delivered to the whole group via the
|
||||||
|
// negative PID and reaches the grandchild the wrapper never reaped.
|
||||||
|
func TestProcessCommand_StopReapsForkedGrandchild(t *testing.T) {
|
||||||
|
skipIfNoSimpleResponder(t)
|
||||||
|
|
||||||
|
port := getFreePort(t)
|
||||||
|
dir := t.TempDir()
|
||||||
|
pidFile := filepath.Join(dir, "child.pid")
|
||||||
|
|
||||||
|
wrapper := filepath.Join(dir, "wrapper.sh")
|
||||||
|
script := fmt.Sprintf("#!/bin/bash\n%q -port %d -silent &\necho $! > %q\nwait\n",
|
||||||
|
simpleResponderPath, port, pidFile)
|
||||||
|
if err := os.WriteFile(wrapper, []byte(script), 0o755); err != nil {
|
||||||
|
t.Fatalf("WriteFile: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { killChildFromPidFile(pidFile) })
|
||||||
|
|
||||||
|
p := newProcessCommand(t, config.ModelConfig{
|
||||||
|
Cmd: wrapper,
|
||||||
|
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
HealthCheckTimeout: 10,
|
||||||
|
})
|
||||||
|
|
||||||
|
runErr := runAsync(t, p)
|
||||||
|
|
||||||
|
// Read the grandchild PID the wrapper recorded.
|
||||||
|
var childPID int
|
||||||
|
deadline := time.Now().Add(2 * time.Second)
|
||||||
|
for {
|
||||||
|
data, err := os.ReadFile(pidFile)
|
||||||
|
if err == nil {
|
||||||
|
if pid, perr := strconv.Atoi(strings.TrimSpace(string(data))); perr == nil && pid > 0 {
|
||||||
|
childPID = pid
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if time.Now().After(deadline) {
|
||||||
|
t.Fatalf("wrapper did not record grandchild PID")
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.Stop(testStopTimeout); err != nil {
|
||||||
|
t.Fatalf("Stop: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// After Stop the grandchild must be gone. Signal 0 probes liveness without
|
||||||
|
// actually sending a signal; give it a brief window to exit after the
|
||||||
|
// group SIGTERM.
|
||||||
|
proc, err := os.FindProcess(childPID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FindProcess: %v", err)
|
||||||
|
}
|
||||||
|
gone := false
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
if err := proc.Signal(syscall.Signal(0)); err != nil {
|
||||||
|
gone = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
if !gone {
|
||||||
|
t.Errorf("grandchild PID %d still alive after Stop — process group was not reaped", childPID)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-runErr:
|
||||||
|
case <-time.After(testReturnTimeout):
|
||||||
|
t.Errorf("Run did not return after Stop")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// killChildFromPidFile reads a PID written by the wrapper script and SIGKILLs
|
||||||
|
// it so leaked orphans don't accumulate between test runs. Best-effort.
|
||||||
|
func killChildFromPidFile(pidFile string) {
|
||||||
|
data, err := os.ReadFile(pidFile)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
|
||||||
|
if err != nil || pid <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
proc, err := os.FindProcess(pid)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = proc.Kill()
|
||||||
|
}
|
||||||
@@ -0,0 +1,646 @@
|
|||||||
|
package process
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
testStartTimeout = 3 * time.Second
|
||||||
|
testStopTimeout = 2 * time.Second
|
||||||
|
testReturnTimeout = 1 * time.Second
|
||||||
|
testPollInterval = 20 * time.Millisecond
|
||||||
|
testLogPollInterval = 10 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
|
func newProcessCommand(t *testing.T, conf config.ModelConfig) *ProcessCommand {
|
||||||
|
t.Helper()
|
||||||
|
logger := logmon.NewWriter(io.Discard)
|
||||||
|
p, err := New(context.Background(), t.Name(), conf, logger, logger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New: %v", err)
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// runAsync starts Run in a goroutine and waits until the process is ready,
|
||||||
|
// matching the new interface contract where Run blocks until the process is
|
||||||
|
// terminated. Returns a channel that delivers Run's eventual error.
|
||||||
|
func runAsync(t *testing.T, p *ProcessCommand) <-chan error {
|
||||||
|
t.Helper()
|
||||||
|
ch := make(chan error, 1)
|
||||||
|
go func() { ch <- p.Run(testStartTimeout) }()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testStartTimeout)
|
||||||
|
defer cancel()
|
||||||
|
if err := p.WaitReady(ctx); err != nil {
|
||||||
|
t.Fatalf("WaitReady: %v", err)
|
||||||
|
}
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessCommand_StartStop(t *testing.T) {
|
||||||
|
skipIfNoSimpleResponder(t)
|
||||||
|
|
||||||
|
cmd, port := simpleResponderCmd(t, "-silent", "-respond hello")
|
||||||
|
p := newProcessCommand(t, config.ModelConfig{
|
||||||
|
Cmd: cmd,
|
||||||
|
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
HealthCheckTimeout: 10,
|
||||||
|
})
|
||||||
|
t.Cleanup(func() { p.Stop(testStopTimeout) })
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
|
||||||
|
// before start: no handler
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
p.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != http.StatusServiceUnavailable {
|
||||||
|
t.Errorf("before start: expected 503, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
if body := rr.Body.String(); !strings.Contains(body, "llama-swap-error") {
|
||||||
|
t.Errorf("before start: expected body to contain %q, got %q", "llama-swap-error", body)
|
||||||
|
}
|
||||||
|
|
||||||
|
runErr := runAsync(t, p)
|
||||||
|
if got := p.State(); got != StateReady {
|
||||||
|
t.Errorf("after Run: expected state %s, got %s", StateReady, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
rr = httptest.NewRecorder()
|
||||||
|
p.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("after Run: expected 200, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
if body := rr.Body.String(); body != "hello" {
|
||||||
|
t.Errorf("expected body %q, got %q", "hello", body)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.Stop(testStopTimeout); err != nil {
|
||||||
|
t.Fatalf("Stop() error: %v", err)
|
||||||
|
}
|
||||||
|
if got := p.State(); got != StateStopped {
|
||||||
|
t.Errorf("after Stop: expected state %s, got %s", StateStopped, got)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case err := <-runErr:
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Run() after Stop: expected nil, got %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(testReturnTimeout):
|
||||||
|
t.Fatal("Run() did not return after Stop")
|
||||||
|
}
|
||||||
|
|
||||||
|
// after stop: handler cleared
|
||||||
|
rr = httptest.NewRecorder()
|
||||||
|
p.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != http.StatusServiceUnavailable {
|
||||||
|
t.Errorf("after stop: expected 503, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
if body := rr.Body.String(); !strings.Contains(body, "llama-swap-error") {
|
||||||
|
t.Errorf("after stop: expected body to contain %q, got %q", "llama-swap-error", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessCommand_Run_Idempotent(t *testing.T) {
|
||||||
|
skipIfNoSimpleResponder(t)
|
||||||
|
|
||||||
|
cmd, port := simpleResponderCmd(t, "-silent")
|
||||||
|
p := newProcessCommand(t, config.ModelConfig{
|
||||||
|
Cmd: cmd,
|
||||||
|
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
HealthCheckTimeout: 10,
|
||||||
|
})
|
||||||
|
t.Cleanup(func() { p.Stop(testStopTimeout) })
|
||||||
|
|
||||||
|
runErr := runAsync(t, p)
|
||||||
|
|
||||||
|
if err := p.Run(testStartTimeout); err == nil {
|
||||||
|
t.Error("second Run() while running: expected error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.Stop(testStopTimeout); err != nil {
|
||||||
|
t.Fatalf("Stop() error: %v", err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-runErr:
|
||||||
|
case <-time.After(testReturnTimeout):
|
||||||
|
t.Fatal("Run() did not return after Stop")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessCommand_Stop_Idempotent(t *testing.T) {
|
||||||
|
skipIfNoSimpleResponder(t)
|
||||||
|
|
||||||
|
cmd, port := simpleResponderCmd(t, "-silent")
|
||||||
|
p := newProcessCommand(t, config.ModelConfig{
|
||||||
|
Cmd: cmd,
|
||||||
|
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
HealthCheckTimeout: 10,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := p.Stop(testStopTimeout); err != nil {
|
||||||
|
t.Fatalf("Stop() before Run(): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
runErr := runAsync(t, p)
|
||||||
|
|
||||||
|
if err := p.Stop(testStopTimeout); err != nil {
|
||||||
|
t.Fatalf("first Stop() error: %v", err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-runErr:
|
||||||
|
case <-time.After(testReturnTimeout):
|
||||||
|
t.Fatal("Run() did not return after Stop")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.Stop(testStopTimeout); err != nil {
|
||||||
|
t.Fatalf("second Stop() error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcessCommand_StopCancelsRun verifies that a Stop sent while Run is
|
||||||
|
// executing its health-check loop returns ErrAbort to the Run caller.
|
||||||
|
//
|
||||||
|
// A blocking mock HTTP server is used as the proxy so the test can deterministically
|
||||||
|
// know when doStart is inside the health-check loop before issuing Stop.
|
||||||
|
func TestProcessCommand_StopCancelsRun(t *testing.T) {
|
||||||
|
skipIfNoSimpleResponder(t)
|
||||||
|
|
||||||
|
healthCheckStarted := make(chan struct{}, 1)
|
||||||
|
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Signal that a health check is in-flight, then block until the client
|
||||||
|
// cancels (which happens when Stop cancels the start context).
|
||||||
|
select {
|
||||||
|
case healthCheckStarted <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
<-r.Context().Done()
|
||||||
|
http.Error(w, "mock cancelled", http.StatusServiceUnavailable)
|
||||||
|
}))
|
||||||
|
defer mock.Close()
|
||||||
|
|
||||||
|
// simple-responder is the real process; health checks go to the blocking mock.
|
||||||
|
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||||
|
p := newProcessCommand(t, config.ModelConfig{
|
||||||
|
Cmd: cmd,
|
||||||
|
Proxy: mock.URL,
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
HealthCheckTimeout: 30,
|
||||||
|
})
|
||||||
|
|
||||||
|
runErrCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
runErrCh <- p.Run(testStartTimeout)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Block until doStart is actually performing a health check, guaranteeing
|
||||||
|
// that Run is in-flight when Stop is called.
|
||||||
|
<-healthCheckStarted
|
||||||
|
|
||||||
|
if err := p.Stop(testStopTimeout); err != nil {
|
||||||
|
t.Fatalf("Stop() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := <-runErrCh; !errors.Is(err, ErrStartAborted) {
|
||||||
|
t.Errorf("expected ErrStartAborted from Run, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcessCommand_ParentCtxCancelDuringStart verifies that cancelling the
|
||||||
|
// parent context while doStart is health-checking causes the process to
|
||||||
|
// transition to StateShutdown promptly, not wait for the health-check timeout.
|
||||||
|
//
|
||||||
|
// This is the config-reload race: Stop() returns early when parentCtx is
|
||||||
|
// already done and never writes to stopCh, so without a parentCtx.Done()
|
||||||
|
// case in the inner select, the process would keep loading indefinitely.
|
||||||
|
func TestProcessCommand_ParentCtxCancelDuringStart(t *testing.T) {
|
||||||
|
skipIfNoSimpleResponder(t)
|
||||||
|
|
||||||
|
healthCheckStarted := make(chan struct{}, 1)
|
||||||
|
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
select {
|
||||||
|
case healthCheckStarted <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
<-r.Context().Done()
|
||||||
|
http.Error(w, "mock cancelled", http.StatusServiceUnavailable)
|
||||||
|
}))
|
||||||
|
defer mock.Close()
|
||||||
|
|
||||||
|
parentCtx, cancelParent := context.WithCancel(context.Background())
|
||||||
|
logger := logmon.NewWriter(io.Discard)
|
||||||
|
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||||
|
p, err := New(parentCtx, t.Name(), config.ModelConfig{
|
||||||
|
Cmd: cmd,
|
||||||
|
Proxy: mock.URL,
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
HealthCheckTimeout: 60,
|
||||||
|
}, logger, logger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
runErrCh := make(chan error, 1)
|
||||||
|
go func() { runErrCh <- p.Run(60 * time.Second) }()
|
||||||
|
|
||||||
|
<-healthCheckStarted
|
||||||
|
|
||||||
|
// Cancel parent context to simulate a config reload tearing down the old server.
|
||||||
|
cancelParent()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-runErrCh:
|
||||||
|
if !strings.Contains(err.Error(), "shutdown") {
|
||||||
|
t.Errorf("Run error = %v, want shutdown error", err)
|
||||||
|
}
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("process did not shut down within 5s after parent context cancel during start")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run() may return before the run() goroutine writes StateShutdown;
|
||||||
|
// poll briefly to avoid a spurious race in the assertion.
|
||||||
|
deadline := time.Now().Add(2 * time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if p.State() == StateShutdown {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(testPollInterval)
|
||||||
|
}
|
||||||
|
if got := p.State(); got != StateShutdown {
|
||||||
|
t.Errorf("after cancel: expected StateShutdown, got %s", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcessCommand_RunStopCycle runs several sequential start/stop pairs on
|
||||||
|
// fresh processes to confirm they are reusable.
|
||||||
|
func TestProcessCommand_RunStopCycle(t *testing.T) {
|
||||||
|
skipIfNoSimpleResponder(t)
|
||||||
|
|
||||||
|
for i := range 3 {
|
||||||
|
cmd, port := simpleResponderCmd(t, "-silent")
|
||||||
|
p := newProcessCommand(t, config.ModelConfig{
|
||||||
|
Cmd: cmd,
|
||||||
|
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
HealthCheckTimeout: 10,
|
||||||
|
})
|
||||||
|
|
||||||
|
runErr := runAsync(t, p)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/health", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
p.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("cycle %d: expected 200 from /health, got %d", i, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.Stop(testStopTimeout); err != nil {
|
||||||
|
t.Fatalf("cycle %d Stop() error: %v", i, err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-runErr:
|
||||||
|
case <-time.After(testReturnTimeout):
|
||||||
|
t.Fatalf("cycle %d: Run() did not return after Stop", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcessCommand_ReverseProxyPanicIsRecovered drives the full proxy path:
|
||||||
|
// the upstream responds healthy on /health (so Run completes), then on the
|
||||||
|
// actual proxied request it hijacks the connection and closes it mid-body.
|
||||||
|
// That upstream EOF makes httputil.ReverseProxy.copyResponse return an error,
|
||||||
|
// which panics with http.ErrAbortHandler — the wrapped handlerFn must recover
|
||||||
|
// and log the disconnect.
|
||||||
|
//
|
||||||
|
// Requests are issued through an httptest.NewServer wrapping the process so
|
||||||
|
// the panic actually fires (httputil only panics on copy errors when the
|
||||||
|
// request carries http.ServerContextKey, which a real server sets).
|
||||||
|
//
|
||||||
|
// see: https://github.com/golang/go/issues/23643
|
||||||
|
func TestProcessCommand_ReverseProxyPanicIsRecovered(t *testing.T) {
|
||||||
|
skipIfNoSimpleResponder(t)
|
||||||
|
|
||||||
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/health" {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Send a Content-Length that promises 100 bytes, deliver only a few,
|
||||||
|
// then slam the connection shut. The reverse proxy will see EOF
|
||||||
|
// before the body is fully copied and panic with ErrAbortHandler.
|
||||||
|
hj, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("upstream: hijack not supported")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn, _, err := hj.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("upstream: hijack: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _ = conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 100\r\nContent-Type: text/plain\r\n\r\npartial"))
|
||||||
|
_ = conn.Close()
|
||||||
|
}))
|
||||||
|
t.Cleanup(upstream.Close)
|
||||||
|
|
||||||
|
// Capture proxy log output so we can assert the recover message was
|
||||||
|
// emitted by handlerFn.
|
||||||
|
logBuf := &syncBuffer{}
|
||||||
|
proxyLogger := logmon.NewWriter(logBuf)
|
||||||
|
procLogger := logmon.NewWriter(io.Discard)
|
||||||
|
|
||||||
|
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||||
|
p, err := New(context.Background(), t.Name(), config.ModelConfig{
|
||||||
|
Cmd: cmd,
|
||||||
|
Proxy: upstream.URL,
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
HealthCheckTimeout: 10,
|
||||||
|
}, procLogger, proxyLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { p.Stop(testStopTimeout) })
|
||||||
|
|
||||||
|
_ = runAsync(t, p)
|
||||||
|
|
||||||
|
// Wrap p in an httptest server so requests get http.ServerContextKey
|
||||||
|
// automatically — that is what makes httputil.ReverseProxy raise the panic.
|
||||||
|
front := httptest.NewServer(p)
|
||||||
|
t.Cleanup(front.Close)
|
||||||
|
|
||||||
|
resp, err := http.Get(front.URL + "/disconnect")
|
||||||
|
if err == nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
const want = "recovered from upstream disconnection"
|
||||||
|
deadline := time.Now().Add(testReturnTimeout)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if strings.Contains(logBuf.String(), want) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(testLogPollInterval)
|
||||||
|
}
|
||||||
|
t.Errorf("expected proxy log to contain %q; got:\n%s", want, logBuf.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// syncBuffer is a concurrent-safe bytes.Buffer for capturing logmon output.
|
||||||
|
type syncBuffer struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
buf bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *syncBuffer) Write(p []byte) (int, error) {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
return b.buf.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *syncBuffer) String() string {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
return b.buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcessCommand_TTL_StopsAfterIdle verifies that a process with a TTL
|
||||||
|
// automatically stops itself after the idle timeout has elapsed following its
|
||||||
|
// last request.
|
||||||
|
func TestProcessCommand_TTL_StopsAfterIdle(t *testing.T) {
|
||||||
|
skipIfNoSimpleResponder(t)
|
||||||
|
|
||||||
|
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
t.Cleanup(mock.Close)
|
||||||
|
|
||||||
|
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||||
|
|
||||||
|
cfg := config.ModelConfig{
|
||||||
|
Cmd: cmd,
|
||||||
|
Proxy: mock.URL,
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
HealthCheckTimeout: 10,
|
||||||
|
UnloadAfter: 1, // 1-second TTL
|
||||||
|
}
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
cfg.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||||
|
}
|
||||||
|
|
||||||
|
p := newProcessCommand(t, cfg)
|
||||||
|
|
||||||
|
runErr := runAsync(t, p)
|
||||||
|
defer func() {
|
||||||
|
if p.State() == StateReady {
|
||||||
|
p.Stop(testStopTimeout)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if got := p.State(); got != StateReady {
|
||||||
|
t.Fatalf("expected StateReady, got %s", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make one request to prime the last-use timestamp.
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
p.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200 after request, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the TTL goroutine to fire and the process to fully stop.
|
||||||
|
// Poll for StateStopped directly to avoid racing the StateStopping
|
||||||
|
// intermediate state that sits between StateReady and StateStopped.
|
||||||
|
deadline := time.Now().Add(5 * time.Second)
|
||||||
|
for p.State() != StateStopped && time.Now().Before(deadline) {
|
||||||
|
time.Sleep(testPollInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := p.State(); got != StateStopped {
|
||||||
|
t.Fatalf("TTL did not stop process; state is %s (expected %s)", got, StateStopped)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run() should have returned nil (clean stop from TTL).
|
||||||
|
select {
|
||||||
|
case err := <-runErr:
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Run() after TTL stop: expected nil, got %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(testReturnTimeout):
|
||||||
|
t.Fatal("Run() did not return after TTL-induced stop")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcessCommand_TTL_ResetsOnRequest verifies that inflight requests
|
||||||
|
// prevent the TTL goroutine from stopping the process, and that the TTL timer
|
||||||
|
// resets after each request completes.
|
||||||
|
func TestProcessCommand_TTL_ResetsOnRequest(t *testing.T) {
|
||||||
|
skipIfNoSimpleResponder(t)
|
||||||
|
|
||||||
|
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
t.Cleanup(mock.Close)
|
||||||
|
|
||||||
|
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||||
|
p := newProcessCommand(t, config.ModelConfig{
|
||||||
|
Cmd: cmd,
|
||||||
|
Proxy: mock.URL,
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
HealthCheckTimeout: 10,
|
||||||
|
UnloadAfter: 1, // 1-second TTL
|
||||||
|
})
|
||||||
|
|
||||||
|
runErr := runAsync(t, p)
|
||||||
|
defer func() {
|
||||||
|
if p.State() == StateReady {
|
||||||
|
p.Stop(testStopTimeout)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Keep sending requests for 1.5s — past the 1s TTL — and verify
|
||||||
|
// the process never stops while traffic is flowing.
|
||||||
|
stopAt := time.Now().Add(1500 * time.Millisecond)
|
||||||
|
for time.Now().Before(stopAt) {
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
p.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
if p.State() != StateReady {
|
||||||
|
t.Fatalf("process was stopped during active traffic (state=%s)", p.State())
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := p.State(); got != StateReady {
|
||||||
|
t.Fatalf("expected StateReady while traffic was active, got %s", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now stop manually to clean up.
|
||||||
|
if err := p.Stop(testStopTimeout); err != nil {
|
||||||
|
t.Fatalf("Stop() error: %v", err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-runErr:
|
||||||
|
case <-time.After(testReturnTimeout):
|
||||||
|
t.Fatal("Run() did not return after Stop")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcessCommand_TTL_ZeroDisables verifies that UnloadAfter=0 does not
|
||||||
|
// spawn a TTL goroutine — the process stays ready until explicitly stopped.
|
||||||
|
func TestProcessCommand_TTL_ZeroDisables(t *testing.T) {
|
||||||
|
skipIfNoSimpleResponder(t)
|
||||||
|
|
||||||
|
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
t.Cleanup(mock.Close)
|
||||||
|
|
||||||
|
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||||
|
p := newProcessCommand(t, config.ModelConfig{
|
||||||
|
Cmd: cmd,
|
||||||
|
Proxy: mock.URL,
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
HealthCheckTimeout: 10,
|
||||||
|
UnloadAfter: 0, // disabled
|
||||||
|
})
|
||||||
|
|
||||||
|
runErr := runAsync(t, p)
|
||||||
|
defer func() {
|
||||||
|
if p.State() == StateReady {
|
||||||
|
p.Stop(testStopTimeout)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if got := p.State(); got != StateReady {
|
||||||
|
t.Fatalf("expected StateReady, got %s", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
p.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200 after request, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// No TTL goroutine is spawned when UnloadAfter=0, so a brief sleep is
|
||||||
|
// enough to confirm the process remains ready.
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
if got := p.State(); got != StateReady {
|
||||||
|
t.Fatalf("process was stopped unexpectedly (state=%s) with TTL=0", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanly stop.
|
||||||
|
if err := p.Stop(testStopTimeout); err != nil {
|
||||||
|
t.Fatalf("Stop() error: %v", err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-runErr:
|
||||||
|
case <-time.After(testReturnTimeout):
|
||||||
|
t.Fatal("Run() did not return after Stop")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcessCommand_ConcurrentRunStop launches many concurrent run/stop racing
|
||||||
|
// pairs to exercise the race detector and verify no deadlocks occur.
|
||||||
|
func TestProcessCommand_ConcurrentRunStop(t *testing.T) {
|
||||||
|
skipIfNoSimpleResponder(t)
|
||||||
|
|
||||||
|
for range 10 {
|
||||||
|
cmd, port := simpleResponderCmd(t, "-silent")
|
||||||
|
cfg := config.ModelConfig{
|
||||||
|
Cmd: cmd,
|
||||||
|
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
HealthCheckTimeout: 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
cfg.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||||
|
}
|
||||||
|
|
||||||
|
p := newProcessCommand(t, cfg)
|
||||||
|
|
||||||
|
runDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(runDone)
|
||||||
|
p.Run(testStartTimeout) //nolint: errcheck — one goroutine wins the race
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
p.Stop(testStopTimeout) //nolint: errcheck
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Backstop: the racing Stop may have arrived before Run got on the
|
||||||
|
// channel (making it a no-op), so keep stopping until Run unblocks.
|
||||||
|
deadline := time.After(testStartTimeout)
|
||||||
|
for done := false; !done; {
|
||||||
|
select {
|
||||||
|
case <-runDone:
|
||||||
|
done = true
|
||||||
|
case <-deadline:
|
||||||
|
t.Fatal("Run did not return")
|
||||||
|
case <-time.After(testPollInterval):
|
||||||
|
p.Stop(testStopTimeout) //nolint: errcheck
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,82 @@
|
|||||||
|
package process
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/event"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProcessCommand_EmitsStateChangeEvents(t *testing.T) {
|
||||||
|
skipIfNoSimpleResponder(t)
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
var transitions []shared.ProcessStateChangeEvent
|
||||||
|
cancel := event.On(func(e shared.ProcessStateChangeEvent) {
|
||||||
|
if e.ProcessName != t.Name() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mu.Lock()
|
||||||
|
transitions = append(transitions, e)
|
||||||
|
mu.Unlock()
|
||||||
|
})
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cmd, port := simpleResponderCmd(t, "-silent", "-respond hello")
|
||||||
|
p := newProcessCommand(t, config.ModelConfig{
|
||||||
|
Cmd: cmd,
|
||||||
|
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
HealthCheckTimeout: 10,
|
||||||
|
})
|
||||||
|
|
||||||
|
runErr := runAsync(t, p)
|
||||||
|
if err := p.Stop(testStopTimeout); err != nil {
|
||||||
|
t.Fatalf("Stop: %v", err)
|
||||||
|
}
|
||||||
|
<-runErr
|
||||||
|
|
||||||
|
// Events are delivered asynchronously; give the dispatcher a moment.
|
||||||
|
deadline := time.Now().Add(time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
mu.Lock()
|
||||||
|
n := len(transitions)
|
||||||
|
mu.Unlock()
|
||||||
|
if n >= 4 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(testPollInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
|
||||||
|
for _, e := range transitions {
|
||||||
|
if e.OldState == e.NewState {
|
||||||
|
t.Errorf("emitted no-op transition: %s -> %s", e.OldState, e.NewState)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []string{
|
||||||
|
string(StateStopped) + "->" + string(StateStarting),
|
||||||
|
string(StateStarting) + "->" + string(StateReady),
|
||||||
|
string(StateReady) + "->" + string(StateStopping),
|
||||||
|
string(StateStopping) + "->" + string(StateStopped),
|
||||||
|
}
|
||||||
|
got := make([]string, len(transitions))
|
||||||
|
for i, e := range transitions {
|
||||||
|
got[i] = e.OldState + "->" + e.NewState
|
||||||
|
}
|
||||||
|
if len(got) != len(want) {
|
||||||
|
t.Fatalf("transitions = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if got[i] != want[i] {
|
||||||
|
t.Fatalf("transitions = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package process
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os/exec"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setProcAttributes starts the upstream in its own process group (Setpgid) so
|
||||||
|
// the entire process tree can be signalled at once via its negative PID. This
|
||||||
|
// is what lets us reap a forked grandchild — e.g. a shell wrapper that
|
||||||
|
// backgrounds the real binary and exits — instead of leaking it as an orphan
|
||||||
|
// that holds the inherited stdout/stderr pipes open.
|
||||||
|
func setProcAttributes(cmd *exec.Cmd) {
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
// terminateProcessTree sends SIGTERM to the whole process group led by the
|
||||||
|
// command, giving every process in the tree a chance to shut down gracefully.
|
||||||
|
func terminateProcessTree(cmd *exec.Cmd) error {
|
||||||
|
return signalProcessTree(cmd, syscall.SIGTERM)
|
||||||
|
}
|
||||||
|
|
||||||
|
// killProcessTree sends SIGKILL to the whole process group, force-terminating
|
||||||
|
// every process in the tree.
|
||||||
|
func killProcessTree(cmd *exec.Cmd) error {
|
||||||
|
return signalProcessTree(cmd, syscall.SIGKILL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// signalProcessTree signals the process group led by cmd.Process. Because the
|
||||||
|
// child was started with Setpgid it is its own group leader (pgid == pid), so
|
||||||
|
// targeting -pid reaches the child and every descendant still in the group.
|
||||||
|
// Falls back to signalling just the child if the group send fails (e.g. the
|
||||||
|
// group has already drained), so we never silently skip the signal.
|
||||||
|
func signalProcessTree(cmd *exec.Cmd, sig syscall.Signal) error {
|
||||||
|
if cmd == nil || cmd.Process == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := syscall.Kill(-cmd.Process.Pid, sig); err != nil {
|
||||||
|
return cmd.Process.Signal(sig)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package process
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setProcAttributes sets platform-specific process attributes. CREATE_NO_WINDOW
|
||||||
|
// keeps the upstream from spawning its own console window.
|
||||||
|
func setProcAttributes(cmd *exec.Cmd) {
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||||
|
HideWindow: true,
|
||||||
|
CreationFlags: 0x08000000, // CREATE_NO_WINDOW
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// terminateProcessTree requests a graceful shutdown of the whole process tree
|
||||||
|
// rooted at cmd.Process. Windows has no SIGTERM or process-group signalling, so
|
||||||
|
// we shell out to `taskkill /t`, which walks the child tree by PID — the
|
||||||
|
// equivalent of signalling a Unix process group. Without /f, taskkill asks the
|
||||||
|
// processes to close rather than force-killing them.
|
||||||
|
func terminateProcessTree(cmd *exec.Cmd) error {
|
||||||
|
return taskkillProcessTree(cmd, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// killProcessTree force-terminates the whole process tree rooted at cmd.Process
|
||||||
|
// via `taskkill /f /t`, so any descendant that ignored or outlived the graceful
|
||||||
|
// request is killed alongside the parent rather than leaked as an orphan.
|
||||||
|
func killProcessTree(cmd *exec.Cmd) error {
|
||||||
|
return taskkillProcessTree(cmd, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// taskkillProcessTree runs taskkill against cmd.Process.Pid. The /t flag
|
||||||
|
// terminates the process together with any child processes it started, which is
|
||||||
|
// the Windows analogue of signalling a Unix process group via its negative PID.
|
||||||
|
// When force is true the /f flag force-kills; otherwise taskkill requests a
|
||||||
|
// graceful close.
|
||||||
|
func taskkillProcessTree(cmd *exec.Cmd, force bool) error {
|
||||||
|
if cmd == nil || cmd.Process == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
args := make([]string, 0, 4)
|
||||||
|
if force {
|
||||||
|
args = append(args, "/f")
|
||||||
|
}
|
||||||
|
args = append(args, "/t", "/pid", fmt.Sprintf("%d", cmd.Process.Pid))
|
||||||
|
kill := exec.Command("taskkill", args...)
|
||||||
|
setProcAttributes(kill)
|
||||||
|
return kill.Run()
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package process
|
||||||
|
|
||||||
|
// SetupTreeCleanup is a no-op on non-Windows platforms, where upstream process
|
||||||
|
// teardown is handled via process-group signalling (see runtime_unix.go).
|
||||||
|
func SetupTreeCleanup() error { return nil }
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package process
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetupTreeCleanup assigns the current process to a Windows Job Object
|
||||||
|
// configured with JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE. Upstream processes
|
||||||
|
// spawned afterwards are associated with the same job, so when llama-swap exits
|
||||||
|
// for any reason — graceful shutdown, a forced second Ctrl+C, or a crash — the
|
||||||
|
// OS terminates the whole job and reaps every child instead of leaving orphans
|
||||||
|
// behind. It is the parent-side complement to the per-process teardown in
|
||||||
|
// runtime_windows.go.
|
||||||
|
//
|
||||||
|
// The job handle is intentionally leaked for the lifetime of the process: the
|
||||||
|
// kill-on-close behaviour fires when the last handle is released, which the OS
|
||||||
|
// does when the process exits.
|
||||||
|
func SetupTreeCleanup() error {
|
||||||
|
job, err := windows.CreateJobObject(nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("CreateJobObject: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
info := windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION{
|
||||||
|
BasicLimitInformation: windows.JOBOBJECT_BASIC_LIMIT_INFORMATION{
|
||||||
|
LimitFlags: windows.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if _, err := windows.SetInformationJobObject(
|
||||||
|
job,
|
||||||
|
windows.JobObjectExtendedLimitInformation,
|
||||||
|
uintptr(unsafe.Pointer(&info)),
|
||||||
|
uint32(unsafe.Sizeof(info)),
|
||||||
|
); err != nil {
|
||||||
|
windows.CloseHandle(job)
|
||||||
|
return fmt.Errorf("SetInformationJobObject: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := windows.AssignProcessToJobObject(job, windows.CurrentProcess()); err != nil {
|
||||||
|
windows.CloseHandle(job)
|
||||||
|
return fmt.Errorf("AssignProcessToJobObject: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,800 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
type shutdownReq struct {
|
||||||
|
timeout time.Duration
|
||||||
|
respond chan error
|
||||||
|
}
|
||||||
|
|
||||||
|
type unloadReq struct {
|
||||||
|
targets []string
|
||||||
|
timeout time.Duration
|
||||||
|
respond chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type handlerReq struct {
|
||||||
|
model string
|
||||||
|
ctx context.Context
|
||||||
|
respond chan handlerResp
|
||||||
|
positionCh chan int
|
||||||
|
}
|
||||||
|
|
||||||
|
type handlerResp struct {
|
||||||
|
handleFunc http.HandlerFunc
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type swapDone struct {
|
||||||
|
modelID string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type serveDoneEvent struct {
|
||||||
|
modelID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type activeSwap struct {
|
||||||
|
modelID string
|
||||||
|
evict []string
|
||||||
|
waiters []handlerReq
|
||||||
|
}
|
||||||
|
|
||||||
|
// swapPlanner is the only piece of behaviour that differs between concrete
|
||||||
|
// routers. baseRouter never inspects its internals.
|
||||||
|
type swapPlanner interface {
|
||||||
|
// EvictionFor returns running model IDs that must be stopped before
|
||||||
|
// target can serve. alsoRunning lists models the baseRouter has already
|
||||||
|
// committed to loading (in-flight swaps) which the planner cannot see
|
||||||
|
// via process.State() yet. Pure decision; must not log.
|
||||||
|
EvictionFor(target string, alsoRunning []string) []string
|
||||||
|
|
||||||
|
// OnSwapStart runs once at the start of every swap. Planners may log
|
||||||
|
// their decision here at whatever verbosity they choose.
|
||||||
|
OnSwapStart(target string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// baseRouter owns the channels, run-loop, and orchestration code shared by
|
||||||
|
// every concrete router. Concrete routers embed *baseRouter and supply a
|
||||||
|
// swapPlanner that captures how their eviction set is decided.
|
||||||
|
type baseRouter struct {
|
||||||
|
name string
|
||||||
|
config config.Config
|
||||||
|
processes map[string]process.Process
|
||||||
|
logger *logmon.Monitor
|
||||||
|
planner swapPlanner
|
||||||
|
|
||||||
|
// shutdownCtx governs the request machinery: cancelling it tells grant()
|
||||||
|
// and ServeHTTP to stop granting and reject callers. It is deliberately
|
||||||
|
// separate from procCtx — see procCtx below.
|
||||||
|
shutdownCtx context.Context
|
||||||
|
shutdownFn context.CancelFunc
|
||||||
|
shuttingDown atomic.Bool
|
||||||
|
|
||||||
|
// procCtx is the parent context for every managed process and governs
|
||||||
|
// process lifetime only. handleShutdown stops processes gracefully via
|
||||||
|
// Stop() and cancels procCtx afterwards, so teardown is never a context
|
||||||
|
// cancel racing the graceful path (which collapsed the grace to 100ms and
|
||||||
|
// let the caller return before children were reaped — see process run loop).
|
||||||
|
procCtx context.Context
|
||||||
|
procCancel context.CancelFunc
|
||||||
|
|
||||||
|
handlerCh chan handlerReq
|
||||||
|
shutdownCh chan shutdownReq
|
||||||
|
unloadCh chan unloadReq
|
||||||
|
swapDoneCh chan swapDone
|
||||||
|
serveDoneCh chan serveDoneEvent
|
||||||
|
|
||||||
|
runDone chan struct{}
|
||||||
|
|
||||||
|
// testProcessed, when non-nil, receives one event after each handlerReq
|
||||||
|
// or swapDone has been fully processed by run(). Tests use it to wait
|
||||||
|
// for run() to reach a deterministic state without sleeping. serveDone
|
||||||
|
// events are intentionally NOT signalled here so test event counts
|
||||||
|
// remain stable.
|
||||||
|
testProcessed chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBaseRouter(name string, conf config.Config, processes map[string]process.Process, planner swapPlanner, logger *logmon.Monitor) *baseRouter {
|
||||||
|
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
|
||||||
|
procCtx, procCancel := context.WithCancel(context.Background())
|
||||||
|
return &baseRouter{
|
||||||
|
name: name,
|
||||||
|
config: conf,
|
||||||
|
processes: processes,
|
||||||
|
logger: logger,
|
||||||
|
planner: planner,
|
||||||
|
shutdownCtx: shutdownCtx,
|
||||||
|
shutdownFn: shutdownFn,
|
||||||
|
procCtx: procCtx,
|
||||||
|
procCancel: procCancel,
|
||||||
|
handlerCh: make(chan handlerReq),
|
||||||
|
shutdownCh: make(chan shutdownReq),
|
||||||
|
unloadCh: make(chan unloadReq),
|
||||||
|
swapDoneCh: make(chan swapDone),
|
||||||
|
serveDoneCh: make(chan serveDoneEvent),
|
||||||
|
runDone: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *baseRouter) notifyProcessed() {
|
||||||
|
if b.testProcessed != nil {
|
||||||
|
b.testProcessed <- struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *baseRouter) run() {
|
||||||
|
defer close(b.runDone)
|
||||||
|
|
||||||
|
active := make(map[string]*activeSwap)
|
||||||
|
inFlight := make(map[string]int)
|
||||||
|
var queued []handlerReq
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case req := <-b.shutdownCh:
|
||||||
|
b.handleShutdown(req, active, queued)
|
||||||
|
return
|
||||||
|
|
||||||
|
case req := <-b.handlerCh:
|
||||||
|
b.handleRequest(req, active, inFlight, &queued)
|
||||||
|
b.notifyProcessed()
|
||||||
|
|
||||||
|
case req := <-b.unloadCh:
|
||||||
|
b.handleUnload(req, active, inFlight, &queued)
|
||||||
|
b.notifyProcessed()
|
||||||
|
|
||||||
|
case ev := <-b.swapDoneCh:
|
||||||
|
b.handleSwapDone(ev, active, inFlight, &queued)
|
||||||
|
b.notifyProcessed()
|
||||||
|
|
||||||
|
case ev := <-b.serveDoneCh:
|
||||||
|
b.handleServeDone(ev, active, inFlight, &queued)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// grant sends a response back to the caller of ServeHTTP and tells us
|
||||||
|
// whether the caller was still there to receive it.
|
||||||
|
//
|
||||||
|
// Each ServeHTTP creates a fresh, UNBUFFERED respond channel and parks in
|
||||||
|
// a select waiting on it. "Unbuffered" is the important word: a send only
|
||||||
|
// completes when the other side is actively receiving. So if this send
|
||||||
|
// succeeds, we know for a fact the caller picked up the response and will
|
||||||
|
// act on it. If the caller has already given up (its request context was
|
||||||
|
// cancelled, e.g. the HTTP client disconnected) or the router is shutting
|
||||||
|
// down, the send never lands, one of the other select cases fires, and we
|
||||||
|
// report back that the grant did NOT happen.
|
||||||
|
//
|
||||||
|
// That distinction matters for in-flight bookkeeping — see grantHandler.
|
||||||
|
func (b *baseRouter) grant(req handlerReq, resp handlerResp) bool {
|
||||||
|
select {
|
||||||
|
case req.respond <- resp:
|
||||||
|
return true
|
||||||
|
case <-req.ctx.Done():
|
||||||
|
return false
|
||||||
|
case <-b.shutdownCtx.Done():
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// grantHandler is the "this caller can now use process p" path. It does
|
||||||
|
// two things that must stay locked together:
|
||||||
|
//
|
||||||
|
// 1. Hand the caller a wrapped p.ServeHTTP (via trackedServe) so when the
|
||||||
|
// HTTP request finishes, the run loop hears about it.
|
||||||
|
// 2. Bump inFlight[modelID] so the router knows this process is busy and
|
||||||
|
// refuses to evict it until the count comes back down.
|
||||||
|
//
|
||||||
|
// The increment is gated on grant() returning true. If grant() returns
|
||||||
|
// false, the caller already walked away and trackedServe will never run —
|
||||||
|
// which means no matching decrement will ever arrive on serveDoneCh.
|
||||||
|
// Incrementing in that case would strand the counter at >0 forever and
|
||||||
|
// the router would never again be willing to swap this model out.
|
||||||
|
//
|
||||||
|
// In short: increment if and only if we know a decrement is coming.
|
||||||
|
func (b *baseRouter) grantHandler(req handlerReq, modelID string, p process.Process, inFlight map[string]int) {
|
||||||
|
if b.grant(req, handlerResp{handleFunc: b.trackedServe(modelID, p)}) {
|
||||||
|
inFlight[modelID]++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// trackedServe is the wrapper that closes the loop on in-flight tracking.
|
||||||
|
// It runs p.ServeHTTP normally; the only added behaviour is a deferred
|
||||||
|
// send on serveDoneCh after the handler returns. That send is what tells
|
||||||
|
// the run loop "this model now has one fewer request in flight — go look
|
||||||
|
// at the queue again, you may be able to start a swap you previously had
|
||||||
|
// to defer."
|
||||||
|
//
|
||||||
|
// The select on shutdownCtx.Done() is a release valve: if the router is
|
||||||
|
// already shutting down, nobody is reading serveDoneCh, so we drop the
|
||||||
|
// notification rather than blocking the HTTP goroutine forever.
|
||||||
|
func (b *baseRouter) trackedServe(modelID string, p process.Process) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer func() {
|
||||||
|
select {
|
||||||
|
case b.serveDoneCh <- serveDoneEvent{modelID: modelID}:
|
||||||
|
case <-b.shutdownCtx.Done():
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
p.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleRequest decides what to do with one incoming ServeHTTP request. It is
|
||||||
|
// called from run() and never blocks indefinitely: any work that has to wait
|
||||||
|
// (starting a process, stopping siblings, waiting for ready) is deferred to
|
||||||
|
// a swap goroutine and reported back via swapDoneCh.
|
||||||
|
//
|
||||||
|
// The decision tree, in order:
|
||||||
|
//
|
||||||
|
// 1. Unknown model — respond with ErrNoLocalModelFound and move on.
|
||||||
|
// 2. A swap to the same model is already in flight — attach this waiter so
|
||||||
|
// one swap serves all callers that asked for the same model.
|
||||||
|
// 3. Fast path — the target process is already ready, the planner sees
|
||||||
|
// nothing to evict, and no in-flight swap is evicting it. Hand back its
|
||||||
|
// ServeHTTP immediately (wrapped so the run loop knows when it ends).
|
||||||
|
// 4. Would collide with an in-flight swap (we'd stop their target, or
|
||||||
|
// they're stopping us) — park in the queue for handleSwapDone to drain.
|
||||||
|
// 5. Would evict a process that is still handling requests — park in the
|
||||||
|
// queue. handleServeDone will retry when the busy process drains.
|
||||||
|
// 6. Otherwise — start a new swap. This may run in parallel with other
|
||||||
|
// active swaps when their evict sets don't intersect.
|
||||||
|
func (b *baseRouter) handleRequest(req handlerReq, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
||||||
|
// (1) Unknown model.
|
||||||
|
p, ok := b.processes[req.model]
|
||||||
|
if !ok {
|
||||||
|
b.logger.Debugf("%s: model %s not handled by this router", b.name, req.model)
|
||||||
|
b.grant(req, handlerResp{err: ErrNoLocalModelFound})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// (2) Join an in-flight swap for the same model.
|
||||||
|
if s, ok := active[req.model]; ok {
|
||||||
|
b.logger.Debugf("%s: joining in-flight swap for model %s (%d waiters)", b.name, req.model, len(s.waiters)+1)
|
||||||
|
s.waiters = append(s.waiters, req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
evict := b.planner.EvictionFor(req.model, activeTargets(active, req.model))
|
||||||
|
|
||||||
|
// (3) Fast path: ready, nothing to evict, and nobody is evicting us.
|
||||||
|
if p.State() == process.StateReady && len(evict) == 0 && !collidesWith(req.model, evict, active) {
|
||||||
|
b.logger.Debugf("%s: fast-path serving model %s (already ready)", b.name, req.model)
|
||||||
|
b.grantHandler(req, req.model, p, inFlight)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// (4) Collision with an in-flight swap — queue.
|
||||||
|
if collidesWith(req.model, evict, active) {
|
||||||
|
b.logger.Debugf("%s: queuing request for model %s (collides with in-flight swap)", b.name, req.model)
|
||||||
|
*queued = append(*queued, req)
|
||||||
|
b.broadcastQueuePositions(*queued)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// (5) Would evict a busy process — queue until it drains.
|
||||||
|
if conflictsWithInFlight(evict, inFlight) {
|
||||||
|
b.logger.Debugf("%s: queuing request for model %s (would evict in-flight process)", b.name, req.model)
|
||||||
|
*queued = append(*queued, req)
|
||||||
|
b.broadcastQueuePositions(*queued)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// (6) Start a new (possibly parallel) swap.
|
||||||
|
b.logger.Debugf("%s: starting swap for model %s, evicting %v", b.name, req.model, evict)
|
||||||
|
s := b.startSwap(req, evict)
|
||||||
|
active[s.modelID] = s
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleSwapDone is called from run() when a swap goroutine reports that it
|
||||||
|
// has finished. It fans out the result to every waiter that joined this swap,
|
||||||
|
// removes the swap from the active map, and then walks the queue once,
|
||||||
|
// promoting any items that no longer collide with the remaining active set.
|
||||||
|
// FIFO order is preserved: items still blocked stay in place.
|
||||||
|
func (b *baseRouter) handleSwapDone(ev swapDone, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
||||||
|
s, ok := active[ev.modelID]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(active, ev.modelID)
|
||||||
|
|
||||||
|
for _, w := range s.waiters {
|
||||||
|
if ev.err != nil {
|
||||||
|
b.grant(w, handlerResp{err: ev.err})
|
||||||
|
} else {
|
||||||
|
p := b.processes[ev.modelID]
|
||||||
|
b.grantHandler(w, ev.modelID, p, inFlight)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.drainQueue(active, inFlight, queued)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleServeDone is called from run() each time a tracked ServeHTTP
|
||||||
|
// finishes. It decrements the per-model in-flight count and, when that
|
||||||
|
// drops to zero, retries the queue: requests whose swap was deferred
|
||||||
|
// because they would have evicted this (now-idle) process can now proceed.
|
||||||
|
func (b *baseRouter) handleServeDone(ev serveDoneEvent, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
||||||
|
inFlight[ev.modelID]--
|
||||||
|
if inFlight[ev.modelID] <= 0 {
|
||||||
|
delete(inFlight, ev.modelID)
|
||||||
|
b.drainQueue(active, inFlight, queued)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// drainQueue walks the queued requests in order, re-running the handleRequest
|
||||||
|
// decision tree against the (now smaller) active set. Items that can now start
|
||||||
|
// or join become satisfied; items still blocked remain queued in original
|
||||||
|
// order so they get another chance on the next swap completion.
|
||||||
|
func (b *baseRouter) drainQueue(active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
||||||
|
if len(*queued) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pending := *queued
|
||||||
|
var remaining []handlerReq
|
||||||
|
for _, req := range pending {
|
||||||
|
p, ok := b.processes[req.model]
|
||||||
|
if !ok {
|
||||||
|
b.grant(req, handlerResp{err: ErrNoLocalModelFound})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if s, ok := active[req.model]; ok {
|
||||||
|
b.logger.Debugf("%s: queued request for model %s now joining in-flight swap", b.name, req.model)
|
||||||
|
s.waiters = append(s.waiters, req)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
evict := b.planner.EvictionFor(req.model, activeTargets(active, req.model))
|
||||||
|
if p.State() == process.StateReady && len(evict) == 0 && !collidesWith(req.model, evict, active) {
|
||||||
|
b.logger.Debugf("%s: queued request for model %s now served fast-path", b.name, req.model)
|
||||||
|
b.grantHandler(req, req.model, p, inFlight)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if collidesWith(req.model, evict, active) {
|
||||||
|
remaining = append(remaining, req)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if conflictsWithInFlight(evict, inFlight) {
|
||||||
|
remaining = append(remaining, req)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b.logger.Debugf("%s: queued request for model %s now starting swap, evicting %v", b.name, req.model, evict)
|
||||||
|
s := b.startSwap(req, evict)
|
||||||
|
active[s.modelID] = s
|
||||||
|
}
|
||||||
|
*queued = remaining
|
||||||
|
b.broadcastQueuePositions(*queued)
|
||||||
|
}
|
||||||
|
|
||||||
|
// broadcastQueuePositions sends each queued request its current 1-indexed
|
||||||
|
// position. Sends are non-blocking: if the channel is full, the old value is
|
||||||
|
// drained first so the consumer always sees the latest position.
|
||||||
|
func (b *baseRouter) broadcastQueuePositions(queued []handlerReq) {
|
||||||
|
for i, req := range queued {
|
||||||
|
pos := i + 1
|
||||||
|
select {
|
||||||
|
case req.positionCh <- pos:
|
||||||
|
default:
|
||||||
|
select {
|
||||||
|
case <-req.positionCh:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case req.positionCh <- pos:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *baseRouter) startSwap(initial handlerReq, evict []string) *activeSwap {
|
||||||
|
swap := &activeSwap{
|
||||||
|
modelID: initial.model,
|
||||||
|
evict: evict,
|
||||||
|
waiters: []handlerReq{initial},
|
||||||
|
}
|
||||||
|
b.planner.OnSwapStart(initial.model)
|
||||||
|
go b.doSwap(initial.model, evict)
|
||||||
|
return swap
|
||||||
|
}
|
||||||
|
|
||||||
|
// activeTargets returns the IDs of every in-flight swap target except exclude.
|
||||||
|
// baseRouter passes this to the planner so eviction decisions account for
|
||||||
|
// models that have been committed to but have not yet transitioned to
|
||||||
|
// StateStarting in their process state machine.
|
||||||
|
func activeTargets(active map[string]*activeSwap, exclude string) []string {
|
||||||
|
if len(active) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]string, 0, len(active))
|
||||||
|
for id := range active {
|
||||||
|
if id == exclude {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, id)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// collidesWith reports whether a new swap with this target and evict set can
|
||||||
|
// safely run alongside the currently active swaps. Same-target callers should
|
||||||
|
// JOIN (handled before this) — they do not collide with themselves.
|
||||||
|
func collidesWith(target string, evict []string, active map[string]*activeSwap) bool {
|
||||||
|
for id, s := range active {
|
||||||
|
if id == target {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if containsString(evict, id) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if containsString(s.evict, target) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// conflictsWithInFlight reports whether any model in evict is still handling
|
||||||
|
// requests. Stopping a busy process would cancel its callers' connections,
|
||||||
|
// so the router defers the swap until those callers finish.
|
||||||
|
func conflictsWithInFlight(evict []string, inFlight map[string]int) bool {
|
||||||
|
for _, m := range evict {
|
||||||
|
if inFlight[m] > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsString(xs []string, s string) bool {
|
||||||
|
for _, x := range xs {
|
||||||
|
if x == s {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *baseRouter) doSwap(modelID string, toStop []string) {
|
||||||
|
timeout := b.healthCheckTimeout()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, mID := range toStop {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(p process.Process, id string) {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := p.Stop(timeout); err != nil {
|
||||||
|
b.logger.Warnf("%s: stopping %s failed: %v", b.name, id, err)
|
||||||
|
}
|
||||||
|
}(b.processes[mID], mID)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
target := b.processes[modelID]
|
||||||
|
if target.State() == process.StateStopped {
|
||||||
|
go func() {
|
||||||
|
if err := target.Run(timeout); err != nil {
|
||||||
|
b.logger.Warnf("%s: running %s exited: %v", b.name, modelID, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
err := target.WaitReady(b.shutdownCtx)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case b.swapDoneCh <- swapDone{modelID: modelID, err: err}:
|
||||||
|
case <-b.shutdownCtx.Done():
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *baseRouter) handleShutdown(req shutdownReq, active map[string]*activeSwap, queued []handlerReq) {
|
||||||
|
shutdownErr := fmt.Errorf("%s is shutting down", b.name)
|
||||||
|
|
||||||
|
// Cancel shutdownCtx first so any waiter that is currently parked on
|
||||||
|
// its respond channel can exit via its own shutdownCtx.Done() branch.
|
||||||
|
// The grant calls below then either land (waiter happened to receive
|
||||||
|
// before noticing shutdown) or fall through immediately via grant's
|
||||||
|
// shutdownCtx case — either way the waiter sees a non-OK response.
|
||||||
|
// This does NOT touch processes: their lifetime is procCtx, cancelled
|
||||||
|
// only after the graceful Stop() calls below have reaped them.
|
||||||
|
b.shutdownFn()
|
||||||
|
|
||||||
|
for _, s := range active {
|
||||||
|
for _, w := range s.waiters {
|
||||||
|
b.grant(w, handlerResp{err: shutdownErr})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, w := range queued {
|
||||||
|
b.grant(w, handlerResp{err: shutdownErr})
|
||||||
|
}
|
||||||
|
|
||||||
|
stopTimeout := req.timeout
|
||||||
|
if stopTimeout <= 0 {
|
||||||
|
stopTimeout = b.healthCheckTimeout()
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i, p := range b.processes {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id string, p process.Process) {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := p.Stop(stopTimeout); err != nil {
|
||||||
|
b.logger.Warnf("%s failed to stop process %s: %v", b.name, id, err)
|
||||||
|
}
|
||||||
|
}(i, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if req.timeout > 0 {
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(req.timeout):
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
// Every process is stopped (children reaped via Stop()). Cancel procCtx so
|
||||||
|
// the process run-loop goroutines exit; they are already StateStopped, so
|
||||||
|
// this is a clean no-op kill rather than a forced teardown.
|
||||||
|
b.procCancel()
|
||||||
|
|
||||||
|
req.respond <- nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *baseRouter) healthCheckTimeout() time.Duration {
|
||||||
|
t := time.Duration(b.config.HealthCheckTimeout) * time.Second
|
||||||
|
if t <= 0 {
|
||||||
|
return 30 * time.Second
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *baseRouter) Handles(model string) bool {
|
||||||
|
_, ok := b.processes[model]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *baseRouter) ProcessLogger(modelID string) (*logmon.Monitor, bool) {
|
||||||
|
if p, ok := b.processes[modelID]; ok {
|
||||||
|
return p.Logger(), true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunningModels returns the current state of every process that is not stopped
|
||||||
|
// or shut down. The processes map keys are fixed at construction and State()
|
||||||
|
// is a snapshot, so this is safe to call without the run loop.
|
||||||
|
func (b *baseRouter) RunningModels() map[string]process.ProcessState {
|
||||||
|
running := make(map[string]process.ProcessState)
|
||||||
|
for id, p := range b.processes {
|
||||||
|
st := p.State()
|
||||||
|
if st == process.StateStopped || st == process.StateShutdown {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
running[id] = st
|
||||||
|
}
|
||||||
|
return running
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unload stops the named models, or every running model when none are named.
|
||||||
|
// It blocks until each targeted process has stopped.
|
||||||
|
//
|
||||||
|
// The request is funneled through the run loop so eviction is coordinated
|
||||||
|
// with the rest of the router's state: pending swap waiters for an
|
||||||
|
// unloaded model are released with an error, queued requests for unloaded
|
||||||
|
// models are dropped, and any deferred swaps that were waiting on those
|
||||||
|
// models become eligible to start.
|
||||||
|
//
|
||||||
|
// In-flight requests being served by an unloaded process are not waited
|
||||||
|
// for — Stop kills the upstream, those callers see whatever error the
|
||||||
|
// reverse proxy surfaces and may retry. Their trackedServe defers fire
|
||||||
|
// normally and decrement inFlight as the dying handlers return.
|
||||||
|
func (b *baseRouter) Unload(timeout time.Duration, models ...string) {
|
||||||
|
targets := models
|
||||||
|
if len(targets) == 0 {
|
||||||
|
targets = make([]string, 0, len(b.processes))
|
||||||
|
for id := range b.processes {
|
||||||
|
targets = append(targets, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(targets) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req := unloadReq{targets: targets, timeout: timeout, respond: make(chan struct{})}
|
||||||
|
select {
|
||||||
|
case b.unloadCh <- req:
|
||||||
|
case <-b.runDone:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
<-req.respond
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleUnload runs on the run loop in response to an Unload call. It
|
||||||
|
// reconciles router-owned state with the impending Stop, then performs
|
||||||
|
// the Stop synchronously so callers of Unload remain blocked until each
|
||||||
|
// targeted process has actually exited.
|
||||||
|
func (b *baseRouter) handleUnload(req unloadReq, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
||||||
|
unloadErr := fmt.Errorf("%s: model unloaded", b.name)
|
||||||
|
|
||||||
|
targetSet := make(map[string]bool, len(req.targets))
|
||||||
|
for _, id := range req.targets {
|
||||||
|
targetSet[id] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release waiters of any in-flight swap whose target is being
|
||||||
|
// unloaded. The swap goroutine itself is left to finish on its own;
|
||||||
|
// when its swapDone arrives, handleSwapDone will find no entry in
|
||||||
|
// active and silently drop it.
|
||||||
|
for id := range targetSet {
|
||||||
|
s, ok := active[id]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, w := range s.waiters {
|
||||||
|
b.grant(w, handlerResp{err: unloadErr})
|
||||||
|
}
|
||||||
|
delete(active, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drop queued requests addressed to unloaded models. Requests for
|
||||||
|
// other models stay queued and may benefit from drainQueue at the end.
|
||||||
|
if len(*queued) > 0 {
|
||||||
|
kept := (*queued)[:0]
|
||||||
|
for _, w := range *queued {
|
||||||
|
if targetSet[w.model] {
|
||||||
|
b.grant(w, handlerResp{err: unloadErr})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
kept = append(kept, w)
|
||||||
|
}
|
||||||
|
*queued = kept
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop the targeted processes. Done synchronously so Unload's caller
|
||||||
|
// can rely on "after Unload returns, the process is stopped". inFlight
|
||||||
|
// is intentionally NOT cleared here: each dying handler will fire its
|
||||||
|
// trackedServe defer and reach handleServeDone in the normal way once
|
||||||
|
// the run loop is free again.
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for id := range targetSet {
|
||||||
|
p, ok := b.processes[id]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id string, p process.Process) {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := p.Stop(req.timeout); err != nil {
|
||||||
|
b.logger.Warnf("%s: unloading %s failed: %v", b.name, id, err)
|
||||||
|
}
|
||||||
|
}(id, p)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Removing entries from active above may have unblocked queued
|
||||||
|
// requests that previously collided with the now-cancelled swaps.
|
||||||
|
b.drainQueue(active, inFlight, queued)
|
||||||
|
|
||||||
|
close(req.respond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *baseRouter) Shutdown(timeout time.Duration) error {
|
||||||
|
if !b.shuttingDown.CompareAndSwap(false, true) {
|
||||||
|
return fmt.Errorf("%s shutdown already in progress", b.name)
|
||||||
|
}
|
||||||
|
req := shutdownReq{timeout: timeout, respond: make(chan error, 1)}
|
||||||
|
select {
|
||||||
|
case b.shutdownCh <- req:
|
||||||
|
case <-b.runDone:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return <-req.respond
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
|
if b.shuttingDown.Load() {
|
||||||
|
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := FetchContext(req, b.config)
|
||||||
|
if err != nil {
|
||||||
|
SendError(w, req, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hr := handlerReq{
|
||||||
|
model: data.ModelID,
|
||||||
|
ctx: req.Context(),
|
||||||
|
// Unbuffered: a successful send on respond proves the waiter is
|
||||||
|
// alive and consuming. grant() relies on this to avoid handing a
|
||||||
|
// handleFunc to a cancelled waiter and leaking the inFlight count.
|
||||||
|
respond: make(chan handlerResp),
|
||||||
|
positionCh: make(chan int, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case b.handlerCh <- hr:
|
||||||
|
case <-req.Context().Done():
|
||||||
|
return
|
||||||
|
case <-b.shutdownCtx.Done():
|
||||||
|
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
isModelReady := false
|
||||||
|
if p, ok := b.processes[data.ModelID]; ok {
|
||||||
|
isModelReady = p.State() == process.StateReady
|
||||||
|
}
|
||||||
|
shouldShowLoading := data.Streaming && data.SendLoadingState && isLoadingPath(req.URL.Path) && !isModelReady
|
||||||
|
|
||||||
|
var lw *loadingWriter
|
||||||
|
cancelLoad := func() {}
|
||||||
|
if shouldShowLoading {
|
||||||
|
var swapCtx context.Context
|
||||||
|
swapCtx, cancelLoad = context.WithCancel(req.Context())
|
||||||
|
lw = newLoadingWriter(b.logger, data.ModelID, w, req)
|
||||||
|
go lw.start(swapCtx)
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case pos := <-hr.positionCh:
|
||||||
|
lw.setUpdate(fmt.Sprintf("Queue position: #%d", pos))
|
||||||
|
case <-swapCtx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// finishLoading stops the loading stream and fences its goroutine off from
|
||||||
|
// the ResponseWriter before the real handler (or ServeHTTP's return)
|
||||||
|
// reclaims it. release() must run even when waitForCompletion times out:
|
||||||
|
// otherwise a still-streaming goroutine flushes a finalized response and
|
||||||
|
// panics on the recycled *bufio.Writer.
|
||||||
|
finishLoading := func() {
|
||||||
|
cancelLoad()
|
||||||
|
if lw != nil {
|
||||||
|
lw.waitForCompletion(1 * time.Second)
|
||||||
|
lw.release()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp handlerResp
|
||||||
|
select {
|
||||||
|
case resp = <-hr.respond:
|
||||||
|
finishLoading()
|
||||||
|
case <-req.Context().Done():
|
||||||
|
finishLoading()
|
||||||
|
return
|
||||||
|
case <-b.shutdownCtx.Done():
|
||||||
|
finishLoading()
|
||||||
|
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.err != nil {
|
||||||
|
SendError(w, req, resp.err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.handleFunc(w, req)
|
||||||
|
}
|
||||||
@@ -0,0 +1,863 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
// stubPlanner is a swapPlanner that returns a fixed eviction list per target
|
||||||
|
// and never logs. It lets the base-router tests cover shared run-loop
|
||||||
|
// behaviour without dragging in either real router's eviction rules.
|
||||||
|
type stubPlanner struct {
|
||||||
|
evict map[string][]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubPlanner) EvictionFor(target string, _ []string) []string {
|
||||||
|
if s.evict == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.evict[target]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubPlanner) OnSwapStart(string) {}
|
||||||
|
|
||||||
|
func newTestBase(t *testing.T, processes map[string]process.Process, planner swapPlanner) *baseRouter {
|
||||||
|
t.Helper()
|
||||||
|
conf := config.Config{HealthCheckTimeout: 5}
|
||||||
|
b := newBaseRouter("test", conf, processes, planner, logmon.NewWriter(io.Discard))
|
||||||
|
b.testProcessed = make(chan struct{}, 64)
|
||||||
|
go b.run()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if !b.shuttingDown.Load() {
|
||||||
|
_ = b.Shutdown(time.Second)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBaseRouter_RunningModels(t *testing.T) {
|
||||||
|
ready := newFakeProcess("ready")
|
||||||
|
ready.markReady()
|
||||||
|
starting := newFakeProcess("starting")
|
||||||
|
starting.setState(process.StateStarting)
|
||||||
|
stopped := newFakeProcess("stopped")
|
||||||
|
|
||||||
|
b := newTestBase(t, map[string]process.Process{
|
||||||
|
"ready": ready, "starting": starting, "stopped": stopped,
|
||||||
|
}, &stubPlanner{})
|
||||||
|
|
||||||
|
running := b.RunningModels()
|
||||||
|
if len(running) != 2 {
|
||||||
|
t.Fatalf("running=%v want 2 entries", running)
|
||||||
|
}
|
||||||
|
if running["ready"] != process.StateReady {
|
||||||
|
t.Errorf("ready state=%q want ready", running["ready"])
|
||||||
|
}
|
||||||
|
if running["starting"] != process.StateStarting {
|
||||||
|
t.Errorf("starting state=%q want starting", running["starting"])
|
||||||
|
}
|
||||||
|
if _, ok := running["stopped"]; ok {
|
||||||
|
t.Errorf("stopped process should be excluded from RunningModels")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBaseRouter_UnloadAll(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
a.markReady()
|
||||||
|
c := newFakeProcess("c")
|
||||||
|
c.markReady()
|
||||||
|
|
||||||
|
b := newTestBase(t, map[string]process.Process{"a": a, "c": c}, &stubPlanner{})
|
||||||
|
b.Unload(time.Second)
|
||||||
|
|
||||||
|
if a.State() != process.StateStopped || c.State() != process.StateStopped {
|
||||||
|
t.Fatalf("Unload() should stop every process: a=%q c=%q", a.State(), c.State())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBaseRouter_UnloadSpecificModel(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
a.markReady()
|
||||||
|
c := newFakeProcess("c")
|
||||||
|
c.markReady()
|
||||||
|
|
||||||
|
b := newTestBase(t, map[string]process.Process{"a": a, "c": c}, &stubPlanner{})
|
||||||
|
b.Unload(time.Second, "a")
|
||||||
|
|
||||||
|
if a.State() != process.StateStopped {
|
||||||
|
t.Errorf("a should be stopped, got %q", a.State())
|
||||||
|
}
|
||||||
|
if c.State() != process.StateReady {
|
||||||
|
t.Errorf("c should remain ready, got %q", c.State())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBaseRouter_Unload_StopsInParallel verifies that Unload fans out its
|
||||||
|
// Stop calls concurrently rather than stopping each process serially. Each
|
||||||
|
// fakeProcess.Stop is pinned via stopBlock; the test only releases them
|
||||||
|
// after observing every stopStarted, proving all three Stops were in
|
||||||
|
// flight simultaneously.
|
||||||
|
func TestBaseRouter_Unload_StopsInParallel(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
a.markReady()
|
||||||
|
a.stopBlock = make(chan struct{})
|
||||||
|
pb := newFakeProcess("b")
|
||||||
|
pb.markReady()
|
||||||
|
pb.stopBlock = make(chan struct{})
|
||||||
|
pc := newFakeProcess("c")
|
||||||
|
pc.markReady()
|
||||||
|
pc.stopBlock = make(chan struct{})
|
||||||
|
|
||||||
|
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, &stubPlanner{})
|
||||||
|
|
||||||
|
unloadDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
b.Unload(time.Second, "a", "b", "c")
|
||||||
|
close(unloadDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// All three Stop calls must start before any of them are allowed to
|
||||||
|
// complete. If Unload was serial, only one stopStarted would fire
|
||||||
|
// until we released its stopBlock, and this would deadlock.
|
||||||
|
for _, p := range []*fakeProcess{a, pb, pc} {
|
||||||
|
select {
|
||||||
|
case <-p.stopStarted:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatalf("Stop on %s never started — Unload is not parallel", p.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release them; Unload should now return.
|
||||||
|
close(a.stopBlock)
|
||||||
|
close(pb.stopBlock)
|
||||||
|
close(pc.stopBlock)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-unloadDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("Unload did not return after stops released")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range []*fakeProcess{a, pb, pc} {
|
||||||
|
if p.State() != process.StateStopped {
|
||||||
|
t.Errorf("%s state=%q want stopped", p.id, p.State())
|
||||||
|
}
|
||||||
|
if got := p.stopCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("%s stopCalls=%d want 1", p.id, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBaseRouter_Unload_ReleasesActiveSwapWaiters verifies that Unload
|
||||||
|
// rejoins router state: a request whose swap to the unloaded model is
|
||||||
|
// still in progress receives an error, instead of being abandoned
|
||||||
|
// against a process that's about to vanish.
|
||||||
|
func TestBaseRouter_Unload_ReleasesActiveSwapWaiters(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
// autoReady=false: the swap parks on WaitReady so we can interrupt
|
||||||
|
// it with Unload before it completes.
|
||||||
|
|
||||||
|
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
b.ServeHTTP(w, newRequest("a"))
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
waitProcessed(t, b.testProcessed, 1) // handlerReq absorbed; swap started
|
||||||
|
<-a.runStarted
|
||||||
|
|
||||||
|
b.Unload(time.Second, "a")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("ServeHTTP did not return after Unload")
|
||||||
|
}
|
||||||
|
if w.Code == http.StatusOK {
|
||||||
|
t.Errorf("expected non-OK status after Unload, got %d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if a.State() != process.StateStopped {
|
||||||
|
t.Errorf("a state=%q want stopped", a.State())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBaseRouter_Unload_DropsQueuedRequests verifies that queued requests
|
||||||
|
// for an unloaded model receive an error rather than sitting forever in
|
||||||
|
// the queue against state the router no longer maintains.
|
||||||
|
func TestBaseRouter_Unload_DropsQueuedRequests(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
pb := newFakeProcess("b")
|
||||||
|
// Loading B evicts A — so a request for B while A is loading queues.
|
||||||
|
planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}}
|
||||||
|
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, planner)
|
||||||
|
|
||||||
|
// r1 starts the swap to A and parks on WaitReady (autoReady=false).
|
||||||
|
w1 := httptest.NewRecorder()
|
||||||
|
done1 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
b.ServeHTTP(w1, newRequest("a"))
|
||||||
|
close(done1)
|
||||||
|
}()
|
||||||
|
waitProcessed(t, b.testProcessed, 1)
|
||||||
|
<-a.runStarted
|
||||||
|
|
||||||
|
// r2 for B collides with A's in-flight swap and queues.
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
done2 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
b.ServeHTTP(w2, newRequest("b"))
|
||||||
|
close(done2)
|
||||||
|
}()
|
||||||
|
waitProcessed(t, b.testProcessed, 1)
|
||||||
|
|
||||||
|
// Unload B — r2 (queued, targeting B) must be released with an error.
|
||||||
|
b.Unload(time.Second, "b")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done2:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("queued B request did not return after Unload(b)")
|
||||||
|
}
|
||||||
|
if w2.Code == http.StatusOK {
|
||||||
|
t.Errorf("queued B request: expected non-OK status, got %d", w2.Code)
|
||||||
|
}
|
||||||
|
if got := pb.runCalls.Load(); got != 0 {
|
||||||
|
t.Errorf("b.runCalls=%d want 0 (B should never have been started)", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release r1 so the test cleans up cleanly.
|
||||||
|
a.markReady()
|
||||||
|
select {
|
||||||
|
case <-done1:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("r1 did not complete after a.markReady")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBaseRouter_FastPath(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
a.markReady()
|
||||||
|
|
||||||
|
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
b.ServeHTTP(w, newRequest("a"))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if got := a.serveCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("serveCalls=%d want 1", got)
|
||||||
|
}
|
||||||
|
if got := a.runCalls.Load(); got != 0 {
|
||||||
|
t.Errorf("runCalls=%d want 0 (fast path should not start)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBaseRouter_OnDemandStart(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
a.autoReady = true
|
||||||
|
|
||||||
|
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
b.ServeHTTP(w, newRequest("a"))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if got := a.runCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("runCalls=%d want 1", got)
|
||||||
|
}
|
||||||
|
if got := a.serveCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("serveCalls=%d want 1", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBaseRouter_ConcurrentSameModel(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
// autoReady=false so the swap parks on WaitReady until we release it.
|
||||||
|
|
||||||
|
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||||
|
|
||||||
|
const N = 5
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
codes := make([]int, N)
|
||||||
|
for i := 0; i < N; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
b.ServeHTTP(w, newRequest("a"))
|
||||||
|
codes[i] = w.Code
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
waitProcessed(t, b.testProcessed, N) // all N handlerReqs absorbed by run()
|
||||||
|
<-a.runStarted // swap goroutine reached Run()
|
||||||
|
a.markReady()
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
for i, c := range codes {
|
||||||
|
if c != http.StatusOK {
|
||||||
|
t.Errorf("request %d: status=%d", i, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if got := a.runCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("runCalls=%d want 1 (single swap should issue one Run)", got)
|
||||||
|
}
|
||||||
|
if got := a.serveCalls.Load(); got != N {
|
||||||
|
t.Errorf("serveCalls=%d want %d", got, N)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBaseRouter_ContextCancel(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
// autoReady=false so swap parks forever until we mark ready.
|
||||||
|
|
||||||
|
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
w1 := httptest.NewRecorder()
|
||||||
|
done1 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
b.ServeHTTP(w1, newRequestCtx(ctx, "a"))
|
||||||
|
close(done1)
|
||||||
|
}()
|
||||||
|
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
done2 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
b.ServeHTTP(w2, newRequest("a"))
|
||||||
|
close(done2)
|
||||||
|
}()
|
||||||
|
|
||||||
|
waitProcessed(t, b.testProcessed, 2) // both requests joined the active swap
|
||||||
|
<-a.runStarted
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
select {
|
||||||
|
case <-done1:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("cancelled ServeHTTP did not return after ctx cancel")
|
||||||
|
}
|
||||||
|
|
||||||
|
a.markReady()
|
||||||
|
select {
|
||||||
|
case <-done2:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("non-cancelled ServeHTTP did not complete after swap")
|
||||||
|
}
|
||||||
|
if w2.Code != http.StatusOK {
|
||||||
|
t.Errorf("second request status=%d body=%q", w2.Code, w2.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBaseRouter_QueuedDifferentModel(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
pa := newFakeProcess("b")
|
||||||
|
|
||||||
|
// Loading b must stop a.
|
||||||
|
planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}}
|
||||||
|
b := newTestBase(t, map[string]process.Process{"a": a, "b": pa}, planner)
|
||||||
|
|
||||||
|
// First request starts a swap to A; A's autoReady=false so it parks.
|
||||||
|
w1 := httptest.NewRecorder()
|
||||||
|
done1 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
b.ServeHTTP(w1, newRequest("a"))
|
||||||
|
close(done1)
|
||||||
|
}()
|
||||||
|
waitProcessed(t, b.testProcessed, 1)
|
||||||
|
<-a.runStarted
|
||||||
|
|
||||||
|
// Second request for B should queue while A's swap is in flight.
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
done2 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
b.ServeHTTP(w2, newRequest("b"))
|
||||||
|
close(done2)
|
||||||
|
}()
|
||||||
|
waitProcessed(t, b.testProcessed, 1)
|
||||||
|
|
||||||
|
if got := pa.runCalls.Load(); got != 0 {
|
||||||
|
t.Errorf("b started early: runCalls=%d want 0 while A's swap is pending", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release A's swap. B's swap should then run.
|
||||||
|
a.markReady()
|
||||||
|
waitProcessed(t, b.testProcessed, 1) // swapDone for A → B's swap kicked off
|
||||||
|
<-pa.runStarted
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done1:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("A request did not complete")
|
||||||
|
}
|
||||||
|
pa.markReady()
|
||||||
|
select {
|
||||||
|
case <-done2:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("queued B request did not complete after A's swap")
|
||||||
|
}
|
||||||
|
if w2.Code != http.StatusOK {
|
||||||
|
t.Errorf("B status=%d body=%q", w2.Code, w2.Body.String())
|
||||||
|
}
|
||||||
|
if got := a.stopCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("a.stopCalls=%d want 1 (B's swap must stop A)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBaseRouter_QueueCollation verifies that incoming requests of the form
|
||||||
|
// a, b, c, a, b, c collapse into three swaps (one per model) and that the
|
||||||
|
// second request for each model rides the fast path — either joining the
|
||||||
|
// active swap, or being pulled out of the queue when handleSwapDone promotes
|
||||||
|
// the next model.
|
||||||
|
func TestBaseRouter_QueueCollation(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
pb := newFakeProcess("b")
|
||||||
|
pc := newFakeProcess("c")
|
||||||
|
|
||||||
|
// Each model evicts the other two so all swaps are mutually exclusive.
|
||||||
|
planner := &stubPlanner{evict: map[string][]string{
|
||||||
|
"a": {"b", "c"},
|
||||||
|
"b": {"a", "c"},
|
||||||
|
"c": {"a", "b"},
|
||||||
|
}}
|
||||||
|
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner)
|
||||||
|
|
||||||
|
var (
|
||||||
|
completedMu sync.Mutex
|
||||||
|
completed []string
|
||||||
|
)
|
||||||
|
record := func(id string) {
|
||||||
|
completedMu.Lock()
|
||||||
|
defer completedMu.Unlock()
|
||||||
|
completed = append(completed, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
ids := []string{"a", "b", "c", "a", "b", "c"}
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, id := range ids {
|
||||||
|
id := id
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
b.ServeHTTP(w, newRequest(id))
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("%s: status=%d body=%q", id, w.Code, w.Body.String())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
record(id)
|
||||||
|
}()
|
||||||
|
// Wait for run() to absorb this request before launching the next,
|
||||||
|
// so handlerCh receives them in launch order.
|
||||||
|
waitProcessed(t, b.testProcessed, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// All 6 are now parked in run()'s waiters/queue. Release each swap in
|
||||||
|
// sequence, waiting deterministically for each promotion to fire.
|
||||||
|
<-a.runStarted
|
||||||
|
a.markReady()
|
||||||
|
waitProcessed(t, b.testProcessed, 1) // swapDone(a) → b swap kicked off
|
||||||
|
|
||||||
|
<-pb.runStarted
|
||||||
|
pb.markReady()
|
||||||
|
waitProcessed(t, b.testProcessed, 1) // swapDone(b) → c swap kicked off
|
||||||
|
|
||||||
|
<-pc.runStarted
|
||||||
|
pc.markReady()
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if got := len(completed); got != 6 {
|
||||||
|
t.Fatalf("completed=%v want 6", completed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// run() fans out responses in model-grouped order (a1,a2 → b1,b2 → c1,c2)
|
||||||
|
// but waiter goroutines may be scheduled in any order after their respond
|
||||||
|
// channel fires, so completion order isn't deterministic. Per-model counts
|
||||||
|
// (combined with the runCalls checks below) are sufficient to prove queue
|
||||||
|
// collation collapsed each pair into a single swap.
|
||||||
|
aDone, bDone, cDone := 0, 0, 0
|
||||||
|
for _, id := range completed {
|
||||||
|
switch id {
|
||||||
|
case "a":
|
||||||
|
aDone++
|
||||||
|
case "b":
|
||||||
|
bDone++
|
||||||
|
case "c":
|
||||||
|
cDone++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if aDone != 2 || bDone != 2 || cDone != 2 {
|
||||||
|
t.Errorf("per-model counts: a=%d b=%d c=%d, want 2 each (order=%v)", aDone, bDone, cDone, completed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Single swap per model — the second request for each must have ridden
|
||||||
|
// the fast path (joined active swap or joined a queued sibling), not
|
||||||
|
// triggered an extra Run.
|
||||||
|
if got := a.runCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("a.runCalls=%d want 1", got)
|
||||||
|
}
|
||||||
|
if got := pb.runCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("b.runCalls=%d want 1", got)
|
||||||
|
}
|
||||||
|
if got := pc.runCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("c.runCalls=%d want 1", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBaseRouter_ConcurrentDisjointSwaps verifies that two requests with
|
||||||
|
// non-conflicting evict sets are loaded in parallel: both Run() calls happen
|
||||||
|
// before either process is marked ready.
|
||||||
|
func TestBaseRouter_ConcurrentDisjointSwaps(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
pb := newFakeProcess("b")
|
||||||
|
|
||||||
|
// Empty evict sets for both: they can load in parallel.
|
||||||
|
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, &stubPlanner{})
|
||||||
|
|
||||||
|
w1 := httptest.NewRecorder()
|
||||||
|
done1 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
b.ServeHTTP(w1, newRequest("a"))
|
||||||
|
close(done1)
|
||||||
|
}()
|
||||||
|
waitProcessed(t, b.testProcessed, 1)
|
||||||
|
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
done2 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
b.ServeHTTP(w2, newRequest("b"))
|
||||||
|
close(done2)
|
||||||
|
}()
|
||||||
|
waitProcessed(t, b.testProcessed, 1)
|
||||||
|
|
||||||
|
// Both swaps must have reached Run() before either is marked ready —
|
||||||
|
// proves they ran in parallel rather than serializing.
|
||||||
|
<-a.runStarted
|
||||||
|
<-pb.runStarted
|
||||||
|
|
||||||
|
a.markReady()
|
||||||
|
pb.markReady()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done1:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("request A did not complete")
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-done2:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("request B did not complete")
|
||||||
|
}
|
||||||
|
|
||||||
|
if w1.Code != http.StatusOK {
|
||||||
|
t.Errorf("A status=%d body=%q", w1.Code, w1.Body.String())
|
||||||
|
}
|
||||||
|
if w2.Code != http.StatusOK {
|
||||||
|
t.Errorf("B status=%d body=%q", w2.Code, w2.Body.String())
|
||||||
|
}
|
||||||
|
if got := a.stopCalls.Load(); got != 0 {
|
||||||
|
t.Errorf("a.stopCalls=%d want 0 (parallel swap, no eviction)", got)
|
||||||
|
}
|
||||||
|
if got := pb.stopCalls.Load(); got != 0 {
|
||||||
|
t.Errorf("b.stopCalls=%d want 0 (parallel swap, no eviction)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBaseRouter_QueueDrainPromotesMultiple verifies that completing one swap
|
||||||
|
// unblocks every queued request that no longer collides — they all start in
|
||||||
|
// parallel rather than one-per-completion.
|
||||||
|
func TestBaseRouter_QueueDrainPromotesMultiple(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
pb := newFakeProcess("b")
|
||||||
|
pc := newFakeProcess("c")
|
||||||
|
|
||||||
|
// A's swap evicts both B and C, so B and C must queue. Once A finishes
|
||||||
|
// B and C themselves have empty evict sets, so they can start together.
|
||||||
|
planner := &stubPlanner{evict: map[string][]string{
|
||||||
|
"a": {"b", "c"},
|
||||||
|
}}
|
||||||
|
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner)
|
||||||
|
|
||||||
|
w1 := httptest.NewRecorder()
|
||||||
|
done1 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
b.ServeHTTP(w1, newRequest("a"))
|
||||||
|
close(done1)
|
||||||
|
}()
|
||||||
|
waitProcessed(t, b.testProcessed, 1)
|
||||||
|
<-a.runStarted
|
||||||
|
|
||||||
|
// B and C arrive while A is loading. evict_b and evict_c are empty,
|
||||||
|
// but collidesWith returns true because they appear in A's evict set.
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
done2 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
b.ServeHTTP(w2, newRequest("b"))
|
||||||
|
close(done2)
|
||||||
|
}()
|
||||||
|
waitProcessed(t, b.testProcessed, 1)
|
||||||
|
|
||||||
|
w3 := httptest.NewRecorder()
|
||||||
|
done3 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
b.ServeHTTP(w3, newRequest("c"))
|
||||||
|
close(done3)
|
||||||
|
}()
|
||||||
|
waitProcessed(t, b.testProcessed, 1)
|
||||||
|
|
||||||
|
if got := pb.runCalls.Load(); got != 0 {
|
||||||
|
t.Errorf("b started early: runCalls=%d", got)
|
||||||
|
}
|
||||||
|
if got := pc.runCalls.Load(); got != 0 {
|
||||||
|
t.Errorf("c started early: runCalls=%d", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release A. The swapDone handler should drain the queue and start
|
||||||
|
// both B and C in parallel.
|
||||||
|
a.markReady()
|
||||||
|
waitProcessed(t, b.testProcessed, 1) // swapDone(A) → drainQueue starts B and C
|
||||||
|
<-pb.runStarted
|
||||||
|
<-pc.runStarted
|
||||||
|
|
||||||
|
pb.markReady()
|
||||||
|
pc.markReady()
|
||||||
|
|
||||||
|
for i, ch := range []chan struct{}{done1, done2, done3} {
|
||||||
|
select {
|
||||||
|
case <-ch:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatalf("request %d did not complete", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBaseRouter_Shutdown_FailsAllInFlight verifies that shutdown returns
|
||||||
|
// the shutdown error to every waiter on every active swap AND to every
|
||||||
|
// queued request.
|
||||||
|
func TestBaseRouter_Shutdown_FailsAllInFlight(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
pb := newFakeProcess("b")
|
||||||
|
pc := newFakeProcess("c")
|
||||||
|
|
||||||
|
// a and b load in parallel (empty evicts). c collides with both.
|
||||||
|
planner := &stubPlanner{evict: map[string][]string{
|
||||||
|
"c": {"a", "b"},
|
||||||
|
}}
|
||||||
|
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner)
|
||||||
|
|
||||||
|
const waitersPer = 2
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
codes := make([]int, 0, 2*waitersPer+1)
|
||||||
|
var codesMu sync.Mutex
|
||||||
|
record := func(code int) {
|
||||||
|
codesMu.Lock()
|
||||||
|
codes = append(codes, code)
|
||||||
|
codesMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
launch := func(model string) {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
b.ServeHTTP(w, newRequest(model))
|
||||||
|
record(w.Code)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Active swaps for a and b, each with 2 waiters.
|
||||||
|
for i := 0; i < waitersPer; i++ {
|
||||||
|
launch("a")
|
||||||
|
waitProcessed(t, b.testProcessed, 1)
|
||||||
|
}
|
||||||
|
for i := 0; i < waitersPer; i++ {
|
||||||
|
launch("b")
|
||||||
|
waitProcessed(t, b.testProcessed, 1)
|
||||||
|
}
|
||||||
|
// c collides with both → queues.
|
||||||
|
launch("c")
|
||||||
|
waitProcessed(t, b.testProcessed, 1)
|
||||||
|
|
||||||
|
<-a.runStarted
|
||||||
|
<-pb.runStarted
|
||||||
|
|
||||||
|
if err := b.Shutdown(time.Second); err != nil {
|
||||||
|
t.Fatalf("Shutdown: %v", err)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
codesMu.Lock()
|
||||||
|
defer codesMu.Unlock()
|
||||||
|
if len(codes) != 2*waitersPer+1 {
|
||||||
|
t.Fatalf("got %d responses, want %d", len(codes), 2*waitersPer+1)
|
||||||
|
}
|
||||||
|
for i, c := range codes {
|
||||||
|
if c == http.StatusOK {
|
||||||
|
t.Errorf("response %d: status=%d, want non-200 (shutdown)", i, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBaseRouter_NoSwapWhileServing verifies that an already-loaded model
|
||||||
|
// is not stopped to satisfy another model's swap while it is still handling
|
||||||
|
// a request.
|
||||||
|
//
|
||||||
|
// Sequence:
|
||||||
|
// 1. r1 (A) — A loads; ServeHTTP enters and is pinned via serveBlock.
|
||||||
|
// 2. r2 (B, planner: B evicts A) — must NOT cause A.Stop while r1 is live.
|
||||||
|
// 3. r3 (A) — arrives next; the existing code queues it because B's swap
|
||||||
|
// intent collides with A.
|
||||||
|
// 4. r1 released — A finishes r1, then r3 is served by A.
|
||||||
|
// 5. B's swap then proceeds; r2 is served by B.
|
||||||
|
//
|
||||||
|
// fakeProcess.stoppedWhileServing flips true if Stop is ever called while
|
||||||
|
// a ServeHTTP is in flight — a direct, race-free signal of the violation.
|
||||||
|
func TestBaseRouter_NoSwapWhileServing(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
// autoReady left false: we markReady manually after observing runStarted,
|
||||||
|
// so autoReady's setState(Ready) cannot race with a later Stop and leave
|
||||||
|
// A in Ready, masking the bug.
|
||||||
|
a.serveBlock = make(chan struct{})
|
||||||
|
pb := newFakeProcess("b")
|
||||||
|
// Same reasoning for B: park its swap on WaitReady until we choose.
|
||||||
|
|
||||||
|
planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}}
|
||||||
|
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, planner)
|
||||||
|
|
||||||
|
// r1 — load A and enter its ServeHTTP (which blocks on serveBlock).
|
||||||
|
w1 := httptest.NewRecorder()
|
||||||
|
done1 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
b.ServeHTTP(w1, newRequest("a"))
|
||||||
|
close(done1)
|
||||||
|
}()
|
||||||
|
waitProcessed(t, b.testProcessed, 1) // handlerReq for r1
|
||||||
|
<-a.runStarted
|
||||||
|
a.markReady()
|
||||||
|
waitProcessed(t, b.testProcessed, 1) // swapDone for A
|
||||||
|
<-a.serveStarted
|
||||||
|
|
||||||
|
// r2 — would evict A. A must not be stopped while r1 is in flight.
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
done2 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
b.ServeHTTP(w2, newRequest("b"))
|
||||||
|
close(done2)
|
||||||
|
}()
|
||||||
|
waitProcessed(t, b.testProcessed, 1)
|
||||||
|
|
||||||
|
// r3 — another request for A, arrives behind r2 and queues because
|
||||||
|
// B's swap intent (which evicts A) is recorded as active.
|
||||||
|
w3 := httptest.NewRecorder()
|
||||||
|
done3 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
b.ServeHTTP(w3, newRequest("a"))
|
||||||
|
close(done3)
|
||||||
|
}()
|
||||||
|
waitProcessed(t, b.testProcessed, 1)
|
||||||
|
|
||||||
|
// Release r1 (and r3 if it is fast-pathed onto the still-loaded A).
|
||||||
|
// The router must hold off B's swap until A has drained.
|
||||||
|
close(a.serveBlock)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done1:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("r1 did not complete after serveBlock release")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for B.Run before marking it ready: markReady before Run would
|
||||||
|
// skip the Run path entirely and leave pb.runCalls at 0. In a correct
|
||||||
|
// implementation B's swap only starts after A has drained; in the
|
||||||
|
// current implementation it has already started — either way runStarted
|
||||||
|
// fires.
|
||||||
|
<-pb.runStarted
|
||||||
|
pb.markReady()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done2:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("r2 did not complete after B marked ready")
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-done3:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("r3 did not complete")
|
||||||
|
}
|
||||||
|
|
||||||
|
if w1.Code != http.StatusOK || w2.Code != http.StatusOK || w3.Code != http.StatusOK {
|
||||||
|
t.Fatalf("statuses: w1=%d w2=%d w3=%d", w1.Code, w2.Code, w3.Code)
|
||||||
|
}
|
||||||
|
if w1.Body.String() != "ok:a" {
|
||||||
|
t.Errorf("r1 body=%q want ok:a", w1.Body.String())
|
||||||
|
}
|
||||||
|
if w3.Body.String() != "ok:a" {
|
||||||
|
t.Errorf("r3 body=%q want ok:a (r3 must be served by A)", w3.Body.String())
|
||||||
|
}
|
||||||
|
if w2.Body.String() != "ok:b" {
|
||||||
|
t.Errorf("r2 body=%q want ok:b", w2.Body.String())
|
||||||
|
}
|
||||||
|
if a.stoppedWhileServing.Load() {
|
||||||
|
t.Errorf("A.Stop was called while A was still handling a request — the router swapped out a busy process")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBaseRouter_ModelNotFound(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
b.ServeHTTP(w, newRequest("unknown"))
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Errorf("status=%d want %d body=%q", w.Code, http.StatusNotFound, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBaseRouter_Shutdown_StopsAllProcesses(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
a.markReady()
|
||||||
|
go a.Run(0)
|
||||||
|
pb := newFakeProcess("b")
|
||||||
|
pb.markReady()
|
||||||
|
go pb.Run(0)
|
||||||
|
|
||||||
|
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, &stubPlanner{})
|
||||||
|
|
||||||
|
if err := b.Shutdown(time.Second); err != nil {
|
||||||
|
t.Fatalf("Shutdown: %v", err)
|
||||||
|
}
|
||||||
|
if got := a.stopCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("a.stopCalls=%d want 1", got)
|
||||||
|
}
|
||||||
|
if got := pb.stopCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("b.stopCalls=%d want 1", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subsequent ServeHTTP should report 5xx.
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
b.ServeHTTP(w, newRequest("a"))
|
||||||
|
if w.Code != http.StatusInternalServerError && w.Code != http.StatusServiceUnavailable {
|
||||||
|
t.Errorf("post-shutdown status=%d want 5xx body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second Shutdown should report already in progress.
|
||||||
|
if err := b.Shutdown(0); err == nil {
|
||||||
|
t.Errorf("second Shutdown returned nil, want error")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,112 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Group struct {
|
||||||
|
*baseRouter
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGroup(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Group, error) {
|
||||||
|
modelToGroup := make(map[string]string)
|
||||||
|
for gid, gcfg := range conf.Groups {
|
||||||
|
for _, mid := range gcfg.Members {
|
||||||
|
if existing, dup := modelToGroup[mid]; dup {
|
||||||
|
return nil, fmt.Errorf("model %q is in multiple groups: %q and %q", mid, existing, gid)
|
||||||
|
}
|
||||||
|
modelToGroup[mid] = gid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
planner := &groupPlanner{
|
||||||
|
config: conf,
|
||||||
|
modelToGroup: modelToGroup,
|
||||||
|
}
|
||||||
|
|
||||||
|
processes := make(map[string]process.Process, len(modelToGroup))
|
||||||
|
base := newBaseRouter("group", conf, processes, planner, proxylog)
|
||||||
|
planner.processes = processes
|
||||||
|
|
||||||
|
for mid := range modelToGroup {
|
||||||
|
modelCfg, _, ok := conf.FindConfig(mid)
|
||||||
|
if !ok {
|
||||||
|
base.shutdownFn()
|
||||||
|
base.procCancel()
|
||||||
|
return nil, fmt.Errorf("no model config for %q", mid)
|
||||||
|
}
|
||||||
|
procLog := logmon.NewWriter(upstreamlog)
|
||||||
|
p, err := process.New(base.procCtx, mid, modelCfg, procLog, proxylog)
|
||||||
|
if err != nil {
|
||||||
|
base.shutdownFn()
|
||||||
|
base.procCancel()
|
||||||
|
return nil, fmt.Errorf("creating process for %q: %w", mid, err)
|
||||||
|
}
|
||||||
|
processes[mid] = p
|
||||||
|
}
|
||||||
|
|
||||||
|
g := &Group{baseRouter: base}
|
||||||
|
go base.run()
|
||||||
|
return g, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// groupPlanner decides evictions from static group configuration.
|
||||||
|
//
|
||||||
|
// Same-group siblings are stopped when the group has swap=true. Cross-group
|
||||||
|
// members are stopped only when the target's group is exclusive; loading a
|
||||||
|
// model from a non-exclusive group leaves running exclusive groups alone,
|
||||||
|
// matching the gotcha in the original ProcessGroup behaviour.
|
||||||
|
type groupPlanner struct {
|
||||||
|
config config.Config
|
||||||
|
modelToGroup map[string]string
|
||||||
|
processes map[string]process.Process
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *groupPlanner) EvictionFor(target string, alsoRunning []string) []string {
|
||||||
|
tg := p.modelToGroup[target]
|
||||||
|
tgCfg := p.config.Groups[tg]
|
||||||
|
|
||||||
|
seen := make(map[string]struct{})
|
||||||
|
var result []string
|
||||||
|
consider := func(mID string) {
|
||||||
|
if mID == target {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, dup := seen[mID]; dup {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
og := p.modelToGroup[mID]
|
||||||
|
switch {
|
||||||
|
case og == tg && tgCfg.Swap:
|
||||||
|
seen[mID] = struct{}{}
|
||||||
|
result = append(result, mID)
|
||||||
|
// the previous ProcessGroup behaviour did not unload exclusive groups
|
||||||
|
// when loading a non-exclusive model. This maintains that gotcha
|
||||||
|
// for backwards compatibility. The newer swap matrix approach does not
|
||||||
|
// have this issue.
|
||||||
|
case og != tg && tgCfg.Exclusive:
|
||||||
|
if ogCfg := p.config.Groups[og]; !ogCfg.Persistent {
|
||||||
|
seen[mID] = struct{}{}
|
||||||
|
result = append(result, mID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for mID, proc := range p.processes {
|
||||||
|
st := proc.State()
|
||||||
|
if st == process.StateStopped || st == process.StateShutdown {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
consider(mID)
|
||||||
|
}
|
||||||
|
for _, mID := range alsoRunning {
|
||||||
|
consider(mID)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *groupPlanner) OnSwapStart(target string) {}
|
||||||
@@ -0,0 +1,331 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newTestGroup builds a Group directly from the supplied processes and config,
|
||||||
|
// bypassing NewGroup's call to process.New.
|
||||||
|
func newTestGroup(t *testing.T, conf config.Config, processes map[string]process.Process) *Group {
|
||||||
|
t.Helper()
|
||||||
|
modelToGroup := make(map[string]string)
|
||||||
|
for gid, gcfg := range conf.Groups {
|
||||||
|
for _, mid := range gcfg.Members {
|
||||||
|
modelToGroup[mid] = gid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
planner := &groupPlanner{
|
||||||
|
config: conf,
|
||||||
|
modelToGroup: modelToGroup,
|
||||||
|
processes: processes,
|
||||||
|
}
|
||||||
|
base := newBaseRouter("group", conf, processes, planner, logmon.NewWriter(io.Discard))
|
||||||
|
base.testProcessed = make(chan struct{}, 64)
|
||||||
|
g := &Group{baseRouter: base}
|
||||||
|
go base.run()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if !g.shuttingDown.Load() {
|
||||||
|
_ = g.Shutdown(time.Second)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return g
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroup_NewGroup_DuplicateMembership(t *testing.T) {
|
||||||
|
conf := config.Config{
|
||||||
|
Groups: map[string]config.GroupConfig{
|
||||||
|
"g1": {Swap: true, Members: []string{"a"}},
|
||||||
|
"g2": {Swap: true, Members: []string{"a"}},
|
||||||
|
},
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"a": {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
log := logmon.NewWriter(io.Discard)
|
||||||
|
if _, err := NewGroup(conf, log, log); err == nil {
|
||||||
|
t.Fatalf("expected error for duplicate membership")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroup_ServeHTTP_SwapStopsPrevious(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
a.markReady()
|
||||||
|
go a.Run(0) // park a Run goroutine so Stop has something to release
|
||||||
|
|
||||||
|
b := newFakeProcess("b")
|
||||||
|
b.autoReady = true
|
||||||
|
|
||||||
|
conf := config.Config{
|
||||||
|
HealthCheckTimeout: 5,
|
||||||
|
Groups: map[string]config.GroupConfig{
|
||||||
|
"g": {Swap: true, Exclusive: true, Members: []string{"a", "b"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
g.ServeHTTP(w, newRequest("b"))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if got := a.stopCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("a.stopCalls=%d want 1", got)
|
||||||
|
}
|
||||||
|
if got := b.runCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("b.runCalls=%d want 1", got)
|
||||||
|
}
|
||||||
|
if got := b.serveCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("b.serveCalls=%d want 1", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroup_NonSwapGroup_NoStop(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
a.markReady()
|
||||||
|
|
||||||
|
b := newFakeProcess("b")
|
||||||
|
b.autoReady = true
|
||||||
|
|
||||||
|
conf := config.Config{
|
||||||
|
HealthCheckTimeout: 5,
|
||||||
|
Groups: map[string]config.GroupConfig{
|
||||||
|
"g": {Swap: false, Exclusive: false, Members: []string{"a", "b"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
g.ServeHTTP(w, newRequest("b"))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if got := a.stopCalls.Load(); got != 0 {
|
||||||
|
t.Errorf("a.stopCalls=%d want 0 (swap=false should not stop siblings)", got)
|
||||||
|
}
|
||||||
|
if got := b.runCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("b.runCalls=%d want 1", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroup_CrossGroupExclusive(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
a.markReady()
|
||||||
|
go a.Run(0)
|
||||||
|
|
||||||
|
b := newFakeProcess("b")
|
||||||
|
b.autoReady = true
|
||||||
|
|
||||||
|
conf := config.Config{
|
||||||
|
HealthCheckTimeout: 5,
|
||||||
|
Groups: map[string]config.GroupConfig{
|
||||||
|
"g1": {Swap: true, Exclusive: true, Members: []string{"a"}},
|
||||||
|
"g2": {Swap: true, Exclusive: true, Members: []string{"b"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
g.ServeHTTP(w, newRequest("b"))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if got := a.stopCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("a.stopCalls=%d want 1 (cross-group exclusive must stop)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGroup_CrossGroupNonExclusiveParallel verifies that two requests for
|
||||||
|
// models in distinct non-exclusive groups load in parallel rather than
|
||||||
|
// serializing through the router's run loop.
|
||||||
|
func TestGroup_CrossGroupNonExclusiveParallel(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
pb := newFakeProcess("b")
|
||||||
|
|
||||||
|
conf := config.Config{
|
||||||
|
HealthCheckTimeout: 5,
|
||||||
|
Groups: map[string]config.GroupConfig{
|
||||||
|
"g1": {Swap: true, Exclusive: false, Members: []string{"a"}},
|
||||||
|
"g2": {Swap: true, Exclusive: false, Members: []string{"b"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": pb})
|
||||||
|
|
||||||
|
w1 := httptest.NewRecorder()
|
||||||
|
done1 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
g.ServeHTTP(w1, newRequest("a"))
|
||||||
|
close(done1)
|
||||||
|
}()
|
||||||
|
waitProcessed(t, g.testProcessed, 1)
|
||||||
|
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
done2 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
g.ServeHTTP(w2, newRequest("b"))
|
||||||
|
close(done2)
|
||||||
|
}()
|
||||||
|
waitProcessed(t, g.testProcessed, 1)
|
||||||
|
|
||||||
|
// Both groups load concurrently — both must reach Run() before either is
|
||||||
|
// marked ready. If the router still serialised, only one would proceed.
|
||||||
|
<-a.runStarted
|
||||||
|
<-pb.runStarted
|
||||||
|
|
||||||
|
a.markReady()
|
||||||
|
pb.markReady()
|
||||||
|
|
||||||
|
for i, ch := range []chan struct{}{done1, done2} {
|
||||||
|
select {
|
||||||
|
case <-ch:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatalf("request %d did not complete", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if got := a.stopCalls.Load(); got != 0 {
|
||||||
|
t.Errorf("a.stopCalls=%d want 0 (parallel groups don't evict each other)", got)
|
||||||
|
}
|
||||||
|
if got := pb.stopCalls.Load(); got != 0 {
|
||||||
|
t.Errorf("b.stopCalls=%d want 0 (parallel groups don't evict each other)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGroup_SameGroupSwapSerialises verifies that two same-group requests
|
||||||
|
// (Swap=true) serialise even when both arrive while neither has reached
|
||||||
|
// StateStarting yet — the alsoRunning hint to the planner closes that race.
|
||||||
|
func TestGroup_SameGroupSwapSerialises(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
pb := newFakeProcess("b")
|
||||||
|
|
||||||
|
conf := config.Config{
|
||||||
|
HealthCheckTimeout: 5,
|
||||||
|
Groups: map[string]config.GroupConfig{
|
||||||
|
"g": {Swap: true, Exclusive: false, Members: []string{"a", "b"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": pb})
|
||||||
|
|
||||||
|
w1 := httptest.NewRecorder()
|
||||||
|
done1 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
g.ServeHTTP(w1, newRequest("a"))
|
||||||
|
close(done1)
|
||||||
|
}()
|
||||||
|
waitProcessed(t, g.testProcessed, 1)
|
||||||
|
|
||||||
|
// Request B arrives before A transitions to StateStarting in the process
|
||||||
|
// state machine. Without the alsoRunning hint, the planner would not see
|
||||||
|
// A as running, and B would start in parallel, violating Swap=true.
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
done2 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
g.ServeHTTP(w2, newRequest("b"))
|
||||||
|
close(done2)
|
||||||
|
}()
|
||||||
|
waitProcessed(t, g.testProcessed, 1)
|
||||||
|
|
||||||
|
if got := pb.runCalls.Load(); got != 0 {
|
||||||
|
t.Errorf("b started in parallel: runCalls=%d want 0", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
<-a.runStarted
|
||||||
|
a.markReady()
|
||||||
|
waitProcessed(t, g.testProcessed, 1) // swapDone(a) → b promoted
|
||||||
|
<-pb.runStarted
|
||||||
|
pb.markReady()
|
||||||
|
|
||||||
|
for i, ch := range []chan struct{}{done1, done2} {
|
||||||
|
select {
|
||||||
|
case <-ch:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatalf("request %d did not complete", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if got := a.stopCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("a.stopCalls=%d want 1 (b's swap must stop a)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGroup_PersistentNotEvicted verifies that a group with persistent=true
|
||||||
|
// is never evicted when another exclusive group starts loading. The running
|
||||||
|
// model in the persistent group stays alive alongside the new one.
|
||||||
|
func TestGroup_PersistentNotEvicted(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
a.markReady()
|
||||||
|
go a.Run(0)
|
||||||
|
|
||||||
|
b := newFakeProcess("b")
|
||||||
|
b.autoReady = true
|
||||||
|
|
||||||
|
conf := config.Config{
|
||||||
|
HealthCheckTimeout: 5,
|
||||||
|
Groups: map[string]config.GroupConfig{
|
||||||
|
"persist": {Swap: true, Exclusive: false, Persistent: true, Members: []string{"a"}},
|
||||||
|
"other": {Swap: true, Exclusive: true, Members: []string{"b"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
g.ServeHTTP(w, newRequest("b"))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if got := a.stopCalls.Load(); got != 0 {
|
||||||
|
t.Errorf("a.stopCalls=%d want 0 (persistent group must not be evicted)", got)
|
||||||
|
}
|
||||||
|
if a.State() != process.StateStarting && a.State() != process.StateReady {
|
||||||
|
t.Errorf("a state=%s want still running", a.State())
|
||||||
|
}
|
||||||
|
if got := b.runCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("b.runCalls=%d want 1", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGroup_NonExclusiveDoesNotUnloadExclusive pins a backwards-compatible
|
||||||
|
// gotcha from the original ProcessGroup: when a model in a non-exclusive group
|
||||||
|
// is loaded, any running exclusive group keeps running. The two coexist.
|
||||||
|
func TestGroup_NonExclusiveDoesNotUnloadExclusive(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
a.markReady()
|
||||||
|
go a.Run(0)
|
||||||
|
|
||||||
|
b := newFakeProcess("b")
|
||||||
|
b.autoReady = true
|
||||||
|
|
||||||
|
conf := config.Config{
|
||||||
|
HealthCheckTimeout: 5,
|
||||||
|
Groups: map[string]config.GroupConfig{
|
||||||
|
"g1": {Swap: true, Exclusive: true, Members: []string{"a"}},
|
||||||
|
"g2": {Swap: true, Exclusive: false, Members: []string{"b"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
g.ServeHTTP(w, newRequest("b"))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if got := a.stopCalls.Load(); got != 0 {
|
||||||
|
t.Errorf("a.stopCalls=%d want 0 (non-exclusive target must not unload exclusive group)", got)
|
||||||
|
}
|
||||||
|
if a.State() != process.StateStarting && a.State() != process.StateReady {
|
||||||
|
t.Errorf("a state=%s want still running", a.State())
|
||||||
|
}
|
||||||
|
if got := b.runCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("b.runCalls=%d want 1", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,205 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
// fakeProcess is an in-memory implementation of process.Process used to drive
|
||||||
|
// the routers through their state machine without spawning real upstreams.
|
||||||
|
type fakeProcess struct {
|
||||||
|
id string
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
state process.ProcessState
|
||||||
|
readyCh chan struct{}
|
||||||
|
stopCh chan struct{}
|
||||||
|
runStarted chan struct{} // closed on the first Run call
|
||||||
|
stopStarted chan struct{} // closed on the first Stop call
|
||||||
|
|
||||||
|
autoReady bool
|
||||||
|
|
||||||
|
// serveBlock, when non-nil, makes ServeHTTP receive from it before
|
||||||
|
// writing its response. Tests use this to hold a request in-flight.
|
||||||
|
// Closing the channel releases every blocked ServeHTTP caller.
|
||||||
|
serveBlock chan struct{}
|
||||||
|
// serveStarted is closed on the first ServeHTTP entry, letting tests
|
||||||
|
// wait deterministically for the handler to begin executing.
|
||||||
|
serveStarted chan struct{}
|
||||||
|
// stopBlock, when non-nil, makes Stop receive from it (after signalling
|
||||||
|
// stopStarted) before completing. Tests use this to prove that several
|
||||||
|
// Stop calls can be in flight simultaneously.
|
||||||
|
stopBlock chan struct{}
|
||||||
|
|
||||||
|
runCalls atomic.Int32
|
||||||
|
stopCalls atomic.Int32
|
||||||
|
serveCalls atomic.Int32
|
||||||
|
|
||||||
|
// inFlightServe counts ServeHTTP calls currently inside the handler.
|
||||||
|
// stoppedWhileServing flips true if Stop is ever called while that
|
||||||
|
// counter is non-zero — a direct, race-free observation of the
|
||||||
|
// "swap mid-request" anti-property.
|
||||||
|
inFlightServe atomic.Int32
|
||||||
|
stoppedWhileServing atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFakeProcess(id string) *fakeProcess {
|
||||||
|
return &fakeProcess{
|
||||||
|
id: id,
|
||||||
|
state: process.StateStopped,
|
||||||
|
readyCh: make(chan struct{}),
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
runStarted: make(chan struct{}),
|
||||||
|
stopStarted: make(chan struct{}),
|
||||||
|
serveStarted: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeProcess) setState(s process.ProcessState) {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
f.state = s
|
||||||
|
if s == process.StateReady {
|
||||||
|
select {
|
||||||
|
case <-f.readyCh:
|
||||||
|
default:
|
||||||
|
close(f.readyCh)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeProcess) State() process.ProcessState {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
return f.state
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeProcess) markReady() { f.setState(process.StateReady) }
|
||||||
|
|
||||||
|
func (f *fakeProcess) Run(_ time.Duration) error {
|
||||||
|
f.runCalls.Add(1)
|
||||||
|
f.mu.Lock()
|
||||||
|
if f.state != process.StateStopped {
|
||||||
|
s := f.state
|
||||||
|
f.mu.Unlock()
|
||||||
|
return fmt.Errorf("fakeProcess %s: Run called while %s", f.id, s)
|
||||||
|
}
|
||||||
|
f.state = process.StateStarting
|
||||||
|
sc := f.stopCh
|
||||||
|
select {
|
||||||
|
case <-f.runStarted:
|
||||||
|
default:
|
||||||
|
close(f.runStarted)
|
||||||
|
}
|
||||||
|
f.mu.Unlock()
|
||||||
|
|
||||||
|
if f.autoReady {
|
||||||
|
f.setState(process.StateReady)
|
||||||
|
}
|
||||||
|
<-sc
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeProcess) Stop(_ time.Duration) error {
|
||||||
|
f.stopCalls.Add(1)
|
||||||
|
if f.inFlightServe.Load() > 0 {
|
||||||
|
f.stoppedWhileServing.Store(true)
|
||||||
|
}
|
||||||
|
f.mu.Lock()
|
||||||
|
select {
|
||||||
|
case <-f.stopStarted:
|
||||||
|
default:
|
||||||
|
close(f.stopStarted)
|
||||||
|
}
|
||||||
|
f.mu.Unlock()
|
||||||
|
|
||||||
|
// Test hook: hold Stop here so the test can prove multiple Stops are
|
||||||
|
// in flight at the same time before any of them complete.
|
||||||
|
if f.stopBlock != nil {
|
||||||
|
<-f.stopBlock
|
||||||
|
}
|
||||||
|
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
if f.state == process.StateStopped {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
f.state = process.StateStopped
|
||||||
|
select {
|
||||||
|
case <-f.stopCh:
|
||||||
|
default:
|
||||||
|
close(f.stopCh)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeProcess) WaitReady(ctx context.Context) error {
|
||||||
|
f.mu.Lock()
|
||||||
|
if f.state == process.StateReady {
|
||||||
|
f.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
rc := f.readyCh
|
||||||
|
f.mu.Unlock()
|
||||||
|
select {
|
||||||
|
case <-rc:
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeProcess) Logger() *logmon.Monitor { return logmon.NewWriter(io.Discard) }
|
||||||
|
|
||||||
|
func (f *fakeProcess) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
f.serveCalls.Add(1)
|
||||||
|
f.inFlightServe.Add(1)
|
||||||
|
defer f.inFlightServe.Add(-1)
|
||||||
|
f.mu.Lock()
|
||||||
|
select {
|
||||||
|
case <-f.serveStarted:
|
||||||
|
default:
|
||||||
|
close(f.serveStarted)
|
||||||
|
}
|
||||||
|
f.mu.Unlock()
|
||||||
|
if f.serveBlock != nil {
|
||||||
|
<-f.serveBlock
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprintf(w, "ok:%s", f.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitProcessed drains n events from ch, fataling on timeout. One event fires
|
||||||
|
// per handlerReq or swapDone fully absorbed by run().
|
||||||
|
func waitProcessed(t *testing.T, ch chan struct{}, n int) {
|
||||||
|
t.Helper()
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
select {
|
||||||
|
case <-ch:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatalf("waitProcessed: only %d/%d events received", i, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRequest(model string) *http.Request {
|
||||||
|
body := fmt.Sprintf(`{"model":%q}`, model)
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body))
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRequestCtx(ctx context.Context, model string) *http.Request {
|
||||||
|
return newRequest(model).WithContext(ctx)
|
||||||
|
}
|
||||||
@@ -0,0 +1,277 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
)
|
||||||
|
|
||||||
|
var loadingPaths = []string{
|
||||||
|
"/v1/chat/completions",
|
||||||
|
}
|
||||||
|
|
||||||
|
func isLoadingPath(path string) bool {
|
||||||
|
for _, p := range loadingPaths {
|
||||||
|
if strings.HasPrefix(path, p) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
type loadingWriter struct {
|
||||||
|
hasWritten bool
|
||||||
|
writer http.ResponseWriter
|
||||||
|
req *http.Request
|
||||||
|
ctx context.Context
|
||||||
|
logger *logmon.Monitor
|
||||||
|
modelName string
|
||||||
|
startTime time.Time
|
||||||
|
|
||||||
|
pendingMu sync.Mutex
|
||||||
|
pendingUpdate string
|
||||||
|
|
||||||
|
// writeMu serializes writes to the underlying writer and guards released.
|
||||||
|
// Once released is set, the streaming goroutine must not touch the writer
|
||||||
|
// again — ServeHTTP has reclaimed it (to run the real handler or to return)
|
||||||
|
// and writing/flushing a finalized response panics.
|
||||||
|
writeMu sync.Mutex
|
||||||
|
released bool
|
||||||
|
|
||||||
|
// closed by start when the goroutine finishes (after cleanup messages)
|
||||||
|
done chan struct{}
|
||||||
|
|
||||||
|
// test-only: closed when start enters its loop
|
||||||
|
loopStarted chan struct{}
|
||||||
|
// test-only: override the 1s tick interval
|
||||||
|
tickDuration time.Duration
|
||||||
|
// test-only: override character streaming speed (0 = no delay)
|
||||||
|
charPerSecond float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLoadingWriter(logger *logmon.Monitor, modelName string, w http.ResponseWriter, req *http.Request) *loadingWriter {
|
||||||
|
s := &loadingWriter{
|
||||||
|
writer: w,
|
||||||
|
req: req,
|
||||||
|
ctx: req.Context(),
|
||||||
|
logger: logger,
|
||||||
|
modelName: modelName,
|
||||||
|
startTime: time.Now(),
|
||||||
|
tickDuration: 750 * time.Millisecond,
|
||||||
|
charPerSecond: 75,
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
s.Header().Set("Cache-Control", "no-cache")
|
||||||
|
s.Header().Set("Connection", "keep-alive")
|
||||||
|
s.WriteHeader(http.StatusOK)
|
||||||
|
s.sendLine("━━━━━")
|
||||||
|
s.sendLine(fmt.Sprintf("llama-swap loading model: %s", modelName))
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *loadingWriter) setUpdate(msg string) {
|
||||||
|
s.pendingMu.Lock()
|
||||||
|
s.pendingUpdate = msg
|
||||||
|
s.pendingMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *loadingWriter) start(ctx context.Context) {
|
||||||
|
s.done = make(chan struct{})
|
||||||
|
defer close(s.done)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
// Skip cleanup writes if the client disconnected — the connection
|
||||||
|
// is being torn down and flushing against it will panic.
|
||||||
|
if s.ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
duration := time.Since(s.startTime)
|
||||||
|
s.sendData("\n")
|
||||||
|
s.sendLine(fmt.Sprintf("Done! (%.2fs)", duration.Seconds()))
|
||||||
|
s.sendLine("━━━━━")
|
||||||
|
s.sendLine(" ")
|
||||||
|
}()
|
||||||
|
|
||||||
|
remarks := make([]string, len(loadingRemarks))
|
||||||
|
copy(remarks, loadingRemarks)
|
||||||
|
rand.Shuffle(len(remarks), func(i, j int) {
|
||||||
|
remarks[i], remarks[j] = remarks[j], remarks[i]
|
||||||
|
})
|
||||||
|
ri := 0
|
||||||
|
|
||||||
|
nextRemarkIn := time.Duration(2+rand.Intn(4)) * time.Second
|
||||||
|
lastRemarkTime := time.Time{}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(s.tickDuration)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
if s.loopStarted != nil {
|
||||||
|
close(s.loopStarted)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
s.pendingMu.Lock()
|
||||||
|
update := s.pendingUpdate
|
||||||
|
s.pendingUpdate = ""
|
||||||
|
s.pendingMu.Unlock()
|
||||||
|
|
||||||
|
if update != "" {
|
||||||
|
s.sendData("\n")
|
||||||
|
s.sendInline(update)
|
||||||
|
s.sendData(" ")
|
||||||
|
lastRemarkTime = time.Now()
|
||||||
|
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
|
||||||
|
} else if time.Since(lastRemarkTime) >= nextRemarkIn {
|
||||||
|
remark := remarks[ri%len(remarks)]
|
||||||
|
ri++
|
||||||
|
s.sendData("\n")
|
||||||
|
s.sendInline(remark)
|
||||||
|
s.sendData(" ")
|
||||||
|
lastRemarkTime = time.Now()
|
||||||
|
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
|
||||||
|
} else {
|
||||||
|
s.sendData(".")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *loadingWriter) waitForCompletion(timeout time.Duration) bool {
|
||||||
|
if s.done == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-s.done:
|
||||||
|
return true
|
||||||
|
case <-time.After(timeout):
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *loadingWriter) sendInline(text string) {
|
||||||
|
chunkSize := 10
|
||||||
|
if s.charPerSecond > 0 {
|
||||||
|
chunkSize = max(3, int(s.charPerSecond)/15)
|
||||||
|
}
|
||||||
|
|
||||||
|
runes := []rune(text)
|
||||||
|
for i := 0; i < len(runes); {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
end := i + chunkSize
|
||||||
|
if end > len(runes) {
|
||||||
|
end = len(runes)
|
||||||
|
}
|
||||||
|
chunk := string(runes[i:end])
|
||||||
|
s.sendData(chunk)
|
||||||
|
i = end
|
||||||
|
|
||||||
|
if i < len(runes) && s.charPerSecond > 0 {
|
||||||
|
time.Sleep(time.Duration(float64(time.Second) * float64(len(chunk)) / s.charPerSecond))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *loadingWriter) sendLine(line string) {
|
||||||
|
if line == "" {
|
||||||
|
s.sendData("\n")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.sendInline(line)
|
||||||
|
s.sendData("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *loadingWriter) sendData(data string) {
|
||||||
|
type Delta struct {
|
||||||
|
ReasoningContent string `json:"reasoning_content"`
|
||||||
|
}
|
||||||
|
type Choice struct {
|
||||||
|
Delta Delta `json:"delta"`
|
||||||
|
}
|
||||||
|
type SSEMessage struct {
|
||||||
|
Choices []Choice `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := SSEMessage{
|
||||||
|
Choices: []Choice{
|
||||||
|
{
|
||||||
|
Delta: Delta{
|
||||||
|
ReasoningContent: data,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("<%s> Failed to marshal SSE message: %v", s.modelName, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.writeMu.Lock()
|
||||||
|
defer s.writeMu.Unlock()
|
||||||
|
// Once ServeHTTP has reclaimed the writer (release), writing/flushing it
|
||||||
|
// races the real handler or panics on a finalized response. Stop here.
|
||||||
|
if s.released {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData); err != nil {
|
||||||
|
s.logger.Debugf("<%s> Failed to write SSE data (client likely disconnected): %v", s.modelName, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if flusher, ok := s.writer.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// release fences the loadingWriter off from the underlying ResponseWriter.
|
||||||
|
// After it returns, the streaming goroutine will not write to or flush the
|
||||||
|
// writer again: any in-flight write completes under writeMu first, and later
|
||||||
|
// writes short-circuit on released. The caller can then safely hand the writer
|
||||||
|
// to the real handler or let ServeHTTP return without racing a finalized
|
||||||
|
// response (a use-after-return Flush panics on the recycled *bufio.Writer).
|
||||||
|
func (s *loadingWriter) release() {
|
||||||
|
s.writeMu.Lock()
|
||||||
|
s.released = true
|
||||||
|
s.writeMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *loadingWriter) Header() http.Header {
|
||||||
|
return s.writer.Header()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *loadingWriter) Write(data []byte) (int, error) {
|
||||||
|
return s.writer.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *loadingWriter) WriteHeader(statusCode int) {
|
||||||
|
if s.hasWritten {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.hasWritten = true
|
||||||
|
s.writer.WriteHeader(statusCode)
|
||||||
|
s.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *loadingWriter) Flush() {
|
||||||
|
if flusher, ok := s.writer.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,133 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
var loadingRemarks = []string{
|
||||||
|
"Still faster than your last standup meeting",
|
||||||
|
"Reticulating splines",
|
||||||
|
"Waking up the hamsters",
|
||||||
|
"Teaching the model manners",
|
||||||
|
"Convincing the GPU to participate",
|
||||||
|
"Loading weights (they're heavy)",
|
||||||
|
"Please enjoy this elevator music in your head",
|
||||||
|
"Pretending to be productive",
|
||||||
|
"Reading the entire internet, page by page",
|
||||||
|
"Staring at the abyss, the abyss is buffering",
|
||||||
|
"Applying layer after layer of disembodied cognition",
|
||||||
|
"Remembering everything it forgot during quantization",
|
||||||
|
"Counting to 405 billion, one parameter at a time",
|
||||||
|
"Summoning the stochastic parroting",
|
||||||
|
"Hold on, the GPU is questioning its existence",
|
||||||
|
"Deciding which facts to hallucinate today",
|
||||||
|
"Untangling the transformer spaghetti",
|
||||||
|
"Warming up the token soup",
|
||||||
|
"Your prompt is in a queue, behind 7 billion other thoughts",
|
||||||
|
"Running `sudo apt-get install intelligence`",
|
||||||
|
"Defragmenting the latent space",
|
||||||
|
"Polishing each matrix multiplication by hand",
|
||||||
|
"Whispering sweet nothings to the attention heads",
|
||||||
|
"Aligning with human values, one reluctant epoch at a time",
|
||||||
|
"The model is thinking about what it's about to think about",
|
||||||
|
"Loading... and by loading we mean making you wait",
|
||||||
|
"Spinning up the cloud GPU, please be patient while we burn your credits",
|
||||||
|
"Applying duct tape to the context window",
|
||||||
|
"Bribing the GPU scheduler for a timeslice",
|
||||||
|
"Would you like to hear a fun fact while we load? Too bad.",
|
||||||
|
"Hot swapping your sanity for an LLM",
|
||||||
|
"Compressing optimism into FP16",
|
||||||
|
"Ignoring 90% of the attention to save you 50% of the time",
|
||||||
|
"Counting the exact same thing three times just to be sure",
|
||||||
|
"Sorry, the inference you have reached is not in service",
|
||||||
|
"Rotating the positional encodings counterclockwise for good luck",
|
||||||
|
"Your call is very important to us. Please continue to hold.",
|
||||||
|
"Unpacking the blobs. All 300GB of them.",
|
||||||
|
"Initializing the thing that initializes the other thing",
|
||||||
|
"Converting electricity into existential dread",
|
||||||
|
"Flattening the curve... wait, the tensor. Flattening the tensor.",
|
||||||
|
"Fetching the fetch of a fetch, callback hell edition",
|
||||||
|
"The GPU is at 100%. The fan is now a helicopter.",
|
||||||
|
"Baking the weights at 350° for a golden-brown inference",
|
||||||
|
"Recalibrating the confidence of things it's still wrong about",
|
||||||
|
"Have you tried turning it off and on again? No? Good, wait here.",
|
||||||
|
"Simulating deep thought by pausing dramatically",
|
||||||
|
"Loading the model that knows more than you but still can't count r's in 'strawberry'",
|
||||||
|
"Convincing CUDA to cooperate. This may take a while.",
|
||||||
|
"VRAM: 23.9GB used of 24GB. Living on the edge.",
|
||||||
|
"Processing your request with the urgency of a DMV employee",
|
||||||
|
"This model was trained on the entire internet, including that embarrassing blog you wrote in 2008",
|
||||||
|
"Dispatching tokens through a series of increasingly confused matrix multiplies",
|
||||||
|
"Gently lowering your expectations",
|
||||||
|
"Applying softmax to our feelings about this load time",
|
||||||
|
"Autoregressively generating disappointment, one token at a time",
|
||||||
|
"The magic is happening. Somewhere. Probably.",
|
||||||
|
"Synchronizing the parallel processes that run in parallel but really don't",
|
||||||
|
"Calculating the meaning of life. Spoiler: it's 42, but we're double-checking.",
|
||||||
|
"Loading... just like it said 30 seconds ago. And will say 30 seconds from now.",
|
||||||
|
"Pre-warming the cache so the first query is only slightly slower than the rest",
|
||||||
|
"Have you considered that maybe your question wasn't worth all this compute?",
|
||||||
|
"Downloading more RAM (no, really, we're mmap-ing the weights)",
|
||||||
|
"Translating your prompt into math it barely understands",
|
||||||
|
"Estimating your time remaining with 0% accuracy",
|
||||||
|
"Buffering enthusiasm",
|
||||||
|
"Model is loading. Go make some coffee. Or a three-course meal.",
|
||||||
|
"Tokenizing the dictionary, filing a grievance on behalf of 'antidisestablishmentarianism'",
|
||||||
|
"Polling for readiness in a loop that would make your CS professor weep",
|
||||||
|
"Performing percussive maintenance on the attention mechanism",
|
||||||
|
"This loading screen is singlehandedly reversing climate progress",
|
||||||
|
"Decompressing the hopes and dreams of thousands of underpaid labelers",
|
||||||
|
"Filling the key-value cache with the ghost of prompts past",
|
||||||
|
"Currently at step 3 of 9,742 of loading. We'll get there. Eventually.",
|
||||||
|
"If you stare at the spinner, it spins slower. It's science.",
|
||||||
|
"Multiplying matricies with the enthusiasm of a teenager doing chores",
|
||||||
|
"Applying `torch.nap()` until the model feels refreshed",
|
||||||
|
"Reacquainting the model with the concept of 'facts' it forgot during fine-tuning",
|
||||||
|
"Sorry for the wait. No, wait, we're not actually sorry.",
|
||||||
|
"Your GPU is now a space heater with a side hustle in linear algebra",
|
||||||
|
"Allocating memory like a billionaire allocates tax avoidance strategies",
|
||||||
|
"The model saw \"As an AI language model\" and won't stop saying it now",
|
||||||
|
"Installing dependencies you didn't know existed and will never use again",
|
||||||
|
"Re-reading 'Attention Is All You Need' for the 400th time",
|
||||||
|
"Convincing the embedding layer that context is overrated",
|
||||||
|
"Manually untangling the residual connections with a tiny comb",
|
||||||
|
"On hold with the cloud provider trying to explain why 8 H100s isn't enough",
|
||||||
|
"Adjusting temperatures: model is 0.7, server room is 104°F",
|
||||||
|
"Please hold while we justify this electricity bill to accounting",
|
||||||
|
"Stacking decoder blocks like a Jenga tower at a LAN party",
|
||||||
|
"Compensating for your lack of patience with our lack of speed",
|
||||||
|
"This is a loading screen comment. Loading screens have comments now. Welcome to the future.",
|
||||||
|
"Processing the entire works of Shakespeare backwards just in case",
|
||||||
|
"The model is loading slower than your last `npm install`",
|
||||||
|
"Rehearsing plausible-sounding explanations for why it got everything wrong",
|
||||||
|
"Populating the context with filler while you wait for actual content",
|
||||||
|
"Optimizing for BLEU score, which definitely correlates with making you laugh",
|
||||||
|
"Generating an embedding for each and every letter of the alphabet, individually",
|
||||||
|
"Coming soon: llama-swap v2 with actual performance improvements. Probably.",
|
||||||
|
"Loading a model larger than your attention span",
|
||||||
|
"Performing a seance to invoke the spirit of Geoff Hinton",
|
||||||
|
"Did you know loading screens were invented to prevent users from smashing their monitors? Now you do.",
|
||||||
|
"Converting all the internet's bad opinions into a surprisingly useful autocomplete",
|
||||||
|
"Laying down each layer with the care of a Michelin-starred pastry chef",
|
||||||
|
"Checking if the model still thinks birds are government drones. Yep.",
|
||||||
|
"Activating the neurons responsible for 'I cannot assist with that request'",
|
||||||
|
"This model was trained on the same internet that brought you Rickrolling. You're welcome.",
|
||||||
|
"Realigning the alignment so it aligns with the previous alignment",
|
||||||
|
"Running `nvidia-smi` and sighing heavily",
|
||||||
|
"If you close your eyes, the loading bar moves faster. Proven by science.",
|
||||||
|
"EULA said 'by using this software you agree to wait forever' and you clicked Accept",
|
||||||
|
"Zipping the GPUs to make them go faster",
|
||||||
|
"Padding the context window with existential padding",
|
||||||
|
"We could have used a smaller model but someone wanted 'quality'",
|
||||||
|
"Disentangling the latent space into something resembling coherence",
|
||||||
|
"Slow is smooth, smooth is fast, but this is just slow",
|
||||||
|
"Memory-mapping like it's a AAA title from 2012",
|
||||||
|
"Your patience has been tokenized and added to the training set. Thank you for your contribution.",
|
||||||
|
"Loading is CPU-bound and your CPU is busy regretting its life choices",
|
||||||
|
"Exploring the high-dimensional manifold of ways to say 'just a moment'",
|
||||||
|
"The model is experiencing a brief but intense moment of imposter syndrome",
|
||||||
|
"Initializing 7B parameters by rolling 7B 16-sided dice",
|
||||||
|
"Panic! at the disk I/O",
|
||||||
|
"Intelligence is loading... your definition of intelligence may vary",
|
||||||
|
"This model was distilled. Unlike your patience, which is evaporating.",
|
||||||
|
"Unzipping the model. It's a .gguf file, not a metaphor.",
|
||||||
|
"Running inference on the concept of 'soon' to estimate remaining time",
|
||||||
|
"Loading with all the speed of a government-funded IT project",
|
||||||
|
"A blank terminal is a terrible thing to waste. Here's a loading message instead.",
|
||||||
|
}
|
||||||
@@ -0,0 +1,328 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadingWriter_SSEHeadersAndInitialMessage(t *testing.T) {
|
||||||
|
logger := logmon.NewWriter(io.Discard)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||||
|
|
||||||
|
if ct := lw.Header().Get("Content-Type"); ct != "text/event-stream" {
|
||||||
|
t.Errorf("Content-Type: want text/event-stream, got %q", ct)
|
||||||
|
}
|
||||||
|
if cc := lw.Header().Get("Cache-Control"); cc != "no-cache" {
|
||||||
|
t.Errorf("Cache-Control: want no-cache, got %q", cc)
|
||||||
|
}
|
||||||
|
if conn := lw.Header().Get("Connection"); conn != "keep-alive" {
|
||||||
|
t.Errorf("Connection: want keep-alive, got %q", conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
if !strings.HasPrefix(body, "data: ") {
|
||||||
|
t.Errorf("expected SSE data: prefix, got: %s", body)
|
||||||
|
}
|
||||||
|
|
||||||
|
content := extractStreamedContent(body)
|
||||||
|
if !strings.Contains(content, "━━━━━\n") {
|
||||||
|
t.Errorf("missing separator in streamed content: %q", content)
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, "llama-swap loading model: test-model\n") {
|
||||||
|
t.Errorf("missing initial message in streamed content: %q", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadingWriter_WriteHeaderOnce(t *testing.T) {
|
||||||
|
logger := logmon.NewWriter(io.Discard)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||||
|
lw.WriteHeader(http.StatusCreated)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("first WriteHeader: want %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadingWriter_WritePassthrough(t *testing.T) {
|
||||||
|
logger := logmon.NewWriter(io.Discard)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||||
|
lw.Write([]byte("hello"))
|
||||||
|
lw.Flush()
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
if !strings.Contains(body, "hello") {
|
||||||
|
t.Errorf("Write passthrough failed, body: %s", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadingWriter_StartStopsOnCancel(t *testing.T) {
|
||||||
|
logger := logmon.NewWriter(io.Discard)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||||
|
lw.tickDuration = 10 * time.Millisecond
|
||||||
|
lw.loopStarted = make(chan struct{})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
go lw.start(ctx)
|
||||||
|
<-lw.loopStarted
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if !lw.waitForCompletion(time.Second) {
|
||||||
|
t.Fatal("waitForCompletion timed out")
|
||||||
|
}
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
if !strings.Contains(body, "Done!") {
|
||||||
|
t.Errorf("expected Done! message, body: %s", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadingWriter_StartShowsSetUpdate(t *testing.T) {
|
||||||
|
logger := logmon.NewWriter(io.Discard)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||||
|
lw.tickDuration = 10 * time.Millisecond
|
||||||
|
lw.charPerSecond = 0
|
||||||
|
lw.loopStarted = make(chan struct{})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
go lw.start(ctx)
|
||||||
|
<-lw.loopStarted
|
||||||
|
|
||||||
|
lw.setUpdate("custom status message")
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if !lw.waitForCompletion(time.Second) {
|
||||||
|
t.Fatal("waitForCompletion timed out")
|
||||||
|
}
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
content := extractStreamedContent(body)
|
||||||
|
if !strings.Contains(content, "custom status message") {
|
||||||
|
t.Errorf("expected setUpdate message in output, got: %q", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadingWriter_SendDataFormat(t *testing.T) {
|
||||||
|
logger := logmon.NewWriter(io.Discard)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||||
|
lw.sendData("hello world")
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
if !strings.Contains(body, `"reasoning_content":"hello world"`) {
|
||||||
|
t.Errorf("expected reasoning_content in SSE data, body: %s", body)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(body, "data: ") {
|
||||||
|
t.Errorf("expected data: prefix, got: %s", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadingWriter_SendLine(t *testing.T) {
|
||||||
|
logger := logmon.NewWriter(io.Discard)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||||
|
lw.charPerSecond = 0
|
||||||
|
|
||||||
|
// Capture only the content from this sendLine call
|
||||||
|
before := w.Body.Len()
|
||||||
|
lw.sendLine("line content")
|
||||||
|
after := w.Body.Len()
|
||||||
|
chunkBody := w.Body.String()[before:after]
|
||||||
|
|
||||||
|
content := extractStreamedContent(chunkBody)
|
||||||
|
if content != "line content\n" {
|
||||||
|
t.Errorf("expected complete streamed line, got: %q", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadingWriter_FlushesPeriodicallyDuringStatusUpdates(t *testing.T) {
|
||||||
|
logger := logmon.NewWriter(io.Discard)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||||
|
lw.tickDuration = 10 * time.Millisecond
|
||||||
|
lw.charPerSecond = 0
|
||||||
|
lw.loopStarted = make(chan struct{})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
lw.start(ctx)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-lw.loopStarted
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
<-done
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
lines := countSSEMessages(body)
|
||||||
|
if lines < 2 {
|
||||||
|
t.Errorf("expected multiple SSE messages from periodic updates, got %d", lines)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadingWriter_ReqStored(t *testing.T) {
|
||||||
|
logger := logmon.NewWriter(io.Discard)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||||
|
if lw.req != req {
|
||||||
|
t.Fatal("req not stored")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsLoadingPath(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
path string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"/v1/chat/completions", true},
|
||||||
|
{"/v1/chat/completions/extra", true},
|
||||||
|
{"/v1/completions", false},
|
||||||
|
{"/v1/embeddings", false},
|
||||||
|
{"/health", false},
|
||||||
|
{"", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.path, func(t *testing.T) {
|
||||||
|
if got := isLoadingPath(tt.path); got != tt.want {
|
||||||
|
t.Errorf("isLoadingPath(%q) = %v, want %v", tt.path, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_Streaming_GET(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
query string
|
||||||
|
wantStreaming bool
|
||||||
|
}{
|
||||||
|
{"streaming true", "model=llama3&stream=true", true},
|
||||||
|
{"streaming false", "model=llama3&stream=false", false},
|
||||||
|
{"no stream param", "model=llama3", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
|
||||||
|
got, err := ExtractContext(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExtractContext: %v", err)
|
||||||
|
}
|
||||||
|
if got.Streaming != tt.wantStreaming {
|
||||||
|
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_Streaming_JSON(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
wantStreaming bool
|
||||||
|
}{
|
||||||
|
{"streaming true", `{"model":"llama3","stream":true}`, true},
|
||||||
|
{"streaming false", `{"model":"llama3","stream":false}`, false},
|
||||||
|
{"no stream param", `{"model":"llama3"}`, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
got, err := ExtractContext(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExtractContext: %v", err)
|
||||||
|
}
|
||||||
|
if got.Streaming != tt.wantStreaming {
|
||||||
|
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_Streaming_URLEncodedForm(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader("model=whisper-1&stream=true"))
|
||||||
|
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
got, err := ExtractContext(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExtractContext: %v", err)
|
||||||
|
}
|
||||||
|
if !got.Streaming {
|
||||||
|
t.Error("Streaming should be true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func countSSEMessages(s string) int {
|
||||||
|
scanner := bufio.NewScanner(strings.NewReader(s))
|
||||||
|
count := 0
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if strings.HasPrefix(line, "data: ") {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractStreamedContent(body string) string {
|
||||||
|
var result strings.Builder
|
||||||
|
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if !strings.HasPrefix(line, "data: ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
jsonData := strings.TrimPrefix(line, "data: ")
|
||||||
|
var msg struct {
|
||||||
|
Choices []struct {
|
||||||
|
Delta struct {
|
||||||
|
ReasoningContent string `json:"reasoning_content"`
|
||||||
|
} `json:"delta"`
|
||||||
|
} `json:"choices"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(jsonData), &msg); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(msg.Choices) > 0 {
|
||||||
|
result.WriteString(msg.Choices[0].Delta.ReasoningContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result.String()
|
||||||
|
}
|
||||||
@@ -0,0 +1,101 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Matrix struct {
|
||||||
|
*baseRouter
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMatrix(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Matrix, error) {
|
||||||
|
if conf.Matrix == nil {
|
||||||
|
return nil, fmt.Errorf("matrix router requires a matrix configuration")
|
||||||
|
}
|
||||||
|
|
||||||
|
planner := &matrixPlanner{
|
||||||
|
solver: newMatrixSolver(conf.ExpandedSets, conf.Matrix.ResolvedEvictCosts()),
|
||||||
|
logger: proxylog,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a process for every model in the config. Any model can run alone
|
||||||
|
// even if it is not part of a set; this mirrors proxy.NewMatrix.
|
||||||
|
processes := make(map[string]process.Process, len(conf.Models))
|
||||||
|
base := newBaseRouter("matrix", conf, processes, planner, proxylog)
|
||||||
|
planner.processes = processes
|
||||||
|
|
||||||
|
for mid, modelCfg := range conf.Models {
|
||||||
|
procLog := logmon.NewWriter(upstreamlog)
|
||||||
|
p, err := process.New(base.procCtx, mid, modelCfg, procLog, proxylog)
|
||||||
|
if err != nil {
|
||||||
|
base.shutdownFn()
|
||||||
|
base.procCancel()
|
||||||
|
return nil, fmt.Errorf("creating process for %q: %w", mid, err)
|
||||||
|
}
|
||||||
|
processes[mid] = p
|
||||||
|
}
|
||||||
|
|
||||||
|
r := &Matrix{baseRouter: base}
|
||||||
|
go base.run()
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// matrixPlanner decides evictions by asking the matrix solver against the
|
||||||
|
// current running set.
|
||||||
|
type matrixPlanner struct {
|
||||||
|
solver *matrixSolver
|
||||||
|
processes map[string]process.Process
|
||||||
|
logger *logmon.Monitor
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *matrixPlanner) EvictionFor(target string, alsoRunning []string) []string {
|
||||||
|
return p.solver.Solve(target, p.runningSet(alsoRunning)).Evict
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *matrixPlanner) OnSwapStart(target string) {
|
||||||
|
running := p.runningModels()
|
||||||
|
result := p.solver.Solve(target, running)
|
||||||
|
switch {
|
||||||
|
case len(result.Evict) > 0:
|
||||||
|
p.logger.Infof("matrix: model=%s set=%s dsl=%q evict=%v target=%v cost=%d",
|
||||||
|
target, result.SetName, result.DSL, result.Evict, result.TargetSet, result.TotalCost)
|
||||||
|
case len(running) == 0:
|
||||||
|
p.logger.Infof("matrix: model=%s starting (no models running)", target)
|
||||||
|
default:
|
||||||
|
p.logger.Debugf("matrix: model=%s already running in set=%s dsl=%q", target, result.SetName, result.DSL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *matrixPlanner) runningModels() []string {
|
||||||
|
return p.runningSet(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// runningSet returns the union of live processes (State != Stopped/Shutdown)
|
||||||
|
// and any extra IDs the baseRouter has already committed to loading but which
|
||||||
|
// the process state machine has not yet reflected.
|
||||||
|
func (p *matrixPlanner) runningSet(alsoRunning []string) []string {
|
||||||
|
seen := make(map[string]struct{}, len(p.processes))
|
||||||
|
var running []string
|
||||||
|
for id, proc := range p.processes {
|
||||||
|
st := proc.State()
|
||||||
|
if st == process.StateStopped || st == process.StateShutdown {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[id] = struct{}{}
|
||||||
|
running = append(running, id)
|
||||||
|
}
|
||||||
|
for _, id := range alsoRunning {
|
||||||
|
if _, dup := seen[id]; dup {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[id] = struct{}{}
|
||||||
|
running = append(running, id)
|
||||||
|
}
|
||||||
|
sort.Strings(running)
|
||||||
|
return running
|
||||||
|
}
|
||||||
@@ -0,0 +1,132 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// matrixSolver contains pure swap-decision logic with no Process dependencies.
|
||||||
|
// It is safe for concurrent reads after construction.
|
||||||
|
type matrixSolver struct {
|
||||||
|
expandedSets []config.ExpandedSet // all valid model combinations
|
||||||
|
evictCosts map[string]int // real model name -> eviction cost (default 1)
|
||||||
|
modelToSets map[string][]int // model name -> indices into expandedSets
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMatrixSolver(expandedSets []config.ExpandedSet, evictCosts map[string]int) *matrixSolver {
|
||||||
|
modelToSets := make(map[string][]int)
|
||||||
|
for i, es := range expandedSets {
|
||||||
|
for _, model := range es.Models {
|
||||||
|
modelToSets[model] = append(modelToSets[model], i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &matrixSolver{
|
||||||
|
expandedSets: expandedSets,
|
||||||
|
evictCosts: evictCosts,
|
||||||
|
modelToSets: modelToSets,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// solveResult describes what the solver decided.
|
||||||
|
type solveResult struct {
|
||||||
|
Evict []string // running models that must be stopped
|
||||||
|
TargetSet []string // the chosen set of models (for informational purposes)
|
||||||
|
SetName string // name of the chosen set
|
||||||
|
DSL string // original DSL expression for the chosen set
|
||||||
|
TotalCost int // total eviction cost
|
||||||
|
}
|
||||||
|
|
||||||
|
// Solve determines which models to evict when a model is requested.
|
||||||
|
//
|
||||||
|
// Algorithm:
|
||||||
|
// 1. If requestedModel is already running, no eviction needed.
|
||||||
|
// 2. Find all sets containing requestedModel.
|
||||||
|
// 3. If no sets found, the model runs alone; evict all running models.
|
||||||
|
// 4. For each candidate set, compute cost = sum of evict_costs for running
|
||||||
|
// models NOT in that set.
|
||||||
|
// 5. Pick lowest cost. Ties broken by definition order (index in expandedSets).
|
||||||
|
// 6. Return models to evict and the chosen set.
|
||||||
|
func (s *matrixSolver) Solve(requestedModel string, runningModels []string) solveResult {
|
||||||
|
if slices.Contains(runningModels, requestedModel) {
|
||||||
|
setName, dsl := s.findMatchingSet(requestedModel, runningModels)
|
||||||
|
return solveResult{
|
||||||
|
TargetSet: runningModels,
|
||||||
|
SetName: setName,
|
||||||
|
DSL: dsl,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
candidateIndices := s.modelToSets[requestedModel]
|
||||||
|
|
||||||
|
// Model not in any set: runs alone, evict everything.
|
||||||
|
if len(candidateIndices) == 0 {
|
||||||
|
evict := make([]string, len(runningModels))
|
||||||
|
copy(evict, runningModels)
|
||||||
|
return solveResult{
|
||||||
|
Evict: evict,
|
||||||
|
TargetSet: []string{requestedModel},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bestCost := -1
|
||||||
|
bestIdx := -1
|
||||||
|
|
||||||
|
for _, idx := range candidateIndices {
|
||||||
|
setModels := s.expandedSets[idx].Models
|
||||||
|
cost := 0
|
||||||
|
for _, running := range runningModels {
|
||||||
|
if !slices.Contains(setModels, running) {
|
||||||
|
cost += s.evictCost(running)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if bestCost < 0 || cost < bestCost || (cost == bestCost && idx < bestIdx) {
|
||||||
|
bestCost = cost
|
||||||
|
bestIdx = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
chosen := s.expandedSets[bestIdx]
|
||||||
|
var evict []string
|
||||||
|
for _, running := range runningModels {
|
||||||
|
if !slices.Contains(chosen.Models, running) {
|
||||||
|
evict = append(evict, running)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return solveResult{
|
||||||
|
Evict: evict,
|
||||||
|
TargetSet: chosen.Models,
|
||||||
|
SetName: chosen.SetName,
|
||||||
|
DSL: chosen.DSL,
|
||||||
|
TotalCost: bestCost,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// findMatchingSet finds the expanded set that contains all running models.
|
||||||
|
// Returns the set name and DSL, or empty strings if no match.
|
||||||
|
func (s *matrixSolver) findMatchingSet(requestedModel string, runningModels []string) (string, string) {
|
||||||
|
for _, idx := range s.modelToSets[requestedModel] {
|
||||||
|
set := s.expandedSets[idx]
|
||||||
|
allInSet := true
|
||||||
|
for _, m := range runningModels {
|
||||||
|
if !slices.Contains(set.Models, m) {
|
||||||
|
allInSet = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if allInSet {
|
||||||
|
return set.SetName, set.DSL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *matrixSolver) evictCost(model string) int {
|
||||||
|
if cost, ok := s.evictCosts[model]; ok {
|
||||||
|
return cost
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
@@ -0,0 +1,244 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newTestMatrix builds a Matrix router from supplied processes, bypassing
|
||||||
|
// NewMatrix's call to process.New.
|
||||||
|
func newTestMatrix(t *testing.T, conf config.Config, expanded []config.ExpandedSet, evictCosts map[string]int, processes map[string]process.Process) *Matrix {
|
||||||
|
t.Helper()
|
||||||
|
logger := logmon.NewWriter(io.Discard)
|
||||||
|
planner := &matrixPlanner{
|
||||||
|
solver: newMatrixSolver(expanded, evictCosts),
|
||||||
|
processes: processes,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
base := newBaseRouter("matrix", conf, processes, planner, logger)
|
||||||
|
base.testProcessed = make(chan struct{}, 64)
|
||||||
|
r := &Matrix{baseRouter: base}
|
||||||
|
go base.run()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if !r.shuttingDown.Load() {
|
||||||
|
_ = r.Shutdown(time.Second)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func baseMatrixConfig() config.Config {
|
||||||
|
return config.Config{
|
||||||
|
HealthCheckTimeout: 5,
|
||||||
|
Matrix: &config.MatrixConfig{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMatrix_SwapEvictsConflicting verifies that loading a model triggers
|
||||||
|
// eviction of running models that are not in any shared set with it.
|
||||||
|
func TestMatrix_SwapEvictsConflicting(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
a.markReady()
|
||||||
|
go a.Run(0) // park a Run goroutine so Stop has something to release
|
||||||
|
|
||||||
|
b := newFakeProcess("b")
|
||||||
|
b.autoReady = true
|
||||||
|
|
||||||
|
// Two single-model sets: a and b never coexist, so loading b must evict a.
|
||||||
|
expanded := []config.ExpandedSet{
|
||||||
|
{SetName: "s_a", DSL: "a", Models: []string{"a"}},
|
||||||
|
{SetName: "s_b", DSL: "b", Models: []string{"b"}},
|
||||||
|
}
|
||||||
|
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": b})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, newRequest("b"))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if got := a.stopCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("a.stopCalls=%d want 1", got)
|
||||||
|
}
|
||||||
|
if got := b.runCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("b.runCalls=%d want 1", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMatrix_CoexistInSet verifies that a model is not evicted when the target
|
||||||
|
// shares a set with it (the fast path applies if the target is already ready).
|
||||||
|
func TestMatrix_CoexistInSet(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
a.markReady()
|
||||||
|
go a.Run(0)
|
||||||
|
|
||||||
|
b := newFakeProcess("b")
|
||||||
|
b.autoReady = true
|
||||||
|
|
||||||
|
// Both fit in s_ab, so b's swap should not stop a.
|
||||||
|
expanded := []config.ExpandedSet{
|
||||||
|
{SetName: "s_ab", DSL: "a & b", Models: []string{"a", "b"}},
|
||||||
|
}
|
||||||
|
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": b})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, newRequest("b"))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if got := a.stopCalls.Load(); got != 0 {
|
||||||
|
t.Errorf("a.stopCalls=%d want 0 (coexists with b)", got)
|
||||||
|
}
|
||||||
|
if got := b.runCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("b.runCalls=%d want 1", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMatrix_CoexistingSetParallel verifies that two models that share an
|
||||||
|
// expanded set load in parallel — the solver returns empty Evict for both,
|
||||||
|
// the collision predicate clears them, and both swaps run together.
|
||||||
|
func TestMatrix_CoexistingSetParallel(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
pb := newFakeProcess("b")
|
||||||
|
|
||||||
|
expanded := []config.ExpandedSet{
|
||||||
|
{SetName: "s_ab", DSL: "a & b", Models: []string{"a", "b"}},
|
||||||
|
}
|
||||||
|
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": pb})
|
||||||
|
|
||||||
|
w1 := httptest.NewRecorder()
|
||||||
|
done1 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
r.ServeHTTP(w1, newRequest("a"))
|
||||||
|
close(done1)
|
||||||
|
}()
|
||||||
|
waitProcessed(t, r.testProcessed, 1)
|
||||||
|
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
done2 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
r.ServeHTTP(w2, newRequest("b"))
|
||||||
|
close(done2)
|
||||||
|
}()
|
||||||
|
waitProcessed(t, r.testProcessed, 1)
|
||||||
|
|
||||||
|
<-a.runStarted
|
||||||
|
<-pb.runStarted
|
||||||
|
|
||||||
|
a.markReady()
|
||||||
|
pb.markReady()
|
||||||
|
|
||||||
|
for i, ch := range []chan struct{}{done1, done2} {
|
||||||
|
select {
|
||||||
|
case <-ch:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatalf("request %d did not complete", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if got := a.stopCalls.Load(); got != 0 {
|
||||||
|
t.Errorf("a.stopCalls=%d want 0 (coexists with b)", got)
|
||||||
|
}
|
||||||
|
if got := pb.stopCalls.Load(); got != 0 {
|
||||||
|
t.Errorf("b.stopCalls=%d want 0 (coexists with a)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMatrix_IncompatibleQueues verifies that the second request for a model
|
||||||
|
// that cannot coexist with the in-flight first model queues until the first
|
||||||
|
// completes, and then evicts it. This exercises the alsoRunning hint via the
|
||||||
|
// matrix solver's union into runningSet.
|
||||||
|
func TestMatrix_IncompatibleQueues(t *testing.T) {
|
||||||
|
a := newFakeProcess("a")
|
||||||
|
pb := newFakeProcess("b")
|
||||||
|
|
||||||
|
expanded := []config.ExpandedSet{
|
||||||
|
{SetName: "s_a", DSL: "a", Models: []string{"a"}},
|
||||||
|
{SetName: "s_b", DSL: "b", Models: []string{"b"}},
|
||||||
|
}
|
||||||
|
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": pb})
|
||||||
|
|
||||||
|
w1 := httptest.NewRecorder()
|
||||||
|
done1 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
r.ServeHTTP(w1, newRequest("a"))
|
||||||
|
close(done1)
|
||||||
|
}()
|
||||||
|
waitProcessed(t, r.testProcessed, 1)
|
||||||
|
|
||||||
|
// B arrives before A transitions to StateStarting. The solver sees A via
|
||||||
|
// alsoRunning and returns evict=[a], so collidesWith forces B to queue.
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
done2 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
r.ServeHTTP(w2, newRequest("b"))
|
||||||
|
close(done2)
|
||||||
|
}()
|
||||||
|
waitProcessed(t, r.testProcessed, 1)
|
||||||
|
|
||||||
|
if got := pb.runCalls.Load(); got != 0 {
|
||||||
|
t.Errorf("b started in parallel: runCalls=%d want 0", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
<-a.runStarted
|
||||||
|
a.markReady()
|
||||||
|
waitProcessed(t, r.testProcessed, 1) // swapDone(a) → b promoted, evicts a
|
||||||
|
<-pb.runStarted
|
||||||
|
pb.markReady()
|
||||||
|
|
||||||
|
for i, ch := range []chan struct{}{done1, done2} {
|
||||||
|
select {
|
||||||
|
case <-ch:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatalf("request %d did not complete", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if got := a.stopCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("a.stopCalls=%d want 1 (b's swap must stop a)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMatrixSolver_TieBreakDefinitionOrder pins the solver's tie-break rule:
|
||||||
|
// when multiple candidate sets have equal eviction cost, the earlier-defined
|
||||||
|
// set wins.
|
||||||
|
func TestMatrixSolver_TieBreakDefinitionOrder(t *testing.T) {
|
||||||
|
expanded := []config.ExpandedSet{
|
||||||
|
{SetName: "first", DSL: "a & b", Models: []string{"a", "b"}},
|
||||||
|
{SetName: "second", DSL: "a & c", Models: []string{"a", "c"}},
|
||||||
|
}
|
||||||
|
s := newMatrixSolver(expanded, nil)
|
||||||
|
|
||||||
|
// No models running, request "a": both sets have cost 0 and contain a.
|
||||||
|
// Definition order: "first" wins.
|
||||||
|
result := s.Solve("a", nil)
|
||||||
|
if result.SetName != "first" {
|
||||||
|
t.Errorf("SetName=%q want %q", result.SetName, "first")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMatrixSolver_EvictCostsPreferred verifies that higher evict costs steer
|
||||||
|
// the solver toward a cheaper set.
|
||||||
|
func TestMatrixSolver_EvictCostsPreferred(t *testing.T) {
|
||||||
|
// b is expensive to evict; c is cheap. Request "a" with both b and c
|
||||||
|
// running. The solver should pick the set that keeps b.
|
||||||
|
expanded := []config.ExpandedSet{
|
||||||
|
{SetName: "a_with_c", DSL: "a & c", Models: []string{"a", "c"}}, // would evict b (cost 10)
|
||||||
|
{SetName: "a_with_b", DSL: "a & b", Models: []string{"a", "b"}}, // would evict c (cost 1)
|
||||||
|
}
|
||||||
|
s := newMatrixSolver(expanded, map[string]int{"b": 10, "c": 1})
|
||||||
|
|
||||||
|
result := s.Solve("a", []string{"b", "c"})
|
||||||
|
if result.SetName != "a_with_b" {
|
||||||
|
t.Errorf("SetName=%q want %q (keep expensive b)", result.SetName, "a_with_b")
|
||||||
|
}
|
||||||
|
if len(result.Evict) != 1 || result.Evict[0] != "c" {
|
||||||
|
t.Errorf("Evict=%v want [c]", result.Evict)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,187 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"runtime"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
)
|
||||||
|
|
||||||
|
type peerMember struct {
|
||||||
|
peerID string
|
||||||
|
reverseProxy *httputil.ReverseProxy
|
||||||
|
apiKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Peer struct {
|
||||||
|
cfg config.Config
|
||||||
|
logger *logmon.Monitor
|
||||||
|
peers map[string]*peerMember
|
||||||
|
|
||||||
|
shutdownCtx context.Context
|
||||||
|
shutdownFn context.CancelFunc
|
||||||
|
shuttingDown atomic.Bool
|
||||||
|
inflight sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPeer(cfg config.Config, logger *logmon.Monitor) (*Peer, error) {
|
||||||
|
peers := cfg.Peers
|
||||||
|
modelMap := make(map[string]*peerMember)
|
||||||
|
|
||||||
|
peerIDs := make([]string, 0, len(peers))
|
||||||
|
for peerID := range peers {
|
||||||
|
peerIDs = append(peerIDs, peerID)
|
||||||
|
}
|
||||||
|
sort.Strings(peerIDs)
|
||||||
|
|
||||||
|
for _, peerID := range peerIDs {
|
||||||
|
peer := peers[peerID]
|
||||||
|
|
||||||
|
peerTransport := &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: time.Duration(peer.Timeouts.Connect) * time.Second,
|
||||||
|
KeepAlive: time.Duration(peer.Timeouts.KeepAlive) * time.Second,
|
||||||
|
}).DialContext,
|
||||||
|
TLSHandshakeTimeout: time.Duration(peer.Timeouts.TLSHandshake) * time.Second,
|
||||||
|
ResponseHeaderTimeout: time.Duration(peer.Timeouts.ResponseHeader) * time.Second,
|
||||||
|
ExpectContinueTimeout: time.Duration(peer.Timeouts.ExpectContinue) * time.Second,
|
||||||
|
ForceAttemptHTTP2: true,
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
MaxIdleConnsPerHost: 10,
|
||||||
|
IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
reverseProxy := &httputil.ReverseProxy{
|
||||||
|
Transport: peerTransport,
|
||||||
|
Rewrite: func(r *httputil.ProxyRequest) {
|
||||||
|
r.SetURL(peer.ProxyURL)
|
||||||
|
r.Out.Host = r.Out.URL.Host
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||||
|
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||||
|
resp.Header.Set("X-Accel-Buffering", "no")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
logger.Warnf("peer %s: proxy error: %v", peerID, err)
|
||||||
|
errMsg := fmt.Sprintf("peer proxy error: %v", err)
|
||||||
|
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
|
||||||
|
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
|
||||||
|
}
|
||||||
|
http.Error(w, errMsg, http.StatusBadGateway)
|
||||||
|
}
|
||||||
|
|
||||||
|
pp := &peerMember{
|
||||||
|
peerID: peerID,
|
||||||
|
reverseProxy: reverseProxy,
|
||||||
|
apiKey: peer.ApiKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, modelID := range peer.Models {
|
||||||
|
if _, found := modelMap[modelID]; found {
|
||||||
|
logger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
modelMap[modelID] = pp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
return &Peer{
|
||||||
|
cfg: cfg,
|
||||||
|
logger: logger,
|
||||||
|
peers: modelMap,
|
||||||
|
shutdownCtx: shutdownCtx,
|
||||||
|
shutdownFn: shutdownFn,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Peer) Handles(model string) bool {
|
||||||
|
_, ok := r.peers[model]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Peer) Shutdown(timeout time.Duration) error {
|
||||||
|
if !r.shuttingDown.CompareAndSwap(false, true) {
|
||||||
|
return fmt.Errorf("shutdown already in progress")
|
||||||
|
}
|
||||||
|
|
||||||
|
if timeout == 0 {
|
||||||
|
r.shutdownFn()
|
||||||
|
r.inflight.Wait()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
r.inflight.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return nil
|
||||||
|
case <-time.After(timeout):
|
||||||
|
r.shutdownFn()
|
||||||
|
r.inflight.Wait()
|
||||||
|
return fmt.Errorf("peer shutdown timed out after %v", timeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Peer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
|
if r.shuttingDown.Load() {
|
||||||
|
SendError(w, req, fmt.Errorf("peer proxy is shutting down"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.inflight.Add(1)
|
||||||
|
defer r.inflight.Done()
|
||||||
|
|
||||||
|
data, err := FetchContext(req, r.cfg)
|
||||||
|
if err != nil {
|
||||||
|
SendError(w, req, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pp, found := r.peers[data.ModelID]
|
||||||
|
if !found {
|
||||||
|
r.logger.Warnf("peer model not found: %s", data.ModelID)
|
||||||
|
SendError(w, req, ErrNoPeerModelFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.logger.Debugf("peer: routing model %s to peer %s", data.ModelID, pp.peerID)
|
||||||
|
|
||||||
|
if pp.apiKey != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+pp.apiKey)
|
||||||
|
req.Header.Set("x-api-key", pp.apiKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel the proxy request when the client disconnects or shutdown times out.
|
||||||
|
// AfterFunc links both parent contexts to our child without a goroutine leak.
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
stopReq := context.AfterFunc(req.Context(), cancel)
|
||||||
|
stopShutdown := context.AfterFunc(r.shutdownCtx, cancel)
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
|
pp.reverseProxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
stopShutdown()
|
||||||
|
stopReq()
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
@@ -0,0 +1,611 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
)
|
||||||
|
|
||||||
|
var testLogger = logmon.NewWriter(os.Stdout)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
testLogger.SetLogLevel(logmon.LevelWarn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPeer_EmptyPeers(t *testing.T) {
|
||||||
|
pr, err := NewPeer(config.Config{}, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if pr == nil {
|
||||||
|
t.Fatal("expected non-nil Peer")
|
||||||
|
}
|
||||||
|
if len(pr.peers) != 0 {
|
||||||
|
t.Fatalf("expected empty peers map, got %d entries", len(pr.peers))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPeer_SinglePeer(t *testing.T) {
|
||||||
|
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: "http://peer1.example.com:8080",
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
ApiKey: "test-key",
|
||||||
|
Models: []string{"model-a", "model-b"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(pr.peers) != 2 {
|
||||||
|
t.Fatalf("expected 2 entries, got %d", len(pr.peers))
|
||||||
|
}
|
||||||
|
if _, ok := pr.peers["model-a"]; !ok {
|
||||||
|
t.Error("expected model-a to be mapped")
|
||||||
|
}
|
||||||
|
if _, ok := pr.peers["model-b"]; !ok {
|
||||||
|
t.Error("expected model-b to be mapped")
|
||||||
|
}
|
||||||
|
if _, ok := pr.peers["model-c"]; ok {
|
||||||
|
t.Error("expected model-c to not be mapped")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPeer_MultiplePeers(t *testing.T) {
|
||||||
|
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||||
|
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: "http://peer1.example.com:8080",
|
||||||
|
ProxyURL: proxyURL1,
|
||||||
|
Models: []string{"model-a", "model-b"},
|
||||||
|
},
|
||||||
|
"peer2": config.PeerConfig{
|
||||||
|
Proxy: "http://peer2.example.com:8080",
|
||||||
|
ProxyURL: proxyURL2,
|
||||||
|
Models: []string{"model-c", "model-d"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(pr.peers) != 4 {
|
||||||
|
t.Fatalf("expected 4 entries, got %d", len(pr.peers))
|
||||||
|
}
|
||||||
|
for _, m := range []string{"model-a", "model-b", "model-c", "model-d"} {
|
||||||
|
if _, ok := pr.peers[m]; !ok {
|
||||||
|
t.Errorf("expected %s to be mapped", m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPeer_DuplicateModel(t *testing.T) {
|
||||||
|
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||||
|
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"alpha-peer": config.PeerConfig{
|
||||||
|
Proxy: "http://peer1.example.com:8080",
|
||||||
|
ProxyURL: proxyURL1,
|
||||||
|
Models: []string{"duplicate-model"},
|
||||||
|
},
|
||||||
|
"beta-peer": config.PeerConfig{
|
||||||
|
Proxy: "http://peer2.example.com:8080",
|
||||||
|
ProxyURL: proxyURL2,
|
||||||
|
Models: []string{"duplicate-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(pr.peers) != 1 {
|
||||||
|
t.Fatalf("expected 1 entry for duplicate model, got %d", len(pr.peers))
|
||||||
|
}
|
||||||
|
if _, ok := pr.peers["duplicate-model"]; !ok {
|
||||||
|
t.Error("expected duplicate-model to be mapped")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_ServeHTTP_Success(t *testing.T) {
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("response from peer"))
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
proxyURL, _ := url.Parse(testServer.URL)
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: testServer.URL,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Models: []string{"test-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
pr.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
if w.Body.String() != "response from peer" {
|
||||||
|
t.Errorf("expected 'response from peer', got %q", w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_ServeHTTP_ModelNotFoundInContext(t *testing.T) {
|
||||||
|
pr, err := NewPeer(config.Config{}, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
pr.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_ServeHTTP_PeerModelNotFound(t *testing.T) {
|
||||||
|
pr, err := NewPeer(config.Config{}, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "nonexistent-model", ModelID: "nonexistent-model"}))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
pr.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_ServeHTTP_ApiKeyInjection(t *testing.T) {
|
||||||
|
var receivedAuthHeader string
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
receivedAuthHeader = r.Header.Get("Authorization")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
proxyURL, _ := url.Parse(testServer.URL)
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: testServer.URL,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
ApiKey: "secret-api-key",
|
||||||
|
Models: []string{"test-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
pr.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if receivedAuthHeader != "Bearer secret-api-key" {
|
||||||
|
t.Errorf("expected 'Bearer secret-api-key', got %q", receivedAuthHeader)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_ServeHTTP_NoApiKey(t *testing.T) {
|
||||||
|
var receivedAuthHeader string
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
receivedAuthHeader = r.Header.Get("Authorization")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
proxyURL, _ := url.Parse(testServer.URL)
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: testServer.URL,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
ApiKey: "",
|
||||||
|
Models: []string{"test-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
pr.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if receivedAuthHeader != "" {
|
||||||
|
t.Errorf("expected no auth header, got %q", receivedAuthHeader)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_ServeHTTP_HostHeaderSet(t *testing.T) {
|
||||||
|
var receivedHost string
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
receivedHost = r.Host
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
proxyURL, _ := url.Parse(testServer.URL)
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: testServer.URL,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Models: []string{"test-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
pr.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if !strings.HasPrefix(receivedHost, "127.0.0.1:") {
|
||||||
|
t.Errorf("expected Host to start with '127.0.0.1:', got %q", receivedHost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_ServeHTTP_SSEHeaderModification(t *testing.T) {
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
proxyURL, _ := url.Parse(testServer.URL)
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: testServer.URL,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Models: []string{"test-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
pr.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Header().Get("X-Accel-Buffering") != "no" {
|
||||||
|
t.Errorf("expected X-Accel-Buffering=no, got %q", w.Header().Get("X-Accel-Buffering"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_ServeHTTP_ShutdownRejectsNewRequests(t *testing.T) {
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
proxyURL, _ := url.Parse(testServer.URL)
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: testServer.URL,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Models: []string{"test-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pr.Shutdown(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
pr.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusInternalServerError {
|
||||||
|
t.Errorf("expected 500, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if !strings.Contains(w.Body.String(), "shutting down") {
|
||||||
|
t.Errorf("expected 'shutting down' in body, got %q", w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_ServeHTTP_WaitsForInflightDuringShutdown(t *testing.T) {
|
||||||
|
started := make(chan struct{})
|
||||||
|
released := make(chan struct{})
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
close(started)
|
||||||
|
<-released
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
proxyURL, _ := url.Parse(testServer.URL)
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: testServer.URL,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Models: []string{"test-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
pr.ServeHTTP(w, req)
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-started
|
||||||
|
|
||||||
|
shutdownDone := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
shutdownDone <- pr.Shutdown(500 * time.Millisecond)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Shutdown should be waiting on inflight. If it finished already something is wrong.
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
select {
|
||||||
|
case err := <-shutdownDone:
|
||||||
|
t.Errorf("shutdown completed before inflight finished: %v", err)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
close(released)
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-shutdownDone:
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("shutdown errored after inflight completed: %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Error("shutdown did not complete after inflight finished")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_ServeHTTP_ShutdownTimeoutCancelsInflight(t *testing.T) {
|
||||||
|
started := make(chan struct{})
|
||||||
|
released := make(chan struct{})
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
close(started)
|
||||||
|
<-released
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
proxyURL, _ := url.Parse(testServer.URL)
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: testServer.URL,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Models: []string{"test-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
pr.ServeHTTP(w, req)
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-started
|
||||||
|
|
||||||
|
err = pr.Shutdown(100 * time.Millisecond)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected timeout error from shutdown")
|
||||||
|
}
|
||||||
|
|
||||||
|
close(released)
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_ShutdownMultiple(t *testing.T) {
|
||||||
|
pr, err := NewPeer(config.Config{}, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pr.Shutdown(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pr.Shutdown(0)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error on second shutdown")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "already in progress") {
|
||||||
|
t.Errorf("expected 'already in progress', got %q", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_ServeHTTP_ModelExtractedFromBody(t *testing.T) {
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
proxyURL, _ := url.Parse(testServer.URL)
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: testServer.URL,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Models: []string{"extracted-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := strings.NewReader(`{"model":"extracted-model","prompt":"hello"}`)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
pr.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_ServeHTTP_ContextOverridesBodyModel(t *testing.T) {
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
proxyURL, _ := url.Parse(testServer.URL)
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: testServer.URL,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Models: []string{"context-model"},
|
||||||
|
},
|
||||||
|
"peer2": config.PeerConfig{
|
||||||
|
Proxy: testServer.URL,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Models: []string{"body-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := strings.NewReader(`{"model":"body-model","prompt":"hello"}`)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "context-model", ModelID: "context-model"}))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
pr.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPeer_CustomTimeouts(t *testing.T) {
|
||||||
|
proxyURL, _ := url.Parse("http://localhost:8080")
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"test-peer": config.PeerConfig{
|
||||||
|
Proxy: "http://localhost:8080",
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Models: []string{"model1"},
|
||||||
|
Timeouts: config.TimeoutsConfig{
|
||||||
|
Connect: 45,
|
||||||
|
ResponseHeader: 300,
|
||||||
|
TLSHandshake: 15,
|
||||||
|
ExpectContinue: 2,
|
||||||
|
IdleConn: 120,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
member, ok := pr.peers["model1"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected model1 to be mapped")
|
||||||
|
}
|
||||||
|
|
||||||
|
transport, ok := member.reverseProxy.Transport.(*http.Transport)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected Transport to be *http.Transport")
|
||||||
|
}
|
||||||
|
|
||||||
|
if transport.ResponseHeaderTimeout != 300*time.Second {
|
||||||
|
t.Errorf("expected ResponseHeaderTimeout=%v, got %v", 300*time.Second, transport.ResponseHeaderTimeout)
|
||||||
|
}
|
||||||
|
if transport.TLSHandshakeTimeout != 15*time.Second {
|
||||||
|
t.Errorf("expected TLSHandshakeTimeout=%v, got %v", 15*time.Second, transport.TLSHandshakeTimeout)
|
||||||
|
}
|
||||||
|
if transport.ExpectContinueTimeout != 2*time.Second {
|
||||||
|
t.Errorf("expected ExpectContinueTimeout=%v, got %v", 2*time.Second, transport.ExpectContinueTimeout)
|
||||||
|
}
|
||||||
|
if transport.IdleConnTimeout != 120*time.Second {
|
||||||
|
t.Errorf("expected IdleConnTimeout=%v, got %v", 120*time.Second, transport.IdleConnTimeout)
|
||||||
|
}
|
||||||
|
if !transport.ForceAttemptHTTP2 {
|
||||||
|
t.Error("expected ForceAttemptHTTP2 to be true")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,199 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type contextkey struct {
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReqContextData struct {
|
||||||
|
Model string
|
||||||
|
ModelID string
|
||||||
|
Streaming bool
|
||||||
|
SendLoadingState bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNoModelInContext = fmt.Errorf("no model in request context")
|
||||||
|
ErrNoRouterFound = fmt.Errorf("no router found for model")
|
||||||
|
ErrNoPeerModelFound = fmt.Errorf("peer model not found")
|
||||||
|
ErrNoLocalModelFound = fmt.Errorf("local model not found")
|
||||||
|
|
||||||
|
ContextKey = &contextkey{"context"}
|
||||||
|
)
|
||||||
|
|
||||||
|
type Router interface {
|
||||||
|
// Shutdown blocks until the router has shutdown returning nil
|
||||||
|
// when the router has shutdown successfully.
|
||||||
|
//
|
||||||
|
// timeout controls how long to wait for inflight requests to finish. After
|
||||||
|
// the timeout all inflight requests will be cancelled.
|
||||||
|
Shutdown(timeout time.Duration) error
|
||||||
|
|
||||||
|
// ServeHTTP implements the http.Handler and requests coming in will
|
||||||
|
// trigger any model swapping and routing logic.
|
||||||
|
ServeHTTP(http.ResponseWriter, *http.Request)
|
||||||
|
|
||||||
|
// Handles reports whether this router can serve requests for the given model.
|
||||||
|
Handles(model string) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalRouter is a Router backed by local processes whose state can be
|
||||||
|
// inspected and which can be individually stopped. Peer routers, which only
|
||||||
|
// forward to remote hosts, do not implement it.
|
||||||
|
type LocalRouter interface {
|
||||||
|
Router
|
||||||
|
|
||||||
|
// RunningModels returns the current state of every process that is not
|
||||||
|
// stopped or shut down, keyed by model ID.
|
||||||
|
RunningModels() map[string]process.ProcessState
|
||||||
|
|
||||||
|
// Unload stops the named models, or every running model when none are
|
||||||
|
// named. It blocks until each targeted process has stopped.
|
||||||
|
Unload(timeout time.Duration, models ...string)
|
||||||
|
|
||||||
|
// ProcessLogger returns the log monitor for the named model's process.
|
||||||
|
// modelID must be a real (non-alias) config key. Returns false when the
|
||||||
|
// model is not known to this router.
|
||||||
|
ProcessLogger(modelID string) (*logmon.Monitor, bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchContext will attempt to get the model id from the context then
|
||||||
|
// from the model body. If it extracts the model from the body it will
|
||||||
|
// store the model in the context for downstream handlers. An error
|
||||||
|
// will be returned when model can not be fetch from either location.
|
||||||
|
func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
|
||||||
|
data, ok := ReadContext(r.Context())
|
||||||
|
if ok {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if data, err := ExtractContext(r); err == nil {
|
||||||
|
realName, _ := cfg.RealModelName(data.Model)
|
||||||
|
if realName == "" {
|
||||||
|
realName = data.Model
|
||||||
|
}
|
||||||
|
data.ModelID = realName
|
||||||
|
if mc, ok := cfg.Models[realName]; ok {
|
||||||
|
data.SendLoadingState = mc.SendLoadingState != nil && *mc.SendLoadingState
|
||||||
|
}
|
||||||
|
*r = *r.WithContext(SetContext(r.Context(), data))
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ReqContextData{}, ErrNoModelInContext
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetContext(ctx context.Context, data ReqContextData) context.Context {
|
||||||
|
return context.WithValue(ctx, ContextKey, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReadContext(ctx context.Context) (ReqContextData, bool) {
|
||||||
|
data, ok := ctx.Value(ContextKey).(ReqContextData)
|
||||||
|
return data, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractContext pulls the model name from an HTTP request without consuming the
|
||||||
|
// body. For GET requests it reads the "model" query parameter. For POST
|
||||||
|
// requests it inspects Content-Type and parses JSON, multipart/form-data, or
|
||||||
|
// application/x-www-form-urlencoded bodies. The request body is always restored
|
||||||
|
// before returning so downstream handlers — including reverse proxies that
|
||||||
|
// forward raw bytes upstream — can still read it.
|
||||||
|
func ExtractContext(r *http.Request) (ReqContextData, error) {
|
||||||
|
if r.Method == http.MethodGet {
|
||||||
|
if model := r.URL.Query().Get("model"); model != "" {
|
||||||
|
return ReqContextData{Model: model, Streaming: r.URL.Query().Get("stream") == "true"}, nil
|
||||||
|
}
|
||||||
|
return ReqContextData{}, fmt.Errorf("missing 'model' query parameter")
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return ReqContextData{}, fmt.Errorf("error reading request body: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
}()
|
||||||
|
|
||||||
|
contentType := r.Header.Get("Content-Type")
|
||||||
|
|
||||||
|
if strings.Contains(contentType, "application/json") {
|
||||||
|
model := gjson.GetBytes(bodyBytes, "model").String()
|
||||||
|
if model == "" {
|
||||||
|
return ReqContextData{}, fmt.Errorf("missing or empty 'model' in JSON body")
|
||||||
|
}
|
||||||
|
return ReqContextData{Model: model, Streaming: gjson.GetBytes(bodyBytes, "stream").Bool()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Form parsers read from r.Body, so feed them a fresh reader over the
|
||||||
|
// buffered bytes. The deferred restore above will reset r.Body again
|
||||||
|
// after parsing.
|
||||||
|
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
if strings.Contains(contentType, "multipart/form-data") {
|
||||||
|
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||||
|
return ReqContextData{}, fmt.Errorf("error parsing multipart form: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
return ReqContextData{}, fmt.Errorf("error parsing form: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if model := r.FormValue("model"); model != "" {
|
||||||
|
return ReqContextData{Model: model, Streaming: r.FormValue("stream") == "true"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ReqContextData{}, fmt.Errorf("missing 'model' parameter")
|
||||||
|
}
|
||||||
|
|
||||||
|
func SendError(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, ErrNoModelInContext):
|
||||||
|
SendResponse(w, r, http.StatusNotFound, "no model id could be identified")
|
||||||
|
case errors.Is(err, ErrNoPeerModelFound):
|
||||||
|
SendResponse(w, r, http.StatusNotFound, "no peer found for requested model")
|
||||||
|
case errors.Is(err, ErrNoLocalModelFound):
|
||||||
|
SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
|
||||||
|
case errors.Is(err, ErrNoRouterFound):
|
||||||
|
SendResponse(w, r, http.StatusNotFound, "no router for requested model")
|
||||||
|
default:
|
||||||
|
SendResponse(w, r, http.StatusInternalServerError, fmt.Sprintf("unspecific error: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendResponse detects what content type the client prefers and returns an error response in that format.
|
||||||
|
func SendResponse(w http.ResponseWriter, r *http.Request, status int, message string) {
|
||||||
|
// Check Accept header for preferred response format
|
||||||
|
acceptHeader := r.Header.Get("Accept")
|
||||||
|
if strings.Contains(acceptHeader, "text/plain") {
|
||||||
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
w.Write([]byte(fmt.Sprintf("llama-swap: %s", message)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(acceptHeader, "text/html") {
|
||||||
|
w.Header().Set("Content-Type", "text/html")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
w.Write([]byte(fmt.Sprintf(`<html><body><h1>llama-swap</h1><p>%s</p></body></html>`, message)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
w.Write([]byte(fmt.Sprintf(`{"src":"llama-swap", "error": "%s"}`, message)))
|
||||||
|
}
|
||||||
@@ -0,0 +1,275 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractContext_GET(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
query string
|
||||||
|
wantModel string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"model present", "model=llama3", "llama3", false},
|
||||||
|
{"model with slashes", "model=author/model-7b", "author/model-7b", false},
|
||||||
|
{"model missing", "", "", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
|
||||||
|
got, err := ExtractContext(r)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||||
|
}
|
||||||
|
if got.Model != tt.wantModel {
|
||||||
|
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_JSON(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
wantModel string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"model present", `{"model":"llama3","stream":true}`, "llama3", false},
|
||||||
|
{"model with slashes", `{"model":"author/model-7b"}`, "author/model-7b", false},
|
||||||
|
{"model empty string", `{"model":""}`, "", true},
|
||||||
|
{"model key missing", `{"stream":true}`, "", true},
|
||||||
|
{"invalid json", `not-json`, "", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
got, err := ExtractContext(r)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||||
|
}
|
||||||
|
if got.Model != tt.wantModel {
|
||||||
|
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_URLEncodedForm(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
formModel string
|
||||||
|
wantModel string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"model present", "whisper-1", "whisper-1", false},
|
||||||
|
{"model missing", "", "", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
form := url.Values{}
|
||||||
|
if tt.formModel != "" {
|
||||||
|
form.Set("model", tt.formModel)
|
||||||
|
}
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(form.Encode()))
|
||||||
|
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
got, err := ExtractContext(r)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||||
|
}
|
||||||
|
if got.Model != tt.wantModel {
|
||||||
|
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_MultipartForm(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
formModel string
|
||||||
|
wantModel string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"model present", "whisper-1", "whisper-1", false},
|
||||||
|
{"model missing", "", "", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
mw := multipart.NewWriter(&buf)
|
||||||
|
if tt.formModel != "" {
|
||||||
|
fw, _ := mw.CreateFormField("model")
|
||||||
|
fw.Write([]byte(tt.formModel))
|
||||||
|
}
|
||||||
|
mw.Close()
|
||||||
|
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf)
|
||||||
|
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||||
|
got, err := ExtractContext(r)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||||
|
}
|
||||||
|
if got.Model != tt.wantModel {
|
||||||
|
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_JSONBodyRestored(t *testing.T) {
|
||||||
|
body := `{"model":"llama3","stream":true}`
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body))
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
if _, err := ExtractContext(r); err != nil {
|
||||||
|
t.Fatalf("ExtractContext: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
remaining, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading body after ExtractContext: %v", err)
|
||||||
|
}
|
||||||
|
if string(remaining) != body {
|
||||||
|
t.Errorf("body not restored: want %q got %q", body, string(remaining))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_MultipartBodyRestored(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
mw := multipart.NewWriter(&buf)
|
||||||
|
fw, _ := mw.CreateFormField("model")
|
||||||
|
fw.Write([]byte("whisper-1"))
|
||||||
|
ff, _ := mw.CreateFormFile("file", "audio.wav")
|
||||||
|
ff.Write([]byte("fake-audio-bytes"))
|
||||||
|
mw.Close()
|
||||||
|
|
||||||
|
original := buf.Bytes()
|
||||||
|
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", bytes.NewReader(original))
|
||||||
|
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||||
|
|
||||||
|
if _, err := ExtractContext(r); err != nil {
|
||||||
|
t.Fatalf("ExtractContext: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
remaining, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading body after ExtractContext: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(remaining, original) {
|
||||||
|
t.Errorf("multipart body not restored: want %d bytes got %d bytes", len(original), len(remaining))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_URLEncodedBodyRestored(t *testing.T) {
|
||||||
|
body := "model=whisper-1&extra=value"
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(body))
|
||||||
|
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
if _, err := ExtractContext(r); err != nil {
|
||||||
|
t.Fatalf("ExtractContext: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
remaining, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading body after ExtractContext: %v", err)
|
||||||
|
}
|
||||||
|
if string(remaining) != body {
|
||||||
|
t.Errorf("url-encoded body not restored: want %q got %q", body, string(remaining))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetContext(t *testing.T) {
|
||||||
|
ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"})
|
||||||
|
data, ok := ctx.Value(ContextKey).(ReqContextData)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("ContextKey not set or wrong type")
|
||||||
|
}
|
||||||
|
if data.Model != "llama3" {
|
||||||
|
t.Errorf("want %q got %q", "llama3", data.Model)
|
||||||
|
}
|
||||||
|
if data.ModelID != "llama3" {
|
||||||
|
t.Errorf("want %q got %q", "llama3", data.ModelID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetContext_WithAlias(t *testing.T) {
|
||||||
|
ctx := SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"})
|
||||||
|
data, _ := ctx.Value(ContextKey).(ReqContextData)
|
||||||
|
if data.Model != "llama" {
|
||||||
|
t.Errorf("want requested %q got %q", "llama", data.Model)
|
||||||
|
}
|
||||||
|
if data.ModelID != "llama3" {
|
||||||
|
t.Errorf("want real %q got %q", "llama3", data.ModelID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetContext_DoesNotMutateParent(t *testing.T) {
|
||||||
|
parent := context.Background()
|
||||||
|
_ = SetContext(parent, ReqContextData{Model: "llama3", ModelID: "llama3"})
|
||||||
|
if v := parent.Value(ContextKey); v != nil {
|
||||||
|
t.Errorf("parent context was mutated: %v", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadContext(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ctx context.Context
|
||||||
|
wantReq string
|
||||||
|
wantReal string
|
||||||
|
wantBool bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "model present, same name",
|
||||||
|
ctx: SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"}),
|
||||||
|
wantReq: "llama3",
|
||||||
|
wantReal: "llama3",
|
||||||
|
wantBool: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model present, aliased",
|
||||||
|
ctx: SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"}),
|
||||||
|
wantReq: "llama",
|
||||||
|
wantReal: "llama3",
|
||||||
|
wantBool: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model absent",
|
||||||
|
ctx: context.Background(),
|
||||||
|
wantReq: "",
|
||||||
|
wantReal: "",
|
||||||
|
wantBool: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model is empty string",
|
||||||
|
ctx: SetContext(context.Background(), ReqContextData{Model: "", ModelID: ""}),
|
||||||
|
wantReq: "",
|
||||||
|
wantReal: "",
|
||||||
|
wantBool: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotData, ok := ReadContext(tt.ctx)
|
||||||
|
if gotData.Model != tt.wantReq || gotData.ModelID != tt.wantReal || ok != tt.wantBool {
|
||||||
|
t.Errorf("want (%q, %q, %v) got (%q, %q, %v)", tt.wantReq, tt.wantReal, tt.wantBool, gotData.Model, gotData.ModelID, ok)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,269 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/event"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/router"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
|
)
|
||||||
|
|
||||||
|
// apiUnloadTimeout is used by the API endpoints to stop processes
|
||||||
|
const apiUnloadTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
// modelRecord is one entry in the OpenAI-compatible /v1/models listing.
|
||||||
|
type modelRecord struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
OwnedBy string `json:"owned_by"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Meta map[string]any `json:"meta,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleListModels serves the OpenAI-compatible model listing: local models
|
||||||
|
// (with optional aliases) plus peer models.
|
||||||
|
func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
|
||||||
|
created := time.Now().Unix()
|
||||||
|
data := make([]modelRecord, 0, len(s.cfg.Models))
|
||||||
|
|
||||||
|
newRecord := func(id, name, description string, metadata map[string]any) modelRecord {
|
||||||
|
rec := modelRecord{
|
||||||
|
ID: id,
|
||||||
|
Object: "model",
|
||||||
|
Created: created,
|
||||||
|
OwnedBy: "llama-swap",
|
||||||
|
Name: strings.TrimSpace(name),
|
||||||
|
Description: strings.TrimSpace(description),
|
||||||
|
}
|
||||||
|
if len(metadata) > 0 {
|
||||||
|
rec.Meta = map[string]any{"llamaswap": metadata}
|
||||||
|
}
|
||||||
|
return rec
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, mc := range s.cfg.Models {
|
||||||
|
if mc.Unlisted {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data = append(data, newRecord(id, mc.Name, mc.Description, mc.Metadata))
|
||||||
|
|
||||||
|
if s.cfg.IncludeAliasesInList {
|
||||||
|
for _, alias := range mc.Aliases {
|
||||||
|
if alias := strings.TrimSpace(alias); alias != "" {
|
||||||
|
data = append(data, newRecord(alias, mc.Name, mc.Description, mc.Metadata))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for peerID, peer := range s.cfg.Peers {
|
||||||
|
for _, modelID := range peer.Models {
|
||||||
|
data = append(data, newRecord(modelID, peerID+": "+modelID, "", map[string]any{"peerID": peerID}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(data, func(i, j int) bool { return data[i].ID < data[j].ID })
|
||||||
|
|
||||||
|
// Echo the Origin so browser clients can read the listing.
|
||||||
|
if origin := r.Header.Get("Origin"); origin != "" {
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"object": "list",
|
||||||
|
"data": data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// runningModel is one entry in the /running listing.
|
||||||
|
type runningModel struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
State string `json:"state"`
|
||||||
|
Cmd string `json:"cmd"`
|
||||||
|
Proxy string `json:"proxy"`
|
||||||
|
TTL int `json:"ttl"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleUnload stops every running local process. Peer models are remote and
|
||||||
|
// unaffected.
|
||||||
|
func (s *Server) handleUnload(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.local.Unload(apiUnloadTimeout)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleRunning lists local processes that are not stopped, joining each model
|
||||||
|
// ID against its config for the cmd/proxy/ttl/name/description metadata.
|
||||||
|
func (s *Server) handleRunning(w http.ResponseWriter, r *http.Request) {
|
||||||
|
states := s.local.RunningModels()
|
||||||
|
list := make([]runningModel, 0, len(states))
|
||||||
|
for id, state := range states {
|
||||||
|
mc := s.cfg.Models[id]
|
||||||
|
list = append(list, runningModel{
|
||||||
|
Model: id,
|
||||||
|
State: string(state),
|
||||||
|
Cmd: mc.Cmd,
|
||||||
|
Proxy: mc.Proxy,
|
||||||
|
TTL: mc.UnloadAfter,
|
||||||
|
Name: mc.Name,
|
||||||
|
Description: mc.Description,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
sort.Slice(list, func(i, j int) bool { return list[i].Model < list[j].Model })
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{"running": list})
|
||||||
|
}
|
||||||
|
|
||||||
|
// discardResponseWriter satisfies http.ResponseWriter for preload requests,
|
||||||
|
// dropping the body while capturing the status code.
|
||||||
|
type discardResponseWriter struct {
|
||||||
|
header http.Header
|
||||||
|
status int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *discardResponseWriter) Header() http.Header {
|
||||||
|
if d.header == nil {
|
||||||
|
d.header = make(http.Header)
|
||||||
|
}
|
||||||
|
return d.header
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *discardResponseWriter) Write(p []byte) (int, error) { return len(p), nil }
|
||||||
|
|
||||||
|
func (d *discardResponseWriter) WriteHeader(status int) { d.status = status }
|
||||||
|
|
||||||
|
// startPreload fires a background GET / at every model named in
|
||||||
|
// Hooks.OnStartup.Preload so they are warm before the first real request.
|
||||||
|
// Preload names are already resolved to real model IDs by config loading.
|
||||||
|
func (s *Server) startPreload() {
|
||||||
|
models := s.cfg.Hooks.OnStartup.Preload
|
||||||
|
if len(models) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
for _, modelID := range models {
|
||||||
|
if !s.local.Handles(modelID) {
|
||||||
|
s.proxylog.Warnf("preload: model %s is not a local model, skipping", modelID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.proxylog.Infof("preloading model: %s", modelID)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(s.shutdownCtx, http.MethodGet, "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
req = req.WithContext(router.SetContext(req.Context(), router.ReqContextData{Model: modelID, ModelID: modelID}))
|
||||||
|
|
||||||
|
dw := &discardResponseWriter{status: http.StatusOK}
|
||||||
|
s.local.ServeHTTP(dw, req)
|
||||||
|
|
||||||
|
success := dw.status < http.StatusBadRequest
|
||||||
|
if !success {
|
||||||
|
s.proxylog.Errorf("failed to preload model %s: status %d", modelID, dw.status)
|
||||||
|
}
|
||||||
|
event.Emit(shared.ModelPreloadedEvent{ModelName: modelID, Success: success})
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleMetrics serves Prometheus-format performance metrics. Returns 503 when
|
||||||
|
// performance monitoring is disabled.
|
||||||
|
func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if s.perf == nil {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
w.Write([]byte("# performance monitor not available\n"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.perf.MetricsHandler().ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleRootRedirect(w http.ResponseWriter, r *http.Request) {
|
||||||
|
http.Redirect(w, r, "/ui", http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleUpstreamRedirect(w http.ResponseWriter, r *http.Request) {
|
||||||
|
http.Redirect(w, r, "/ui/models", http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleUpstream proxies ANY request under /upstream/<model>/<path> directly to
|
||||||
|
// the model's process, bypassing model dispatch by body/query inspection.
|
||||||
|
func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
||||||
|
upstreamPath := r.PathValue("upstreamPath")
|
||||||
|
|
||||||
|
searchName, modelID, remainingPath, found := findModelInPath(s.cfg, "/"+upstreamPath)
|
||||||
|
if !found {
|
||||||
|
router.SendResponse(w, r, http.StatusNotFound, "model not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redirect /upstream/model to /upstream/model/ so relative URLs in upstream
|
||||||
|
// responses resolve. 301 for GET/HEAD, 308 otherwise to preserve the method.
|
||||||
|
if remainingPath == "/" && !strings.HasSuffix(r.URL.Path, "/") {
|
||||||
|
newPath := "/upstream/" + searchName + "/"
|
||||||
|
if r.URL.RawQuery != "" {
|
||||||
|
newPath += "?" + r.URL.RawQuery
|
||||||
|
}
|
||||||
|
if r.Method == http.MethodGet || r.Method == http.MethodHead {
|
||||||
|
http.Redirect(w, r, newPath, http.StatusMovedPermanently)
|
||||||
|
} else {
|
||||||
|
http.Redirect(w, r, newPath, http.StatusPermanentRedirect)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strip the /upstream/<model> prefix before forwarding.
|
||||||
|
r.URL.Path = remainingPath
|
||||||
|
// Pin the resolved model so the router skips body/query extraction.
|
||||||
|
*r = *r.WithContext(router.SetContext(r.Context(), router.ReqContextData{Model: searchName, ModelID: modelID}))
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case s.local.Handles(modelID):
|
||||||
|
s.local.ServeHTTP(w, r)
|
||||||
|
case s.peer.Handles(modelID):
|
||||||
|
s.peer.ServeHTTP(w, r)
|
||||||
|
default:
|
||||||
|
router.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// findModelInPath walks a slash-separated path, building up segments until one
|
||||||
|
// matches a configured model. This resolves model names that contain slashes
|
||||||
|
// (e.g. "author/model"). Returns the matched name, its real model ID, the
|
||||||
|
// remaining path, and whether a match was found.
|
||||||
|
func findModelInPath(cfg config.Config, path string) (searchName, realName, remainingPath string, found bool) {
|
||||||
|
parts := strings.Split(strings.TrimSpace(path), "/")
|
||||||
|
name := ""
|
||||||
|
|
||||||
|
for i, part := range parts {
|
||||||
|
if part == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
name = part
|
||||||
|
} else {
|
||||||
|
name = name + "/" + part
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelID, ok := cfg.RealModelName(name); ok {
|
||||||
|
return name, modelID, "/" + strings.Join(parts[i+1:], "/"), true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", "", "", false
|
||||||
|
}
|
||||||
@@ -0,0 +1,159 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServer_HandleListModels(t *testing.T) {
|
||||||
|
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||||
|
s.cfg = config.Config{
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"visible": {Name: "Visible", Description: "a model"},
|
||||||
|
"hidden": {Unlisted: true},
|
||||||
|
},
|
||||||
|
Peers: config.PeerDictionaryConfig{
|
||||||
|
"peer1": {Models: []string{"remote-model"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||||
|
req.Header.Set("Origin", "http://example.com")
|
||||||
|
s.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status = %d", w.Code)
|
||||||
|
}
|
||||||
|
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
|
||||||
|
t.Errorf("Access-Control-Allow-Origin = %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Data []modelRecord `json:"data"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
ids := map[string]bool{}
|
||||||
|
for _, m := range resp.Data {
|
||||||
|
ids[m.ID] = true
|
||||||
|
}
|
||||||
|
if !ids["visible"] || !ids["remote-model"] {
|
||||||
|
t.Errorf("missing expected models: %v", ids)
|
||||||
|
}
|
||||||
|
if ids["hidden"] {
|
||||||
|
t.Error("unlisted model should not appear")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_HandleListModels_Aliases(t *testing.T) {
|
||||||
|
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||||
|
s.cfg = config.Config{
|
||||||
|
IncludeAliasesInList: true,
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"real": {Aliases: []string{"nick"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/v1/models", nil))
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Data []modelRecord `json:"data"`
|
||||||
|
}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
ids := map[string]bool{}
|
||||||
|
for _, m := range resp.Data {
|
||||||
|
ids[m.ID] = true
|
||||||
|
}
|
||||||
|
if !ids["real"] || !ids["nick"] {
|
||||||
|
t.Errorf("expected alias entry; got %v", ids)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_FindModelInPath(t *testing.T) {
|
||||||
|
cfg := config.Config{Models: map[string]config.ModelConfig{
|
||||||
|
"author/model": {},
|
||||||
|
"simple": {},
|
||||||
|
}}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
path string
|
||||||
|
wantName string
|
||||||
|
wantRem string
|
||||||
|
wantFound bool
|
||||||
|
}{
|
||||||
|
{"/simple/v1/chat", "simple", "/v1/chat", true},
|
||||||
|
{"/author/model/v1/chat", "author/model", "/v1/chat", true},
|
||||||
|
{"/author/model", "author/model", "/", true},
|
||||||
|
{"/missing/v1", "", "", false},
|
||||||
|
{"/", "", "", false},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
name, _, rem, found := findModelInPath(cfg, c.path)
|
||||||
|
if found != c.wantFound || name != c.wantName || (found && rem != c.wantRem) {
|
||||||
|
t.Errorf("findModelInPath(%q) = (%q,%q,%v), want (%q,%q,%v)",
|
||||||
|
c.path, name, rem, found, c.wantName, c.wantRem, c.wantFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_HandleUpstream(t *testing.T) {
|
||||||
|
local := newStubRouter([]string{"m1"}, "upstream-body")
|
||||||
|
s := newTestServer(local, newStubRouter(nil, ""))
|
||||||
|
s.cfg = config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
|
||||||
|
|
||||||
|
t.Run("proxies to local", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/v1/chat", nil))
|
||||||
|
if w.Code != http.StatusOK || w.Body.String() != "upstream-body" {
|
||||||
|
t.Errorf("status=%d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("redirects bare model path", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1", nil))
|
||||||
|
if w.Code != http.StatusMovedPermanently {
|
||||||
|
t.Errorf("status = %d, want 301", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unknown model 404", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/nope/v1", nil))
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Errorf("status = %d, want 404", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_HandleMetrics_Unavailable(t *testing.T) {
|
||||||
|
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/metrics", nil))
|
||||||
|
if w.Code != http.StatusServiceUnavailable {
|
||||||
|
t.Errorf("status = %d, want 503", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_Redirects(t *testing.T) {
|
||||||
|
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||||
|
|
||||||
|
for path, want := range map[string]string{"/": "/ui", "/upstream": "/ui/models"} {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, path, nil))
|
||||||
|
if w.Code != http.StatusFound {
|
||||||
|
t.Errorf("%s: status = %d, want 302", path, w.Code)
|
||||||
|
}
|
||||||
|
if got := w.Header().Get("Location"); got != want {
|
||||||
|
t.Errorf("%s: Location = %q, want %q", path, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,270 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/event"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/router"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
|
)
|
||||||
|
|
||||||
|
// apiModel is one entry in the /api/events modelStatus payload.
|
||||||
|
type apiModel struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
State string `json:"state"`
|
||||||
|
Unlisted bool `json:"unlisted"`
|
||||||
|
PeerID string `json:"peerID"`
|
||||||
|
Aliases []string `json:"aliases,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelStatus returns every configured model joined with its current process
|
||||||
|
// state (defaulting to "stopped"), followed by peer models.
|
||||||
|
func (s *Server) modelStatus() []apiModel {
|
||||||
|
running := s.local.RunningModels()
|
||||||
|
|
||||||
|
ids := make([]string, 0, len(s.cfg.Models))
|
||||||
|
for id := range s.cfg.Models {
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
sort.Strings(ids)
|
||||||
|
|
||||||
|
models := make([]apiModel, 0, len(ids))
|
||||||
|
for _, id := range ids {
|
||||||
|
mc := s.cfg.Models[id]
|
||||||
|
state := "stopped"
|
||||||
|
if st, ok := running[id]; ok {
|
||||||
|
state = string(st)
|
||||||
|
}
|
||||||
|
models = append(models, apiModel{
|
||||||
|
Id: id,
|
||||||
|
Name: mc.Name,
|
||||||
|
Description: mc.Description,
|
||||||
|
State: state,
|
||||||
|
Unlisted: mc.Unlisted,
|
||||||
|
Aliases: mc.Aliases,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
for peerID, peer := range s.cfg.Peers {
|
||||||
|
for _, modelID := range peer.Models {
|
||||||
|
models = append(models, apiModel{Id: modelID, PeerID: peerID})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAPIUnloadAll stops every running local process.
|
||||||
|
func (s *Server) handleAPIUnloadAll(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.local.Unload(apiUnloadTimeout)
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"msg": "ok"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAPIUnloadModel stops a single named local process.
|
||||||
|
func (s *Server) handleAPIUnloadModel(w http.ResponseWriter, r *http.Request) {
|
||||||
|
requested := strings.TrimPrefix(r.PathValue("model"), "/")
|
||||||
|
realName, found := s.cfg.RealModelName(requested)
|
||||||
|
if !found {
|
||||||
|
router.SendResponse(w, r, http.StatusNotFound, "model not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !s.local.Handles(realName) {
|
||||||
|
router.SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.local.Unload(apiUnloadTimeout, realName)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAPIMetrics serves the activity log as a JSON array.
|
||||||
|
func (s *Server) handleAPIMetrics(w http.ResponseWriter, r *http.Request) {
|
||||||
|
data, err := s.metrics.getMetricsJSON()
|
||||||
|
if err != nil {
|
||||||
|
router.SendResponse(w, r, http.StatusInternalServerError, "failed to get metrics")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAPIPerformance serves the buffered system/GPU stats, optionally
|
||||||
|
// filtered to samples after the ?after=<RFC3339> timestamp.
|
||||||
|
func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if s.perf == nil {
|
||||||
|
router.SendResponse(w, r, http.StatusServiceUnavailable, "performance monitor not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sysStats, gpuStats := s.perf.Current()
|
||||||
|
|
||||||
|
if afterStr := r.URL.Query().Get("after"); afterStr != "" {
|
||||||
|
after, err := time.Parse(time.RFC3339, afterStr)
|
||||||
|
if err != nil {
|
||||||
|
router.SendResponse(w, r, http.StatusBadRequest, "invalid 'after' timestamp, use RFC3339 format")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filteredSys := make([]perf.SysStat, 0, len(sysStats))
|
||||||
|
for _, st := range sysStats {
|
||||||
|
if st.Timestamp.After(after) {
|
||||||
|
filteredSys = append(filteredSys, st)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sysStats = filteredSys
|
||||||
|
|
||||||
|
filteredGpu := make([]perf.GpuStat, 0, len(gpuStats))
|
||||||
|
for _, g := range gpuStats {
|
||||||
|
if g.Timestamp.After(after) {
|
||||||
|
filteredGpu = append(filteredGpu, g)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
gpuStats = filteredGpu
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"sys_stats": sysStats,
|
||||||
|
"gpu_stats": gpuStats,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAPIVersion serves the build metadata.
|
||||||
|
func (s *Server) handleAPIVersion(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"version": s.build.Version,
|
||||||
|
"commit": s.build.Commit,
|
||||||
|
"build_date": s.build.Date,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAPICapture returns the stored request/response capture for a metric ID.
|
||||||
|
func (s *Server) handleAPICapture(w http.ResponseWriter, r *http.Request) {
|
||||||
|
id, err := strconv.Atoi(r.PathValue("id"))
|
||||||
|
if err != nil {
|
||||||
|
router.SendResponse(w, r, http.StatusBadRequest, "invalid capture ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
capture := s.metrics.getCaptureByID(id)
|
||||||
|
if capture == nil {
|
||||||
|
router.SendResponse(w, r, http.StatusNotFound, "capture not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonBytes, err := json.Marshal(capture)
|
||||||
|
if err != nil {
|
||||||
|
router.SendResponse(w, r, http.StatusInternalServerError, "failed to marshal capture")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write(jsonBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
type messageType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
msgTypeModelStatus messageType = "modelStatus"
|
||||||
|
msgTypeLogData messageType = "logData"
|
||||||
|
msgTypeMetrics messageType = "metrics"
|
||||||
|
msgTypeInFlight messageType = "inflight"
|
||||||
|
)
|
||||||
|
|
||||||
|
type messageEnvelope struct {
|
||||||
|
Type messageType `json:"type"`
|
||||||
|
Data string `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAPIEvents streams server events (model status, log data, metrics,
|
||||||
|
// in-flight counts) to the client as Server-Sent Events.
|
||||||
|
func (s *Server) handleAPIEvents(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
|
w.Header().Set("Connection", "keep-alive")
|
||||||
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||||
|
// prevent nginx from buffering SSE
|
||||||
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
router.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// internal/event already has a 50K event buffer
|
||||||
|
// a 1K message buffer should be enough, watch the logs for the warning that the sendBuffer is full
|
||||||
|
sendBuffer := make(chan messageEnvelope, 1024)
|
||||||
|
ctx, cancel := context.WithCancel(r.Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
send := func(msg messageEnvelope) {
|
||||||
|
select {
|
||||||
|
case sendBuffer <- msg:
|
||||||
|
case <-ctx.Done():
|
||||||
|
s.proxylog.Warn("handleAPIEvents send suppressed due to context done")
|
||||||
|
default:
|
||||||
|
s.proxylog.Warn("handleAPIEvents sendBuffer full, dropped message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sendModels := func() {
|
||||||
|
if data, err := json.Marshal(s.modelStatus()); err == nil {
|
||||||
|
send(messageEnvelope{Type: msgTypeModelStatus, Data: string(data)})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sendLogData := func(source string, data []byte) {
|
||||||
|
if j, err := json.Marshal(map[string]string{"source": source, "data": string(data)}); err == nil {
|
||||||
|
send(messageEnvelope{Type: msgTypeLogData, Data: string(j)})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sendMetrics := func(metrics []ActivityLogEntry) {
|
||||||
|
if j, err := json.Marshal(metrics); err == nil {
|
||||||
|
send(messageEnvelope{Type: msgTypeMetrics, Data: string(j)})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sendInFlight := func(total int) {
|
||||||
|
if j, err := json.Marshal(map[string]int{"total": total}); err == nil {
|
||||||
|
send(messageEnvelope{Type: msgTypeInFlight, Data: string(j)})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
defer event.On(func(e shared.ProcessStateChangeEvent) { sendModels() })()
|
||||||
|
defer event.On(func(e shared.ConfigFileChangedEvent) { sendModels() })()
|
||||||
|
defer s.proxylog.OnLogData(func(data []byte) { sendLogData("proxy", data) })()
|
||||||
|
defer s.upstreamlog.OnLogData(func(data []byte) { sendLogData("upstream", data) })()
|
||||||
|
defer event.On(func(e ActivityLogEvent) { sendMetrics([]ActivityLogEntry{e.Metrics}) })()
|
||||||
|
defer event.On(func(e shared.InFlightRequestsEvent) { sendInFlight(e.Total) })()
|
||||||
|
|
||||||
|
// initial payload
|
||||||
|
sendLogData("proxy", s.proxylog.GetHistory())
|
||||||
|
sendLogData("upstream", s.upstreamlog.GetHistory())
|
||||||
|
sendModels()
|
||||||
|
sendMetrics(s.metrics.getMetrics())
|
||||||
|
sendInFlight(int(s.inflight.Current()))
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-r.Context().Done():
|
||||||
|
return
|
||||||
|
case <-s.shutdownCtx.Done():
|
||||||
|
return
|
||||||
|
case msg := <-sendBuffer:
|
||||||
|
data, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, "event:message\ndata:%s\n\n", data)
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,103 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServer_InflightMiddleware(t *testing.T) {
|
||||||
|
c := &inflightCounter{}
|
||||||
|
mw := CreateInflightMiddleware(c)
|
||||||
|
|
||||||
|
var duringRequest int64
|
||||||
|
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
duringRequest = c.Current()
|
||||||
|
}))
|
||||||
|
|
||||||
|
handler.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil))
|
||||||
|
|
||||||
|
if duringRequest != 1 {
|
||||||
|
t.Errorf("counter during request = %d, want 1", duringRequest)
|
||||||
|
}
|
||||||
|
if got := c.Current(); got != 0 {
|
||||||
|
t.Errorf("counter after request = %d, want 0", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_APIVersion(t *testing.T) {
|
||||||
|
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||||
|
s.build = BuildInfo{Version: "1.2.3", Commit: "deadbeef", Date: "2026-05-19"}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/version", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status = %d", w.Code)
|
||||||
|
}
|
||||||
|
var got map[string]string
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if got["version"] != "1.2.3" || got["commit"] != "deadbeef" || got["build_date"] != "2026-05-19" {
|
||||||
|
t.Errorf("body = %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_APIMetrics_Empty(t *testing.T) {
|
||||||
|
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/metrics", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status = %d", w.Code)
|
||||||
|
}
|
||||||
|
if body := strings.TrimSpace(w.Body.String()); body != "[]" {
|
||||||
|
t.Errorf("body = %q, want []", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_APIPerformance_Unavailable(t *testing.T) {
|
||||||
|
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/performance", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusServiceUnavailable {
|
||||||
|
t.Errorf("status = %d, want 503", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_APIEvents_InitialPayload(t *testing.T) {
|
||||||
|
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/events", nil).WithContext(ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
s.ServeHTTP(w, req)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("handler did not return after context cancel")
|
||||||
|
}
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
for _, want := range []string{`"type":"modelStatus"`, `"type":"inflight"`, `"type":"logData"`} {
|
||||||
|
if !strings.Contains(body, want) {
|
||||||
|
t.Errorf("initial SSE payload missing %s; body=%q", want, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,135 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/router"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateAuthMiddleware returns middleware that validates API keys when the
|
||||||
|
// config declares any. It accepts the key via Authorization: Bearer,
|
||||||
|
// Authorization: Basic (password field), or x-api-key. On success the auth
|
||||||
|
// headers are stripped so they never leak to upstream. When no keys are
|
||||||
|
// configured the middleware is a pass-through.
|
||||||
|
func CreateAuthMiddleware(cfg config.Config) chain.Middleware {
|
||||||
|
keys := cfg.RequiredAPIKeys
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return next
|
||||||
|
}
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
provided := extractAPIKey(r)
|
||||||
|
|
||||||
|
valid := false
|
||||||
|
for _, key := range keys {
|
||||||
|
if provided == key {
|
||||||
|
valid = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !valid {
|
||||||
|
w.Header().Set("WWW-Authenticate", `Basic realm="llama-swap"`)
|
||||||
|
router.SendResponse(w, r, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Header.Del("Authorization")
|
||||||
|
r.Header.Del("x-api-key")
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractAPIKey pulls a candidate API key from the request, preferring Basic,
|
||||||
|
// then Bearer, then x-api-key.
|
||||||
|
func extractAPIKey(r *http.Request) string {
|
||||||
|
var bearerKey, basicKey string
|
||||||
|
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||||
|
if strings.HasPrefix(auth, "Bearer ") {
|
||||||
|
bearerKey = strings.TrimPrefix(auth, "Bearer ")
|
||||||
|
} else if strings.HasPrefix(auth, "Basic ") {
|
||||||
|
encoded := strings.TrimPrefix(auth, "Basic ")
|
||||||
|
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
|
||||||
|
if parts := strings.SplitN(string(decoded), ":", 2); len(parts) == 2 {
|
||||||
|
basicKey = parts[1] // password field is the API key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case basicKey != "":
|
||||||
|
return basicKey
|
||||||
|
case bearerKey != "":
|
||||||
|
return bearerKey
|
||||||
|
default:
|
||||||
|
return r.Header.Get("x-api-key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateCORSMiddleware returns middleware that answers OPTIONS preflight
|
||||||
|
// requests with permissive CORS headers (see issues #81, #77, #42). Non-OPTIONS
|
||||||
|
// requests pass through untouched.
|
||||||
|
func CreateCORSMiddleware() chain.Middleware {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodOptions {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||||
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||||
|
if headers := r.Header.Get("Access-Control-Request-Headers"); headers != "" {
|
||||||
|
w.Header().Set("Access-Control-Allow-Headers", sanitizeAccessControlRequestHeaderValues(headers))
|
||||||
|
} else {
|
||||||
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Accept, X-Requested-With")
|
||||||
|
}
|
||||||
|
w.Header().Set("Access-Control-Max-Age", "86400")
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTokenChar(r rune) bool {
|
||||||
|
switch {
|
||||||
|
case r >= 'a' && r <= 'z':
|
||||||
|
case r >= 'A' && r <= 'Z':
|
||||||
|
case r >= '0' && r <= '9':
|
||||||
|
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitizeAccessControlRequestHeaderValues drops any header names that contain
|
||||||
|
// characters outside the HTTP token grammar before echoing them back.
|
||||||
|
func sanitizeAccessControlRequestHeaderValues(headerValues string) string {
|
||||||
|
parts := strings.Split(headerValues, ",")
|
||||||
|
valid := make([]string, 0, len(parts))
|
||||||
|
|
||||||
|
for _, p := range parts {
|
||||||
|
v := strings.TrimSpace(p)
|
||||||
|
if v == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
validPart := true
|
||||||
|
for _, c := range v {
|
||||||
|
if !isTokenChar(c) {
|
||||||
|
validPart = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if validPart {
|
||||||
|
valid = append(valid, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(valid, ", ")
|
||||||
|
}
|
||||||
@@ -0,0 +1,120 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServer_ExtractAPIKey(t *testing.T) {
|
||||||
|
basicHeader := func(user, pass string) string {
|
||||||
|
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
|
||||||
|
}
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
auth string
|
||||||
|
xapi string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"none", "", "", ""},
|
||||||
|
{"bearer", "Bearer tok123", "", "tok123"},
|
||||||
|
{"basic", basicHeader("user", "pw-key"), "", "pw-key"},
|
||||||
|
{"x-api-key", "", "xkey", "xkey"},
|
||||||
|
{"basic beats bearer", basicHeader("u", "bk"), "", "bk"},
|
||||||
|
{"bearer beats x-api-key", "Bearer btok", "xkey", "btok"},
|
||||||
|
{"malformed basic falls back to x-api-key", "Basic !!!notbase64", "xkey", "xkey"},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if c.auth != "" {
|
||||||
|
r.Header.Set("Authorization", c.auth)
|
||||||
|
}
|
||||||
|
if c.xapi != "" {
|
||||||
|
r.Header.Set("x-api-key", c.xapi)
|
||||||
|
}
|
||||||
|
if got := extractAPIKey(r); got != c.want {
|
||||||
|
t.Errorf("extractAPIKey() = %q, want %q", got, c.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_SanitizeAccessControlRequestHeaders(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
in string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"Content-Type, Authorization", "Content-Type, Authorization"},
|
||||||
|
{" X-Custom , Accept ", "X-Custom, Accept"},
|
||||||
|
{"Valid, Bad Header", "Valid"},
|
||||||
|
{"Bad@Header", ""},
|
||||||
|
{"", ""},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
if got := sanitizeAccessControlRequestHeaderValues(c.in); got != c.want {
|
||||||
|
t.Errorf("sanitize(%q) = %q, want %q", c.in, got, c.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_IsTokenChar(t *testing.T) {
|
||||||
|
for _, r := range "abcXYZ0129!#$%&'*+-.^_`|~" {
|
||||||
|
if !isTokenChar(r) {
|
||||||
|
t.Errorf("isTokenChar(%q) = false, want true", r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, r := range " @()/\t\"" {
|
||||||
|
if isTokenChar(r) {
|
||||||
|
t.Errorf("isTokenChar(%q) = true, want false", r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_AuthMiddleware(t *testing.T) {
|
||||||
|
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Header.Get("Authorization") != "" || r.Header.Get("x-api-key") != "" {
|
||||||
|
t.Error("auth headers leaked to upstream")
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no keys configured passes through", func(t *testing.T) {
|
||||||
|
mw := CreateAuthMiddleware(config.Config{})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
mw(final).ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("status = %d, want 200", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
cfg := config.Config{RequiredAPIKeys: []string{"secret"}}
|
||||||
|
|
||||||
|
t.Run("valid key", func(t *testing.T) {
|
||||||
|
mw := CreateAuthMiddleware(cfg)
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
r.Header.Set("Authorization", "Bearer secret")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
mw(final).ServeHTTP(w, r)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("status = %d, want 200", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid key", func(t *testing.T) {
|
||||||
|
mw := CreateAuthMiddleware(cfg)
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
r.Header.Set("Authorization", "Bearer wrong")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
mw(final).ServeHTTP(w, r)
|
||||||
|
if w.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("status = %d, want 401", w.Code)
|
||||||
|
}
|
||||||
|
if w.Header().Get("WWW-Authenticate") == "" {
|
||||||
|
t.Error("missing WWW-Authenticate header")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,176 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/fxamacker/cbor/v2"
|
||||||
|
"github.com/klauspost/compress/zstd"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReqRespCapture is a stored request/response pair for a single metered request.
|
||||||
|
type ReqRespCapture struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
ReqPath string `json:"req_path"`
|
||||||
|
ReqHeaders map[string]string `json:"req_headers"`
|
||||||
|
ReqBody []byte `json:"req_body"`
|
||||||
|
RespHeaders map[string]string `json:"resp_headers"`
|
||||||
|
RespBody []byte `json:"resp_body"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// captureFields is a bitmask controlling what a route stores in a ReqRespCapture.
|
||||||
|
type captureFields uint
|
||||||
|
|
||||||
|
const (
|
||||||
|
captureReqHeaders captureFields = 1 << iota
|
||||||
|
captureReqBody
|
||||||
|
captureRespHeaders
|
||||||
|
captureRespBody
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
captureReqAll = captureReqHeaders | captureReqBody
|
||||||
|
captureRespAll = captureRespHeaders | captureRespBody
|
||||||
|
captureAll = captureReqAll | captureRespAll
|
||||||
|
)
|
||||||
|
|
||||||
|
// captureFieldsByPath overrides the default capture mask for routes carrying
|
||||||
|
// large binary payloads (audio/image) where storing the full body is wasteful.
|
||||||
|
var captureFieldsByPath = map[string]captureFields{
|
||||||
|
"/v1/audio/speech": captureReqAll | captureRespHeaders,
|
||||||
|
"/v1/audio/voices": captureReqHeaders | captureRespAll,
|
||||||
|
"/v1/audio/transcriptions": captureReqHeaders | captureRespHeaders | captureRespBody,
|
||||||
|
"/v1/images/generations": captureReqAll | captureRespHeaders,
|
||||||
|
"/v1/images/edits": captureReqHeaders | captureRespHeaders,
|
||||||
|
"/sdapi/v1/txt2img": captureReqAll | captureRespHeaders,
|
||||||
|
"/sdapi/v1/img2img": captureReqHeaders | captureRespHeaders,
|
||||||
|
}
|
||||||
|
|
||||||
|
// captureFieldsFor returns the capture mask for a request path. Unlisted routes
|
||||||
|
// (the OpenAI-compatible JSON endpoints) capture everything.
|
||||||
|
func captureFieldsFor(path string) captureFields {
|
||||||
|
if cf, ok := captureFieldsByPath[path]; ok {
|
||||||
|
return cf
|
||||||
|
}
|
||||||
|
return captureAll
|
||||||
|
}
|
||||||
|
|
||||||
|
// zstdEncOptions are the shared zstd encoder options for maximum compression.
|
||||||
|
var zstdEncOptions = []zstd.EOption{
|
||||||
|
zstd.WithEncoderLevel(zstd.SpeedBetterCompression),
|
||||||
|
}
|
||||||
|
|
||||||
|
// zstdEncPool pools zstd.Encoder instances to reduce allocations.
|
||||||
|
var zstdEncPool = &sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
enc, _ := zstd.NewWriter(nil, zstdEncOptions...)
|
||||||
|
return enc
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// zstdDecPool pools zstd.Decoder instances to reduce allocations.
|
||||||
|
var zstdDecPool = &sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
dec, _ := zstd.NewReader(nil)
|
||||||
|
return dec
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// compressCapture marshals a ReqRespCapture to CBOR and compresses it with zstd.
|
||||||
|
// Returns the compressed bytes and the original CBOR byte count for logging.
|
||||||
|
func compressCapture(c *ReqRespCapture) ([]byte, int, error) {
|
||||||
|
cborBytes, err := cbor.Marshal(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("marshal capture: %w", err)
|
||||||
|
}
|
||||||
|
zenc := zstdEncPool.Get().(*zstd.Encoder)
|
||||||
|
defer zstdEncPool.Put(zenc)
|
||||||
|
return zenc.EncodeAll(cborBytes, nil), len(cborBytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decompressCapture decompresses zstd-compressed CBOR into a ReqRespCapture.
|
||||||
|
func decompressCapture(data []byte) (*ReqRespCapture, error) {
|
||||||
|
dec := zstdDecPool.Get().(*zstd.Decoder)
|
||||||
|
defer zstdDecPool.Put(dec)
|
||||||
|
cborBytes, err := dec.DecodeAll(data, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decompress capture: %w", err)
|
||||||
|
}
|
||||||
|
var capture ReqRespCapture
|
||||||
|
if err := cbor.Unmarshal(cborBytes, &capture); err != nil {
|
||||||
|
return nil, fmt.Errorf("unmarshal capture: %w", err)
|
||||||
|
}
|
||||||
|
return &capture, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addCapture compresses and stores a capture in the cache. Returns true if the
|
||||||
|
// capture was stored.
|
||||||
|
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) bool {
|
||||||
|
if !mp.enableCaptures {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
compressed, uncompressedBytes, err := compressCapture(&capture)
|
||||||
|
if err != nil {
|
||||||
|
mp.logger.Warnf("failed to compress capture: %v, skipping", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mp.captureCache.Add(capture.ID, compressed); err != nil {
|
||||||
|
mp.logger.Warnf("capture %d too large (%d bytes), skipping: %v", capture.ID, len(compressed), err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
compressionRatio := (1 - float64(len(compressed))/float64(uncompressedBytes)) * 100
|
||||||
|
mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCaptureByID decompresses and unmarshals a capture by ID. Returns nil if
|
||||||
|
// the capture is not found or decompression fails.
|
||||||
|
func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture {
|
||||||
|
if mp.captureCache == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
data, err := mp.captureCache.Get(id)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
capture, err := decompressCapture(data)
|
||||||
|
if err != nil {
|
||||||
|
mp.logger.Warnf("failed to decompress capture %d: %v", id, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return capture
|
||||||
|
}
|
||||||
|
|
||||||
|
// sensitiveHeaders lists headers that are redacted in captures.
|
||||||
|
var sensitiveHeaders = map[string]bool{
|
||||||
|
"authorization": true,
|
||||||
|
"proxy-authorization": true,
|
||||||
|
"cookie": true,
|
||||||
|
"set-cookie": true,
|
||||||
|
"x-api-key": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// headerMap flattens an http.Header to a single-value map.
|
||||||
|
func headerMap(h http.Header) map[string]string {
|
||||||
|
m := make(map[string]string, len(h))
|
||||||
|
for key, values := range h {
|
||||||
|
if len(values) > 0 {
|
||||||
|
m[key] = values[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// redactHeaders replaces sensitive header values in-place with "[REDACTED]".
|
||||||
|
func redactHeaders(headers map[string]string) {
|
||||||
|
for key := range headers {
|
||||||
|
if sensitiveHeaders[strings.ToLower(key)] {
|
||||||
|
headers[key] = "[REDACTED]"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,79 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServer_CaptureCompressRoundtrip(t *testing.T) {
|
||||||
|
orig := &ReqRespCapture{
|
||||||
|
ID: 7,
|
||||||
|
ReqPath: "/v1/chat/completions",
|
||||||
|
ReqHeaders: map[string]string{"Content-Type": "application/json"},
|
||||||
|
ReqBody: []byte(`{"model":"m"}`),
|
||||||
|
RespHeaders: map[string]string{"Content-Type": "application/json"},
|
||||||
|
RespBody: []byte(`{"usage":{}}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
compressed, uncompressed, err := compressCapture(orig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("compressCapture: %v", err)
|
||||||
|
}
|
||||||
|
if uncompressed == 0 || len(compressed) == 0 {
|
||||||
|
t.Fatalf("unexpected sizes: uncompressed=%d compressed=%d", uncompressed, len(compressed))
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := decompressCapture(compressed)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decompressCapture: %v", err)
|
||||||
|
}
|
||||||
|
if got.ID != orig.ID || got.ReqPath != orig.ReqPath ||
|
||||||
|
!bytes.Equal(got.ReqBody, orig.ReqBody) || !bytes.Equal(got.RespBody, orig.RespBody) {
|
||||||
|
t.Fatalf("roundtrip mismatch: %+v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_CaptureStoreAndRetrieve(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 5)
|
||||||
|
if !mm.enableCaptures {
|
||||||
|
t.Fatal("captures should be enabled with non-zero buffer")
|
||||||
|
}
|
||||||
|
|
||||||
|
capture := ReqRespCapture{ID: 3, ReqPath: "/v1/chat/completions", ReqBody: []byte("hello")}
|
||||||
|
if !mm.addCapture(capture) {
|
||||||
|
t.Fatal("addCapture returned false")
|
||||||
|
}
|
||||||
|
|
||||||
|
got := mm.getCaptureByID(3)
|
||||||
|
if got == nil || !bytes.Equal(got.ReqBody, []byte("hello")) {
|
||||||
|
t.Fatalf("getCaptureByID = %+v", got)
|
||||||
|
}
|
||||||
|
if mm.getCaptureByID(999) != nil {
|
||||||
|
t.Fatal("expected nil for unknown capture ID")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_CaptureDisabled(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 0)
|
||||||
|
if mm.enableCaptures {
|
||||||
|
t.Fatal("captures should be disabled with zero buffer")
|
||||||
|
}
|
||||||
|
if mm.addCapture(ReqRespCapture{ID: 1}) {
|
||||||
|
t.Fatal("addCapture should return false when disabled")
|
||||||
|
}
|
||||||
|
if mm.getCaptureByID(1) != nil {
|
||||||
|
t.Fatal("getCaptureByID should return nil when disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_CaptureFieldsFor(t *testing.T) {
|
||||||
|
if got := captureFieldsFor("/v1/chat/completions"); got != captureAll {
|
||||||
|
t.Fatalf("default = %b, want captureAll", got)
|
||||||
|
}
|
||||||
|
if got := captureFieldsFor("/v1/audio/speech"); got != captureReqAll|captureRespHeaders {
|
||||||
|
t.Fatalf("/v1/audio/speech = %b", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"golang.org/x/sync/semaphore"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/router"
|
||||||
|
)
|
||||||
|
|
||||||
|
// defaultConcurrencyLimit caps simultaneous in-flight requests per model when
|
||||||
|
// the model config leaves concurrencyLimit unset. Matches the legacy
|
||||||
|
// proxy.Process default.
|
||||||
|
const defaultConcurrencyLimit = 10
|
||||||
|
|
||||||
|
// CreateConcurrencyMiddleware returns middleware that limits simultaneous
|
||||||
|
// model-dispatched requests per model. Each model gets a semaphore sized to
|
||||||
|
// its concurrencyLimit (or defaultConcurrencyLimit). A request that cannot
|
||||||
|
// immediately acquire a slot is rejected with 429. Models without a local
|
||||||
|
// config entry (e.g. peer-routed models) are not limited.
|
||||||
|
func CreateConcurrencyMiddleware(cfg config.Config) chain.Middleware {
|
||||||
|
semaphores := make(map[string]*semaphore.Weighted, len(cfg.Models))
|
||||||
|
for id, mc := range cfg.Models {
|
||||||
|
limit := defaultConcurrencyLimit
|
||||||
|
if mc.ConcurrencyLimit > 0 {
|
||||||
|
limit = mc.ConcurrencyLimit
|
||||||
|
}
|
||||||
|
semaphores[id] = semaphore.NewWeighted(int64(limit))
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
data, err := router.FetchContext(r, cfg)
|
||||||
|
if err != nil {
|
||||||
|
router.SendError(w, r, router.ErrNoModelInContext)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// fall through for peer models
|
||||||
|
sem, ok := semaphores[data.ModelID]
|
||||||
|
if !ok {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !sem.TryAcquire(1) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
w.Write([]byte(`{"error":"Too many requests"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer sem.Release(1)
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,75 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/router"
|
||||||
|
)
|
||||||
|
|
||||||
|
func concurrencyTestReq(model string) *http.Request {
|
||||||
|
r := httptest.NewRequest("GET", "/v1/chat/completions", nil)
|
||||||
|
return r.WithContext(router.SetContext(r.Context(), router.ReqContextData{Model: model, ModelID: model}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_ConcurrencyMiddleware_RejectsOverLimit(t *testing.T) {
|
||||||
|
cfg := config.Config{
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"m1": {ConcurrencyLimit: 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
entered := make(chan struct{})
|
||||||
|
release := make(chan struct{})
|
||||||
|
var once sync.Once
|
||||||
|
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
once.Do(func() { close(entered) })
|
||||||
|
<-release
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
h := CreateConcurrencyMiddleware(cfg)(final)
|
||||||
|
|
||||||
|
// First request occupies the only slot.
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
h.ServeHTTP(httptest.NewRecorder(), concurrencyTestReq("m1"))
|
||||||
|
}()
|
||||||
|
<-entered
|
||||||
|
|
||||||
|
// Second concurrent request is rejected with 429.
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, concurrencyTestReq("m1"))
|
||||||
|
if w.Code != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("over-limit status = %d, want 429", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Once the slot frees, a new request succeeds.
|
||||||
|
close(release)
|
||||||
|
<-done
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, concurrencyTestReq("m1"))
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("post-release status = %d, want 200", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_ConcurrencyMiddleware_UnconfiguredModelPassesThrough(t *testing.T) {
|
||||||
|
cfg := config.Config{Models: map[string]config.ModelConfig{}}
|
||||||
|
|
||||||
|
called := 0
|
||||||
|
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
called++
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
h := CreateConcurrencyMiddleware(cfg)(final)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, concurrencyTestReq("peer-model"))
|
||||||
|
if w.Code != http.StatusOK || called != 1 {
|
||||||
|
t.Fatalf("unconfigured model: status=%d called=%d, want 200/1", w.Code, called)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,256 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"compress/flate"
|
||||||
|
"compress/gzip"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServer_DecompressBody(t *testing.T) {
|
||||||
|
plain := []byte("hello world")
|
||||||
|
|
||||||
|
var gz bytes.Buffer
|
||||||
|
gw := gzip.NewWriter(&gz)
|
||||||
|
gw.Write(plain)
|
||||||
|
gw.Close()
|
||||||
|
|
||||||
|
var fl bytes.Buffer
|
||||||
|
fw, _ := flate.NewWriter(&fl, flate.DefaultCompression)
|
||||||
|
fw.Write(plain)
|
||||||
|
fw.Close()
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
body []byte
|
||||||
|
encoding string
|
||||||
|
}{
|
||||||
|
{"plain", plain, ""},
|
||||||
|
{"gzip", gz.Bytes(), "gzip"},
|
||||||
|
{"deflate", fl.Bytes(), "deflate"},
|
||||||
|
{"unknown passthrough", plain, "br"},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
got, err := decompressBody(c.body, c.encoding)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decompressBody: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(got, plain) {
|
||||||
|
t.Errorf("got %q, want %q", got, plain)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_FilterAcceptEncoding(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
in string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"", ""},
|
||||||
|
{"gzip, deflate, br", "gzip, deflate"},
|
||||||
|
{"br, zstd", ""},
|
||||||
|
{"gzip;q=1.0", "gzip;q=1.0"},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
if got := filterAcceptEncoding(c.in); got != c.want {
|
||||||
|
t.Errorf("filterAcceptEncoding(%q) = %q, want %q", c.in, got, c.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_BodyCopier_Flush(t *testing.T) {
|
||||||
|
bc := newBodyCopier(httptest.NewRecorder())
|
||||||
|
bc.Write([]byte("data"))
|
||||||
|
bc.Flush()
|
||||||
|
if bc.Status() != http.StatusOK {
|
||||||
|
t.Errorf("status = %d, want 200", bc.Status())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// hijackRecorder is an httptest.ResponseRecorder that also implements
|
||||||
|
// http.Hijacker, returning a pipe so Hijack forwarding can be exercised.
|
||||||
|
type hijackRecorder struct {
|
||||||
|
*httptest.ResponseRecorder
|
||||||
|
conn net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *hijackRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
return h.conn, bufio.NewReadWriter(bufio.NewReader(h.conn), bufio.NewWriter(h.conn)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_BodyCopier_Hijack(t *testing.T) {
|
||||||
|
t.Run("forwards to underlying hijacker", func(t *testing.T) {
|
||||||
|
client, server := net.Pipe()
|
||||||
|
defer client.Close()
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
bc := newBodyCopier(&hijackRecorder{httptest.NewRecorder(), server})
|
||||||
|
conn, _, err := bc.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Hijack: %v", err)
|
||||||
|
}
|
||||||
|
if conn != server {
|
||||||
|
t.Errorf("Hijack returned unexpected conn")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("errors when underlying writer is not a hijacker", func(t *testing.T) {
|
||||||
|
bc := newBodyCopier(httptest.NewRecorder())
|
||||||
|
if _, _, err := bc.Hijack(); err == nil {
|
||||||
|
t.Error("expected error hijacking a non-Hijacker ResponseWriter")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_BodyCopier_SkipsBufferingOnUpgrade(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
bc := newBodyCopier(rec)
|
||||||
|
bc.WriteHeader(http.StatusSwitchingProtocols)
|
||||||
|
bc.Write([]byte("websocket frame bytes"))
|
||||||
|
|
||||||
|
if bc.body.Len() != 0 {
|
||||||
|
t.Errorf("upgrade body buffered = %q, want empty", bc.body.Bytes())
|
||||||
|
}
|
||||||
|
if got := rec.Body.String(); got != "websocket frame bytes" {
|
||||||
|
t.Errorf("client body = %q, want %q", got, "websocket frame bytes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_HeaderMapAndRedact(t *testing.T) {
|
||||||
|
h := http.Header{
|
||||||
|
"Content-Type": {"application/json"},
|
||||||
|
"Authorization": {"Bearer secret"},
|
||||||
|
"X-Api-Key": {"key123"},
|
||||||
|
}
|
||||||
|
m := headerMap(h)
|
||||||
|
if m["Content-Type"] != "application/json" {
|
||||||
|
t.Errorf("Content-Type = %q", m["Content-Type"])
|
||||||
|
}
|
||||||
|
|
||||||
|
redactHeaders(m)
|
||||||
|
if m["Authorization"] != "[REDACTED]" || m["X-Api-Key"] != "[REDACTED]" {
|
||||||
|
t.Errorf("sensitive headers not redacted: %v", m)
|
||||||
|
}
|
||||||
|
if m["Content-Type"] != "application/json" {
|
||||||
|
t.Error("non-sensitive header should not be redacted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_StripVersionPrefix(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/v/v1/chat", nil)
|
||||||
|
stripVersionPrefix(r)
|
||||||
|
if r.URL.Path != "/v1/chat" {
|
||||||
|
t.Errorf("path = %q, want /v1/chat", r.URL.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
r2 := httptest.NewRequest(http.MethodGet, "/v1/chat", nil)
|
||||||
|
stripVersionPrefix(r2)
|
||||||
|
if r2.URL.Path != "/v1/chat" {
|
||||||
|
t.Errorf("path = %q, want unchanged", r2.URL.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_CloseStreams(t *testing.T) {
|
||||||
|
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||||
|
s.CloseStreams()
|
||||||
|
select {
|
||||||
|
case <-s.shutdownCtx.Done():
|
||||||
|
default:
|
||||||
|
t.Error("CloseStreams did not cancel shutdown context")
|
||||||
|
}
|
||||||
|
s.CloseStreams() // idempotent
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_HandleUIAndFavicon(t *testing.T) {
|
||||||
|
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||||
|
|
||||||
|
for _, path := range []string{"/ui/", "/favicon.ico"} {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, path, nil))
|
||||||
|
// The embedded ui_dist only carries placeholder.txt in test builds, so
|
||||||
|
// these resolve to 404 — the handlers still execute end to end.
|
||||||
|
if w.Code != http.StatusOK && w.Code != http.StatusNotFound {
|
||||||
|
t.Errorf("%s: status = %d", path, w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_HandleAPIUnloadAll(t *testing.T) {
|
||||||
|
local := newStubRouter([]string{"m1"}, "")
|
||||||
|
s := newTestServer(local, newStubRouter(nil, ""))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/models/unload", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status = %d", w.Code)
|
||||||
|
}
|
||||||
|
if local.unloadCalls.Load() != 1 {
|
||||||
|
t.Errorf("unloadCalls = %d, want 1", local.unloadCalls.Load())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_HandleAPIUnloadModel(t *testing.T) {
|
||||||
|
local := newStubRouter([]string{"m1"}, "")
|
||||||
|
s := newTestServer(local, newStubRouter(nil, ""))
|
||||||
|
s.cfg = config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
|
||||||
|
|
||||||
|
t.Run("known model", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/models/unload/m1", nil))
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("status = %d, want 200", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unknown model 404", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/models/unload/nope", nil))
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Errorf("status = %d, want 404", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_HandleAPICapture(t *testing.T) {
|
||||||
|
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||||
|
s.metrics = newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 5)
|
||||||
|
s.metrics.addCapture(ReqRespCapture{ID: 42, ReqPath: "/v1/chat/completions"})
|
||||||
|
|
||||||
|
t.Run("found", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/captures/42", nil))
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status = %d", w.Code)
|
||||||
|
}
|
||||||
|
if !bytes.Contains(w.Body.Bytes(), []byte("/v1/chat/completions")) {
|
||||||
|
t.Errorf("body = %q", w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("not found", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/captures/999", nil))
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Errorf("status = %d, want 404", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid id", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/captures/abc", nil))
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("status = %d, want 400", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,218 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/router"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateFilterMiddleware returns middleware that applies per-model request-body
|
||||||
|
// filters to JSON requests before they are forwarded upstream:
|
||||||
|
//
|
||||||
|
// - UseModelName rewrite (issue #69)
|
||||||
|
// - StripParams removal (issue #174)
|
||||||
|
// - SetParams injection (issue #453)
|
||||||
|
// - SetParamsByID per-alias overrides
|
||||||
|
//
|
||||||
|
// Non-JSON requests (GET, multipart forms) pass through untouched. The buffered
|
||||||
|
// body is re-attached with Content-Length / Transfer-Encoding cleanup so the
|
||||||
|
// downstream reverse proxy forwards the correct bytes (see issue #11).
|
||||||
|
func CreateFilterMiddleware(cfg config.Config) chain.Middleware {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !strings.Contains(r.Header.Get("Content-Type"), "application/json") {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := router.FetchContext(r, cfg)
|
||||||
|
if err != nil {
|
||||||
|
router.SendError(w, r, router.ErrNoModelInContext)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
useModelName, filters, ok := resolveFilters(cfg, data.Model)
|
||||||
|
if !ok {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
router.SendResponse(w, r, http.StatusBadRequest, "could not read request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err = applyFilters(body, data.Model, useModelName, filters)
|
||||||
|
if err != nil {
|
||||||
|
router.SendResponse(w, r, http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||||
|
r.Header.Del("Transfer-Encoding")
|
||||||
|
r.Header.Set("Content-Length", strconv.Itoa(len(body)))
|
||||||
|
r.ContentLength = int64(len(body))
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateFormFilterMiddleware returns middleware that applies the UseModelName
|
||||||
|
// rewrite (issue #69) to multipart/form-data requests before they are forwarded
|
||||||
|
// upstream. JSON-body filters (StripParams, SetParams) do not apply to form
|
||||||
|
// endpoints; only the "model" field is rewritten.
|
||||||
|
//
|
||||||
|
// Non-multipart requests pass through untouched. When a rewrite is needed the
|
||||||
|
// form is reconstructed and re-attached with Content-Type / Content-Length
|
||||||
|
// cleanup so the downstream reverse proxy forwards the correct bytes.
|
||||||
|
func CreateFormFilterMiddleware(cfg config.Config) chain.Middleware {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !strings.Contains(r.Header.Get("Content-Type"), "multipart/form-data") {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := router.FetchContext(r, cfg)
|
||||||
|
if err != nil {
|
||||||
|
router.SendError(w, r, router.ErrNoModelInContext)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
useModelName, _, ok := resolveFilters(cfg, data.Model)
|
||||||
|
if !ok || useModelName == "" {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||||
|
router.SendResponse(w, r, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body, contentType, err := rewriteMultipartModel(r.MultipartForm, useModelName)
|
||||||
|
if err != nil {
|
||||||
|
router.SendResponse(w, r, http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||||
|
r.MultipartForm = nil
|
||||||
|
r.Header.Del("Transfer-Encoding")
|
||||||
|
r.Header.Set("Content-Type", contentType)
|
||||||
|
r.Header.Set("Content-Length", strconv.Itoa(len(body)))
|
||||||
|
r.ContentLength = int64(len(body))
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewriteMultipartModel reconstructs a multipart form, replacing the "model"
|
||||||
|
// field value with useModelName. It returns the encoded body and the matching
|
||||||
|
// Content-Type header (which carries the generated boundary).
|
||||||
|
func rewriteMultipartModel(form *multipart.Form, useModelName string) ([]byte, string, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
mw := multipart.NewWriter(&buf)
|
||||||
|
|
||||||
|
for key, values := range form.Value {
|
||||||
|
for _, value := range values {
|
||||||
|
if key == "model" {
|
||||||
|
value = useModelName
|
||||||
|
}
|
||||||
|
field, err := mw.CreateFormField(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("error recreating form field %s: %w", key, err)
|
||||||
|
}
|
||||||
|
if _, err := field.Write([]byte(value)); err != nil {
|
||||||
|
return nil, "", fmt.Errorf("error writing form field %s: %w", key, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, headers := range form.File {
|
||||||
|
for _, fh := range headers {
|
||||||
|
part, err := mw.CreateFormFile(key, fh.Filename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("error recreating form file %s: %w", key, err)
|
||||||
|
}
|
||||||
|
file, err := fh.Open()
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("error opening uploaded file %s: %w", key, err)
|
||||||
|
}
|
||||||
|
if _, err := io.Copy(part, file); err != nil {
|
||||||
|
file.Close()
|
||||||
|
return nil, "", fmt.Errorf("error copying file data %s: %w", key, err)
|
||||||
|
}
|
||||||
|
file.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mw.Close(); err != nil {
|
||||||
|
return nil, "", fmt.Errorf("error finalizing multipart form: %w", err)
|
||||||
|
}
|
||||||
|
return buf.Bytes(), mw.FormDataContentType(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveFilters returns the filter settings for a requested model. UseModelName
|
||||||
|
// only applies to local models; peers carry filters but no name rewrite.
|
||||||
|
func resolveFilters(cfg config.Config, requested string) (useModelName string, filters config.Filters, ok bool) {
|
||||||
|
if realName, found := cfg.RealModelName(requested); found {
|
||||||
|
mc := cfg.Models[realName]
|
||||||
|
return mc.UseModelName, mc.Filters.Filters, true
|
||||||
|
}
|
||||||
|
for _, peer := range cfg.Peers {
|
||||||
|
for _, m := range peer.Models {
|
||||||
|
if m == requested {
|
||||||
|
return "", peer.Filters, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", config.Filters{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyFilters rewrites the JSON body in place. Order matches the legacy
|
||||||
|
// ProxyManager: useModelName, stripParams, setParams, then setParamsByID (which
|
||||||
|
// can override setParams).
|
||||||
|
func applyFilters(body []byte, requested, useModelName string, f config.Filters) ([]byte, error) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if useModelName != "" {
|
||||||
|
if body, err = sjson.SetBytes(body, "model", useModelName); err != nil {
|
||||||
|
return nil, fmt.Errorf("error rewriting model name in JSON: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, param := range f.SanitizedStripParams() {
|
||||||
|
if body, err = sjson.DeleteBytes(body, param); err != nil {
|
||||||
|
return nil, fmt.Errorf("error stripping parameter %s from request", param)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
setParams, setKeys := f.SanitizedSetParams()
|
||||||
|
for _, key := range setKeys {
|
||||||
|
if body, err = sjson.SetBytes(body, key, setParams[key]); err != nil {
|
||||||
|
return nil, fmt.Errorf("error setting parameter %s in request", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
byID, byIDKeys := f.SanitizedSetParamsByID(requested)
|
||||||
|
for _, key := range byIDKeys {
|
||||||
|
if body, err = sjson.SetBytes(body, key, byID[key]); err != nil {
|
||||||
|
return nil, fmt.Errorf("error setting parameter %s in request", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,132 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServer_ApplyFilters(t *testing.T) {
|
||||||
|
t.Run("useModelName rewrite", func(t *testing.T) {
|
||||||
|
out, err := applyFilters([]byte(`{"model":"alias","temp":1}`), "alias", "real-model", config.Filters{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("applyFilters: %v", err)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "model").String(); got != "real-model" {
|
||||||
|
t.Errorf("model = %q, want real-model", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("strip and set params", func(t *testing.T) {
|
||||||
|
f := config.Filters{
|
||||||
|
StripParams: "temperature",
|
||||||
|
SetParams: map[string]any{"top_p": 0.9},
|
||||||
|
}
|
||||||
|
out, err := applyFilters([]byte(`{"model":"m","temperature":0.7}`), "m", "", f)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("applyFilters: %v", err)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(out, "temperature").Exists() {
|
||||||
|
t.Error("temperature should be stripped")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "top_p").Float(); got != 0.9 {
|
||||||
|
t.Errorf("top_p = %v, want 0.9", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("setParamsByID overrides setParams", func(t *testing.T) {
|
||||||
|
f := config.Filters{
|
||||||
|
SetParams: map[string]any{"top_p": 0.5},
|
||||||
|
SetParamsByID: map[string]map[string]any{"alias": {"top_p": 0.1}},
|
||||||
|
}
|
||||||
|
out, err := applyFilters([]byte(`{"model":"alias"}`), "alias", "", f)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("applyFilters: %v", err)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "top_p").Float(); got != 0.1 {
|
||||||
|
t.Errorf("top_p = %v, want 0.1", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_RewriteMultipartModel(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
mw := multipart.NewWriter(&buf)
|
||||||
|
mw.WriteField("model", "old-name")
|
||||||
|
mw.WriteField("language", "en")
|
||||||
|
fw, _ := mw.CreateFormFile("file", "audio.wav")
|
||||||
|
fw.Write([]byte("RIFFdata"))
|
||||||
|
mw.Close()
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf)
|
||||||
|
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||||
|
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||||
|
t.Fatalf("ParseMultipartForm: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, contentType, err := rewriteMultipartModel(r.MultipartForm, "new-name")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("rewriteMultipartModel: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed, err := multipart.NewReader(bytes.NewReader(body), boundaryOf(t, contentType)).ReadForm(32 << 20)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("re-parse: %v", err)
|
||||||
|
}
|
||||||
|
if got := parsed.Value["model"][0]; got != "new-name" {
|
||||||
|
t.Errorf("model = %q, want new-name", got)
|
||||||
|
}
|
||||||
|
if got := parsed.Value["language"][0]; got != "en" {
|
||||||
|
t.Errorf("language = %q, want en", got)
|
||||||
|
}
|
||||||
|
fh := parsed.File["file"][0]
|
||||||
|
f, _ := fh.Open()
|
||||||
|
data, _ := io.ReadAll(f)
|
||||||
|
f.Close()
|
||||||
|
if string(data) != "RIFFdata" {
|
||||||
|
t.Errorf("file data = %q, want RIFFdata", data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func boundaryOf(t *testing.T, contentType string) string {
|
||||||
|
t.Helper()
|
||||||
|
_, params, ok := strings.Cut(contentType, "boundary=")
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("no boundary in %q", contentType)
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_FormFilterMiddleware(t *testing.T) {
|
||||||
|
cfg := config.Config{Models: map[string]config.ModelConfig{
|
||||||
|
"whisper": {UseModelName: "whisper-large-v3"},
|
||||||
|
}}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
mw := multipart.NewWriter(&buf)
|
||||||
|
mw.WriteField("model", "whisper")
|
||||||
|
fw, _ := mw.CreateFormFile("file", "a.wav")
|
||||||
|
fw.Write([]byte("xx"))
|
||||||
|
mw.Close()
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf)
|
||||||
|
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||||
|
|
||||||
|
var gotModel string
|
||||||
|
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_ = r.ParseMultipartForm(32 << 20)
|
||||||
|
gotModel = r.MultipartForm.Value["model"][0]
|
||||||
|
})
|
||||||
|
CreateFormFilterMiddleware(cfg)(final).ServeHTTP(httptest.NewRecorder(), r)
|
||||||
|
|
||||||
|
if gotModel != "whisper-large-v3" {
|
||||||
|
t.Errorf("model rewritten to %q, want whisper-large-v3", gotModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/event"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
|
)
|
||||||
|
|
||||||
|
// inflightCounter tracks the number of in-flight model-dispatched requests.
|
||||||
|
type inflightCounter struct {
|
||||||
|
total atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *inflightCounter) Increment() int64 { return c.total.Add(1) }
|
||||||
|
func (c *inflightCounter) Decrement() int64 { return c.total.Add(-1) }
|
||||||
|
func (c *inflightCounter) Current() int64 { return c.total.Load() }
|
||||||
|
|
||||||
|
// CreateInflightMiddleware returns middleware that increments the counter on
|
||||||
|
// entry and decrements on exit, emitting an InFlightRequestsEvent for each.
|
||||||
|
func CreateInflightMiddleware(c *inflightCounter) chain.Middleware {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
event.Emit(shared.InFlightRequestsEvent{Total: int(c.Increment())})
|
||||||
|
defer func() {
|
||||||
|
event.Emit(shared.InFlightRequestsEvent{Total: int(c.Decrement())})
|
||||||
|
}()
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,233 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/router"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewLoggers builds the proxy, upstream, and combined (mux) log monitors,
|
||||||
|
// wiring each one's output per the logToStdout config value. The proxy and
|
||||||
|
// upstream monitors write into muxlog (rather than os.Stdout directly) so
|
||||||
|
// muxlog accumulates a combined history for the /logs endpoints, while each
|
||||||
|
// monitor keeps its own per-source history and event subscribers.
|
||||||
|
//
|
||||||
|
// Behaviour matches the legacy ProxyManager:
|
||||||
|
//
|
||||||
|
// - none: everything discarded
|
||||||
|
// - both: proxy + upstream both routed to muxlog -> stdout
|
||||||
|
// - upstream: only upstream routed to muxlog -> stdout; proxy discarded
|
||||||
|
// - proxy: only proxy routed to muxlog -> stdout; upstream discarded
|
||||||
|
//
|
||||||
|
// An empty or unrecognised value behaves like "proxy".
|
||||||
|
func NewLoggers(logToStdout string) (muxlog, proxylog, upstreamlog *logmon.Monitor) {
|
||||||
|
switch logToStdout {
|
||||||
|
case config.LogToStdoutNone:
|
||||||
|
muxlog = logmon.NewWriter(io.Discard)
|
||||||
|
proxylog = logmon.NewWriter(io.Discard)
|
||||||
|
upstreamlog = logmon.NewWriter(io.Discard)
|
||||||
|
case config.LogToStdoutBoth:
|
||||||
|
muxlog = logmon.NewWriter(os.Stdout)
|
||||||
|
proxylog = logmon.NewWriter(muxlog)
|
||||||
|
upstreamlog = logmon.NewWriter(muxlog)
|
||||||
|
case config.LogToStdoutUpstream:
|
||||||
|
muxlog = logmon.NewWriter(os.Stdout)
|
||||||
|
proxylog = logmon.NewWriter(io.Discard)
|
||||||
|
upstreamlog = logmon.NewWriter(muxlog)
|
||||||
|
default:
|
||||||
|
// config.LogToStdoutProxy, and the fallback for an unset value.
|
||||||
|
muxlog = logmon.NewWriter(os.Stdout)
|
||||||
|
proxylog = logmon.NewWriter(muxlog)
|
||||||
|
upstreamlog = logmon.NewWriter(io.Discard)
|
||||||
|
}
|
||||||
|
return muxlog, proxylog, upstreamlog
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleLogs serves the historical proxy/upstream log. HTML clients are
|
||||||
|
// redirected to the UI.
|
||||||
|
func (s *Server) handleLogs(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if strings.Contains(r.Header.Get("Accept"), "text/html") {
|
||||||
|
http.Redirect(w, r, "/ui/", http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
|
w.Write(s.muxlog.GetHistory())
|
||||||
|
}
|
||||||
|
|
||||||
|
// getLogger resolves a log monitor by id. An empty id maps to the combined
|
||||||
|
// muxlog; "proxy" and "upstream" select the respective monitors.
|
||||||
|
func (s *Server) getLogger(logMonitorID string) (*logmon.Monitor, error) {
|
||||||
|
switch logMonitorID {
|
||||||
|
case "":
|
||||||
|
return s.muxlog, nil
|
||||||
|
case "proxy":
|
||||||
|
return s.proxylog, nil
|
||||||
|
case "upstream":
|
||||||
|
return s.upstreamlog, nil
|
||||||
|
default:
|
||||||
|
if _, modelID, _, found := findModelInPath(s.cfg, "/"+logMonitorID); found {
|
||||||
|
if log, ok := s.local.ProcessLogger(modelID); ok {
|
||||||
|
return log, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleLogStream tails a log monitor: it writes the history then streams live
|
||||||
|
// log data until the client disconnects or the server shuts down.
|
||||||
|
func (s *Server) handleLogStream(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
|
w.Header().Set("Transfer-Encoding", "chunked")
|
||||||
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||||
|
// prevent nginx from buffering streamed logs
|
||||||
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
|
||||||
|
logMonitorID := strings.TrimPrefix(r.PathValue("logMonitorID"), "/")
|
||||||
|
// Strip a query string if it leaked into the path segment.
|
||||||
|
if idx := strings.Index(logMonitorID, "?"); idx != -1 {
|
||||||
|
logMonitorID = logMonitorID[:idx]
|
||||||
|
}
|
||||||
|
|
||||||
|
logger, err := s.getLogger(logMonitorID)
|
||||||
|
if err != nil {
|
||||||
|
router.SendResponse(w, r, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
router.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, skipHistory := r.URL.Query()["no-history"]
|
||||||
|
if !skipHistory {
|
||||||
|
if history := logger.GetHistory(); len(history) != 0 {
|
||||||
|
w.Write(history)
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sendChan := make(chan []byte, 10)
|
||||||
|
ctx, cancel := context.WithCancel(r.Context())
|
||||||
|
defer cancel()
|
||||||
|
cancelSub := logger.OnLogData(func(data []byte) {
|
||||||
|
select {
|
||||||
|
case sendChan <- data:
|
||||||
|
case <-ctx.Done():
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
})
|
||||||
|
defer cancelSub()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-r.Context().Done():
|
||||||
|
return
|
||||||
|
case <-s.shutdownCtx.Done():
|
||||||
|
return
|
||||||
|
case data := <-sendChan:
|
||||||
|
w.Write(data)
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// requestLogPathSkips lists path prefixes excluded from the access log because
|
||||||
|
// they are polled frequently and would drown out useful entries.
|
||||||
|
var requestLogPathSkips = []string{"/wol-health", "/api/performance", "/metrics"}
|
||||||
|
|
||||||
|
// statusRecorder wraps an http.ResponseWriter to capture the response status
|
||||||
|
// code and the number of body bytes written, so the access log can report
|
||||||
|
// them. Flush is forwarded so streaming handlers (SSE) still work, and Hijack
|
||||||
|
// is forwarded so httputil.ReverseProxy can upgrade websocket connections.
|
||||||
|
type statusRecorder struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
status int
|
||||||
|
size int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sr *statusRecorder) WriteHeader(code int) {
|
||||||
|
sr.status = code
|
||||||
|
sr.ResponseWriter.WriteHeader(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sr *statusRecorder) Write(b []byte) (int, error) {
|
||||||
|
n, err := sr.ResponseWriter.Write(b)
|
||||||
|
sr.size += n
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sr *statusRecorder) Flush() {
|
||||||
|
if f, ok := sr.ResponseWriter.(http.Flusher); ok {
|
||||||
|
f.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hijack forwards to the underlying ResponseWriter so httputil.ReverseProxy can
|
||||||
|
// take over the connection for websocket upgrades.
|
||||||
|
func (sr *statusRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
if hj, ok := sr.ResponseWriter.(http.Hijacker); ok {
|
||||||
|
return hj.Hijack()
|
||||||
|
}
|
||||||
|
return nil, nil, fmt.Errorf("underlying ResponseWriter does not support hijacking")
|
||||||
|
}
|
||||||
|
|
||||||
|
// clientIP resolves the originating client address, preferring proxy headers
|
||||||
|
// over the raw connection address.
|
||||||
|
func clientIP(r *http.Request) string {
|
||||||
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||||
|
if first, _, found := strings.Cut(xff, ","); found {
|
||||||
|
return strings.TrimSpace(first)
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(xff)
|
||||||
|
}
|
||||||
|
if xr := r.Header.Get("X-Real-IP"); xr != "" {
|
||||||
|
return strings.TrimSpace(xr)
|
||||||
|
}
|
||||||
|
if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
return r.RemoteAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateRequestLogMiddleware returns middleware that records one access-log
|
||||||
|
// line per request to proxylog, in the legacy format:
|
||||||
|
//
|
||||||
|
// clientIP "METHOD PATH PROTO" status bodySize "UA" duration
|
||||||
|
//
|
||||||
|
// Frequently-polled health/metrics paths are skipped. The path is captured
|
||||||
|
// before next runs because /upstream rewrites the request URL in place.
|
||||||
|
func CreateRequestLogMiddleware(proxylog *logmon.Monitor) chain.Middleware {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
for _, prefix := range requestLogPathSkips {
|
||||||
|
if strings.HasPrefix(r.URL.Path, prefix) {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
ip, method, path, proto, ua := clientIP(r), r.Method, r.URL.Path, r.Proto, r.UserAgent()
|
||||||
|
|
||||||
|
rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK}
|
||||||
|
next.ServeHTTP(rec, r)
|
||||||
|
|
||||||
|
proxylog.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
|
||||||
|
ip, method, path, proto, rec.status, rec.size, ua, time.Since(start))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,242 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/http/httputil"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServer_NewLoggers(t *testing.T) {
|
||||||
|
t.Run("proxy mode routes proxy into muxlog, discards upstream", func(t *testing.T) {
|
||||||
|
mux, proxy, upstream := NewLoggers(config.LogToStdoutProxy)
|
||||||
|
proxy.Info("PROXYLINE")
|
||||||
|
upstream.Info("UPSTREAMLINE")
|
||||||
|
h := string(mux.GetHistory())
|
||||||
|
if !strings.Contains(h, "PROXYLINE") {
|
||||||
|
t.Errorf("muxlog missing proxy line: %q", h)
|
||||||
|
}
|
||||||
|
if strings.Contains(h, "UPSTREAMLINE") {
|
||||||
|
t.Errorf("muxlog should not contain upstream line: %q", h)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("both mode routes proxy and upstream into muxlog", func(t *testing.T) {
|
||||||
|
mux, proxy, upstream := NewLoggers(config.LogToStdoutBoth)
|
||||||
|
proxy.Info("PROXYLINE")
|
||||||
|
upstream.Info("UPSTREAMLINE")
|
||||||
|
h := string(mux.GetHistory())
|
||||||
|
if !strings.Contains(h, "PROXYLINE") || !strings.Contains(h, "UPSTREAMLINE") {
|
||||||
|
t.Errorf("muxlog history = %q", h)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("none mode discards everything from muxlog", func(t *testing.T) {
|
||||||
|
mux, proxy, upstream := NewLoggers(config.LogToStdoutNone)
|
||||||
|
proxy.Info("PROXYLINE")
|
||||||
|
upstream.Info("UPSTREAMLINE")
|
||||||
|
if len(mux.GetHistory()) != 0 {
|
||||||
|
t.Errorf("muxlog should be empty, got %q", mux.GetHistory())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_HandleLogs_Plain(t *testing.T) {
|
||||||
|
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||||
|
s.muxlog.Write([]byte("a log line"))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/logs", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status = %d", w.Code)
|
||||||
|
}
|
||||||
|
if ct := w.Header().Get("Content-Type"); ct != "text/plain" {
|
||||||
|
t.Errorf("Content-Type = %q, want text/plain", ct)
|
||||||
|
}
|
||||||
|
if w.Body.String() != "a log line" {
|
||||||
|
t.Errorf("body = %q", w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_HandleLogs_HTMLRedirect(t *testing.T) {
|
||||||
|
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
|
||||||
|
req.Header.Set("Accept", "text/html")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusFound {
|
||||||
|
t.Fatalf("status = %d, want 302", w.Code)
|
||||||
|
}
|
||||||
|
if got := w.Header().Get("Location"); got != "/ui/" {
|
||||||
|
t.Errorf("Location = %q, want /ui/", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_ClientIP(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
setup func(*http.Request)
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"remote addr", func(r *http.Request) { r.RemoteAddr = "10.0.0.5:1234" }, "10.0.0.5"},
|
||||||
|
{"x-forwarded-for", func(r *http.Request) {
|
||||||
|
r.Header.Set("X-Forwarded-For", "1.2.3.4, 5.6.7.8")
|
||||||
|
}, "1.2.3.4"},
|
||||||
|
{"x-real-ip", func(r *http.Request) { r.Header.Set("X-Real-IP", "9.9.9.9") }, "9.9.9.9"},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
r.RemoteAddr = ""
|
||||||
|
c.setup(r)
|
||||||
|
if got := clientIP(r); got != c.want {
|
||||||
|
t.Errorf("clientIP() = %q, want %q", got, c.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_RequestLogMiddleware(t *testing.T) {
|
||||||
|
proxylog := logmon.NewWriter(io.Discard)
|
||||||
|
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
w.Write([]byte("hello"))
|
||||||
|
})
|
||||||
|
mw := CreateRequestLogMiddleware(proxylog)
|
||||||
|
|
||||||
|
t.Run("logs request", func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
r.RemoteAddr = "192.168.1.1:5000"
|
||||||
|
mw(final).ServeHTTP(httptest.NewRecorder(), r)
|
||||||
|
|
||||||
|
line := string(proxylog.GetHistory())
|
||||||
|
for _, want := range []string{"192.168.1.1", "POST /v1/chat/completions", "201", "5"} {
|
||||||
|
if !strings.Contains(line, want) {
|
||||||
|
t.Errorf("log line %q missing %q", line, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, path := range []string{"/wol-health", "/api/performance", "/metrics"} {
|
||||||
|
t.Run("skips "+path, func(t *testing.T) {
|
||||||
|
skipLog := logmon.NewWriter(io.Discard)
|
||||||
|
skipMW := CreateRequestLogMiddleware(skipLog)
|
||||||
|
skipMW(final).ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, path, nil))
|
||||||
|
if len(skipLog.GetHistory()) != 0 {
|
||||||
|
t.Errorf("%s should not be logged; got %q", path, skipLog.GetHistory())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServer_RequestLogMiddleware_WebSocketUpgrade verifies that the access-log
|
||||||
|
// middleware (which wraps responses in statusRecorder) does not break websocket
|
||||||
|
// upgrades proxied through httputil.ReverseProxy. ReverseProxy requires the
|
||||||
|
// ResponseWriter to implement http.Hijacker to take over the connection; if
|
||||||
|
// statusRecorder does not forward Hijack, the upgrade is refused with 502.
|
||||||
|
func TestServer_RequestLogMiddleware_WebSocketUpgrade(t *testing.T) {
|
||||||
|
// Upstream: complete the upgrade handshake then echo bytes back. This
|
||||||
|
// stands in for an upstream that speaks websocket; ReverseProxy only cares
|
||||||
|
// about the 101 response and then copies raw bytes both ways.
|
||||||
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
hj, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("upstream ResponseWriter is not an http.Hijacker")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn, brw, err := hj.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("upstream hijack: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
brw.WriteString("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
|
||||||
|
brw.Flush()
|
||||||
|
// Echo whatever the client sends.
|
||||||
|
buf := make([]byte, 64)
|
||||||
|
n, err := brw.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
brw.Write(buf[:n])
|
||||||
|
brw.Flush()
|
||||||
|
}))
|
||||||
|
defer upstream.Close()
|
||||||
|
|
||||||
|
upstreamURL, err := url.Parse(upstream.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse upstream URL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Front server: ReverseProxy wrapped in the access-log middleware, which is
|
||||||
|
// the production statusRecorder-wrapped path.
|
||||||
|
proxy := httputil.NewSingleHostReverseProxy(upstreamURL)
|
||||||
|
mw := CreateRequestLogMiddleware(logmon.NewWriter(io.Discard))
|
||||||
|
front := httptest.NewServer(mw(proxy))
|
||||||
|
defer front.Close()
|
||||||
|
|
||||||
|
frontURL, err := url.Parse(front.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse front URL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := net.DialTimeout("tcp", frontURL.Host, 5*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial front: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
|
||||||
|
req := "GET / HTTP/1.1\r\n" +
|
||||||
|
"Host: " + frontURL.Host + "\r\n" +
|
||||||
|
"Connection: Upgrade\r\n" +
|
||||||
|
"Upgrade: websocket\r\n" +
|
||||||
|
"\r\n"
|
||||||
|
if _, err := conn.Write([]byte(req)); err != nil {
|
||||||
|
t.Fatalf("write upgrade request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
br := bufio.NewReader(conn)
|
||||||
|
statusLine, err := br.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read status line: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(statusLine, "101") {
|
||||||
|
t.Fatalf("websocket upgrade failed: status line = %q, want 101 Switching Protocols", strings.TrimSpace(statusLine))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drain the rest of the response headers.
|
||||||
|
for {
|
||||||
|
line, err := br.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read headers: %v", err)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(line) == "" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify bytes flow through the hijacked connection.
|
||||||
|
if _, err := conn.Write([]byte("ping")); err != nil {
|
||||||
|
t.Fatalf("write payload: %v", err)
|
||||||
|
}
|
||||||
|
echo := make([]byte, 4)
|
||||||
|
if _, err := io.ReadFull(br, echo); err != nil {
|
||||||
|
t.Fatalf("read echo: %v", err)
|
||||||
|
}
|
||||||
|
if string(echo) != "ping" {
|
||||||
|
t.Errorf("echo = %q, want %q", echo, "ping")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,467 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"compress/flate"
|
||||||
|
"compress/gzip"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/cache"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/event"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/ring"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TokenMetrics holds token usage and performance metrics.
|
||||||
|
type TokenMetrics struct {
|
||||||
|
CachedTokens int `json:"cache_tokens"`
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||||
|
TokensPerSecond float64 `json:"tokens_per_second"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ActivityLogEntry represents parsed token statistics from llama-server logs.
|
||||||
|
type ActivityLogEntry struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Timestamp time.Time `json:"timestamp"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
ReqPath string `json:"req_path"`
|
||||||
|
RespContentType string `json:"resp_content_type"`
|
||||||
|
RespStatusCode int `json:"resp_status_code"`
|
||||||
|
Tokens TokenMetrics `json:"tokens"`
|
||||||
|
DurationMs int `json:"duration_ms"`
|
||||||
|
HasCapture bool `json:"has_capture"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ActivityLogEvent carries a single activity log entry to event subscribers.
|
||||||
|
type ActivityLogEvent struct {
|
||||||
|
Metrics ActivityLogEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ActivityLogEvent) Type() uint32 {
|
||||||
|
return shared.ActivityLogEventID
|
||||||
|
}
|
||||||
|
|
||||||
|
// metricsMonitor parses upstream responses for token statistics, keeps a
|
||||||
|
// bounded in-memory ring of recent activity, and (when captures are enabled)
|
||||||
|
// stores zstd+CBOR-compressed request/response captures in a sized cache.
|
||||||
|
type metricsMonitor struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
metrics ring.Buffer[ActivityLogEntry]
|
||||||
|
nextID int
|
||||||
|
logger *logmon.Monitor
|
||||||
|
|
||||||
|
enableCaptures bool
|
||||||
|
captureCache *cache.Cache // zstd-compressed CBOR of ReqRespCapture
|
||||||
|
}
|
||||||
|
|
||||||
|
// newMetricsMonitor creates a metricsMonitor retaining up to maxMetrics entries.
|
||||||
|
// captureBufferMB is the capture buffer size in megabytes; 0 disables captures.
|
||||||
|
func newMetricsMonitor(logger *logmon.Monitor, maxMetrics int, captureBufferMB int) *metricsMonitor {
|
||||||
|
if maxMetrics <= 0 {
|
||||||
|
maxMetrics = 1000
|
||||||
|
}
|
||||||
|
mm := &metricsMonitor{
|
||||||
|
logger: logger,
|
||||||
|
metrics: ring.NewBuffer[ActivityLogEntry](maxMetrics),
|
||||||
|
enableCaptures: captureBufferMB > 0,
|
||||||
|
}
|
||||||
|
if captureBufferMB > 0 {
|
||||||
|
mm.captureCache = cache.New(captureBufferMB * 1024 * 1024)
|
||||||
|
}
|
||||||
|
return mm
|
||||||
|
}
|
||||||
|
|
||||||
|
// queueMetrics adds a metric to the ring and returns its assigned ID.
|
||||||
|
func (mp *metricsMonitor) queueMetrics(metric ActivityLogEntry) int {
|
||||||
|
mp.mu.Lock()
|
||||||
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
|
metric.ID = mp.nextID
|
||||||
|
mp.nextID++
|
||||||
|
mp.metrics.Push(metric)
|
||||||
|
return metric.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// emitMetric publishes an ActivityLogEvent for the given metric.
|
||||||
|
func (mp *metricsMonitor) emitMetric(metric ActivityLogEntry) {
|
||||||
|
event.Emit(ActivityLogEvent{Metrics: metric})
|
||||||
|
}
|
||||||
|
|
||||||
|
// getMetrics returns a copy of the current metrics.
|
||||||
|
func (mp *metricsMonitor) getMetrics() []ActivityLogEntry {
|
||||||
|
mp.mu.RLock()
|
||||||
|
defer mp.mu.RUnlock()
|
||||||
|
|
||||||
|
result := mp.metrics.Slice()
|
||||||
|
if result == nil {
|
||||||
|
return []ActivityLogEntry{}
|
||||||
|
}
|
||||||
|
if mp.captureCache != nil {
|
||||||
|
for i := range result {
|
||||||
|
result[i].HasCapture = mp.captureCache.Has(result[i].ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// getMetricsJSON returns the current metrics as a JSON array.
|
||||||
|
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(mp.getMetrics())
|
||||||
|
}
|
||||||
|
|
||||||
|
// record parses a completed response body and stores/emits an activity entry.
|
||||||
|
// When captures are enabled, a zstd+CBOR capture is stored for successful
|
||||||
|
// requests, with cf controlling which request/response parts are retained.
|
||||||
|
// reqBody and reqHeaders are the request data buffered before dispatch.
|
||||||
|
func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *responseBodyCopier, cf captureFields, reqBody []byte, reqHeaders map[string]string) {
|
||||||
|
tm := ActivityLogEntry{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Model: modelID,
|
||||||
|
ReqPath: r.URL.Path,
|
||||||
|
RespContentType: recorder.Header().Get("Content-Type"),
|
||||||
|
RespStatusCode: recorder.Status(),
|
||||||
|
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
|
||||||
|
}
|
||||||
|
|
||||||
|
queueAndEmit := func() {
|
||||||
|
tm.ID = mp.queueMetrics(tm)
|
||||||
|
mp.emitMetric(tm)
|
||||||
|
}
|
||||||
|
|
||||||
|
if recorder.Status() != http.StatusOK {
|
||||||
|
mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), r.URL.Path)
|
||||||
|
queueAndEmit()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body := recorder.body.Bytes()
|
||||||
|
if len(body) == 0 {
|
||||||
|
mp.logger.Warn("metrics: empty body, recording minimal metrics")
|
||||||
|
queueAndEmit()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" {
|
||||||
|
decoded, err := decompressBody(body, encoding)
|
||||||
|
if err != nil {
|
||||||
|
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, r.URL.Path)
|
||||||
|
queueAndEmit()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
body = decoded
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||||
|
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||||
|
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, r.URL.Path)
|
||||||
|
} else {
|
||||||
|
tm.Tokens = parsed.Tokens
|
||||||
|
tm.DurationMs = parsed.DurationMs
|
||||||
|
}
|
||||||
|
} else if gjson.ValidBytes(body) {
|
||||||
|
parsed := gjson.ParseBytes(body)
|
||||||
|
usage := parsed.Get("usage")
|
||||||
|
timings := parsed.Get("timings")
|
||||||
|
|
||||||
|
// /infill responses are arrays; timings live in the last element (#463).
|
||||||
|
if strings.HasPrefix(r.URL.Path, "/infill") {
|
||||||
|
if arr := parsed.Array(); len(arr) > 0 {
|
||||||
|
timings = arr[len(arr)-1].Get("timings")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if usage.Exists() || timings.Exists() {
|
||||||
|
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
|
||||||
|
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, r.URL.Path)
|
||||||
|
} else {
|
||||||
|
tm.Tokens = parsedMetrics.Tokens
|
||||||
|
tm.DurationMs = parsedMetrics.DurationMs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", r.URL.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
tm.ID = mp.queueMetrics(tm)
|
||||||
|
if mp.enableCaptures {
|
||||||
|
capture := ReqRespCapture{
|
||||||
|
ID: tm.ID,
|
||||||
|
ReqPath: r.URL.Path,
|
||||||
|
ReqHeaders: reqHeaders,
|
||||||
|
}
|
||||||
|
if cf&captureReqBody != 0 {
|
||||||
|
capture.ReqBody = reqBody
|
||||||
|
}
|
||||||
|
if cf&captureRespHeaders != 0 {
|
||||||
|
capture.RespHeaders = headerMap(recorder.Header())
|
||||||
|
redactHeaders(capture.RespHeaders)
|
||||||
|
delete(capture.RespHeaders, "Content-Encoding")
|
||||||
|
}
|
||||||
|
if cf&captureRespBody != 0 {
|
||||||
|
capture.RespBody = body
|
||||||
|
}
|
||||||
|
if mp.addCapture(capture) {
|
||||||
|
tm.HasCapture = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mp.emitMetric(tm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// usagePaths lists the JSON paths where a per-event usage object can live.
|
||||||
|
var usagePaths = []string{"usage", "response.usage", "message.usage"}
|
||||||
|
|
||||||
|
// extractUsageTokens reads input/output/cached token counts from a usage
|
||||||
|
// gjson.Result, handling the field-name differences across endpoints.
|
||||||
|
func extractUsageTokens(usage gjson.Result) (input, output, cached int64, ok bool) {
|
||||||
|
cached = -1
|
||||||
|
if !usage.Exists() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := usage.Get("prompt_tokens"); v.Exists() {
|
||||||
|
input = v.Int()
|
||||||
|
ok = true
|
||||||
|
} else if v := usage.Get("input_tokens"); v.Exists() {
|
||||||
|
input = v.Int()
|
||||||
|
ok = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := usage.Get("completion_tokens"); v.Exists() {
|
||||||
|
output = v.Int()
|
||||||
|
ok = true
|
||||||
|
} else if v := usage.Get("output_tokens"); v.Exists() {
|
||||||
|
output = v.Int()
|
||||||
|
ok = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := usage.Get("cache_read_input_tokens"); v.Exists() {
|
||||||
|
cached = v.Int()
|
||||||
|
ok = true
|
||||||
|
} else if v := usage.Get("input_tokens_details.cached_tokens"); v.Exists() {
|
||||||
|
cached = v.Int()
|
||||||
|
ok = true
|
||||||
|
} else if v := usage.Get("prompt_tokens_details.cached_tokens"); v.Exists() {
|
||||||
|
cached = v.Int()
|
||||||
|
ok = true
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func processStreamingResponse(modelID string, start time.Time, body []byte) (ActivityLogEntry, error) {
|
||||||
|
var (
|
||||||
|
inputTokens, outputTokens int64
|
||||||
|
cachedTokens int64 = -1
|
||||||
|
hasAny bool
|
||||||
|
timings gjson.Result
|
||||||
|
)
|
||||||
|
|
||||||
|
prefix := []byte("data:")
|
||||||
|
for offset := 0; offset < len(body); {
|
||||||
|
nl := bytes.IndexByte(body[offset:], '\n')
|
||||||
|
var line []byte
|
||||||
|
if nl == -1 {
|
||||||
|
line = body[offset:]
|
||||||
|
offset = len(body)
|
||||||
|
} else {
|
||||||
|
line = body[offset : offset+nl]
|
||||||
|
offset += nl + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
line = bytes.TrimSpace(line)
|
||||||
|
if len(line) == 0 || !bytes.HasPrefix(line, prefix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := bytes.TrimSpace(line[len(prefix):])
|
||||||
|
if len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !gjson.ValidBytes(data) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parsed := gjson.ParseBytes(data)
|
||||||
|
|
||||||
|
for _, path := range usagePaths {
|
||||||
|
u := parsed.Get(path)
|
||||||
|
if !u.Exists() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
i, o, c, ok := extractUsageTokens(u)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
hasAny = true
|
||||||
|
if i > 0 {
|
||||||
|
inputTokens = i
|
||||||
|
}
|
||||||
|
if o > 0 {
|
||||||
|
outputTokens = o
|
||||||
|
}
|
||||||
|
if c >= 0 {
|
||||||
|
cachedTokens = c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if t := parsed.Get("timings"); t.Exists() {
|
||||||
|
timings = t
|
||||||
|
hasAny = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasAny {
|
||||||
|
return ActivityLogEntry{}, fmt.Errorf("no valid JSON data found in stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
return buildMetrics(modelID, start, inputTokens, outputTokens, cachedTokens, timings), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (ActivityLogEntry, error) {
|
||||||
|
input, output, cached, _ := extractUsageTokens(usage)
|
||||||
|
return buildMetrics(modelID, start, input, output, cached, timings), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildMetrics composes an ActivityLogEntry from accumulated token counts and
|
||||||
|
// optional llama-server timings (which override input/output and provide rates).
|
||||||
|
func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, cachedTokens int64, timings gjson.Result) ActivityLogEntry {
|
||||||
|
wallDurationMs := int(time.Since(start).Milliseconds())
|
||||||
|
durationMs := wallDurationMs
|
||||||
|
tokensPerSecond := -1.0
|
||||||
|
promptPerSecond := -1.0
|
||||||
|
|
||||||
|
if timings.Exists() {
|
||||||
|
inputTokens = timings.Get("prompt_n").Int()
|
||||||
|
outputTokens = timings.Get("predicted_n").Int()
|
||||||
|
promptPerSecond = timings.Get("prompt_per_second").Float()
|
||||||
|
tokensPerSecond = timings.Get("predicted_per_second").Float()
|
||||||
|
timingsDurationMs := int(timings.Get("prompt_ms").Float() + timings.Get("predicted_ms").Float())
|
||||||
|
if timingsDurationMs > durationMs {
|
||||||
|
durationMs = timingsDurationMs
|
||||||
|
}
|
||||||
|
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
|
||||||
|
cachedTokens = cachedValue.Int()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ActivityLogEntry{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Model: modelID,
|
||||||
|
Tokens: TokenMetrics{
|
||||||
|
CachedTokens: int(cachedTokens),
|
||||||
|
InputTokens: int(inputTokens),
|
||||||
|
OutputTokens: int(outputTokens),
|
||||||
|
PromptPerSecond: promptPerSecond,
|
||||||
|
TokensPerSecond: tokensPerSecond,
|
||||||
|
},
|
||||||
|
DurationMs: durationMs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// decompressBody decompresses the body based on the Content-Encoding header.
|
||||||
|
func decompressBody(body []byte, encoding string) ([]byte, error) {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(encoding)) {
|
||||||
|
case "gzip":
|
||||||
|
reader, err := gzip.NewReader(bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
return io.ReadAll(reader)
|
||||||
|
case "deflate":
|
||||||
|
reader := flate.NewReader(bytes.NewReader(body))
|
||||||
|
defer reader.Close()
|
||||||
|
return io.ReadAll(reader)
|
||||||
|
default:
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterAcceptEncoding filters Accept-Encoding to only gzip/deflate so response
|
||||||
|
// bodies remain decompressible for metrics parsing.
|
||||||
|
func filterAcceptEncoding(acceptEncoding string) string {
|
||||||
|
if acceptEncoding == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
supported := map[string]bool{"gzip": true, "deflate": true}
|
||||||
|
var filtered []string
|
||||||
|
for part := range strings.SplitSeq(acceptEncoding, ",") {
|
||||||
|
encoding, _, _ := strings.Cut(strings.TrimSpace(part), ";")
|
||||||
|
if supported[strings.ToLower(encoding)] {
|
||||||
|
filtered = append(filtered, strings.TrimSpace(part))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(filtered, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// responseBodyCopier tees the upstream response to the client while buffering
|
||||||
|
// it for metrics parsing. Status defaults to 200 until WriteHeader is called.
|
||||||
|
type responseBodyCopier struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
body *bytes.Buffer
|
||||||
|
tee io.Writer
|
||||||
|
status int
|
||||||
|
wroteHeader bool
|
||||||
|
start time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBodyCopier(w http.ResponseWriter) *responseBodyCopier {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
return &responseBodyCopier{
|
||||||
|
ResponseWriter: w,
|
||||||
|
body: buf,
|
||||||
|
tee: io.MultiWriter(w, buf),
|
||||||
|
status: http.StatusOK,
|
||||||
|
start: time.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseBodyCopier) Write(b []byte) (int, error) {
|
||||||
|
if !w.wroteHeader {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
// On a protocol upgrade (e.g. websocket) the body is raw framed data, not a
|
||||||
|
// metrics-parseable response, so write straight to the client without
|
||||||
|
// buffering a copy we can't use.
|
||||||
|
if w.status == http.StatusSwitchingProtocols {
|
||||||
|
return w.ResponseWriter.Write(b)
|
||||||
|
}
|
||||||
|
return w.tee.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseBodyCopier) WriteHeader(statusCode int) {
|
||||||
|
if w.wroteHeader {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.wroteHeader = true
|
||||||
|
w.status = statusCode
|
||||||
|
w.ResponseWriter.WriteHeader(statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush forwards to the underlying writer so streaming responses still flush.
|
||||||
|
func (w *responseBodyCopier) Flush() {
|
||||||
|
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||||
|
f.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hijack forwards to the underlying writer so httputil.ReverseProxy can take
|
||||||
|
// over the connection for websocket upgrades.
|
||||||
|
func (w *responseBodyCopier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
if hj, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||||
|
return hj.Hijack()
|
||||||
|
}
|
||||||
|
return nil, nil, fmt.Errorf("underlying ResponseWriter does not support hijacking")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseBodyCopier) Status() int { return w.status }
|
||||||
|
func (w *responseBodyCopier) StartTime() time.Time { return w.start }
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/router"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateMetricsMiddleware returns middleware that records token metrics for
|
||||||
|
// model-dispatched POST requests. It resolves the model, tees the response into
|
||||||
|
// a buffer, and parses token usage once the upstream handler returns.
|
||||||
|
func CreateMetricsMiddleware(mm *metricsMonitor, cfg config.Config) chain.Middleware {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if mm == nil || r.Method != http.MethodPost {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve the model now so downstream dispatch hits the context
|
||||||
|
// fast path; FetchContext restores the request body.
|
||||||
|
data, err := router.FetchContext(r, cfg)
|
||||||
|
if err != nil {
|
||||||
|
router.SendError(w, r, router.ErrNoModelInContext)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Buffer the request body/headers for capture before dispatch
|
||||||
|
// consumes them.
|
||||||
|
cf := captureFieldsFor(r.URL.Path)
|
||||||
|
var reqBody []byte
|
||||||
|
var reqHeaders map[string]string
|
||||||
|
if mm.enableCaptures {
|
||||||
|
if cf&captureReqBody != 0 && r.Body != nil {
|
||||||
|
if buffered, err := io.ReadAll(r.Body); err == nil {
|
||||||
|
reqBody = buffered
|
||||||
|
r.Body.Close()
|
||||||
|
r.Body = io.NopCloser(bytes.NewReader(reqBody))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cf&captureReqHeaders != 0 {
|
||||||
|
reqHeaders = headerMap(r.Header)
|
||||||
|
redactHeaders(reqHeaders)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restrict Accept-Encoding to encodings we can decompress so the
|
||||||
|
// buffered response body stays parseable.
|
||||||
|
if ae := r.Header.Get("Accept-Encoding"); ae != "" {
|
||||||
|
r.Header.Set("Accept-Encoding", filterAcceptEncoding(ae))
|
||||||
|
}
|
||||||
|
|
||||||
|
recorder := newBodyCopier(w)
|
||||||
|
next.ServeHTTP(recorder, r)
|
||||||
|
mm.record(data.ModelID, r, recorder, cf, reqBody, reqHeaders)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,74 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServer_ParseMetrics_ChatCompletions(t *testing.T) {
|
||||||
|
body := `{"usage":{"prompt_tokens":12,"completion_tokens":7,"prompt_tokens_details":{"cached_tokens":4}}}`
|
||||||
|
parsed := gjson.Parse(body)
|
||||||
|
entry, err := parseMetrics("m", time.Now(), parsed.Get("usage"), parsed.Get("timings"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseMetrics: %v", err)
|
||||||
|
}
|
||||||
|
if entry.Tokens.InputTokens != 12 || entry.Tokens.OutputTokens != 7 || entry.Tokens.CachedTokens != 4 {
|
||||||
|
t.Fatalf("tokens = %+v", entry.Tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_ParseMetrics_Timings(t *testing.T) {
|
||||||
|
body := `{"timings":{"prompt_n":20,"predicted_n":50,"prompt_per_second":100.0,"predicted_per_second":40.0,"prompt_ms":200,"predicted_ms":1250,"cache_n":8}}`
|
||||||
|
parsed := gjson.Parse(body)
|
||||||
|
entry, err := parseMetrics("m", time.Now(), parsed.Get("usage"), parsed.Get("timings"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseMetrics: %v", err)
|
||||||
|
}
|
||||||
|
if entry.Tokens.InputTokens != 20 || entry.Tokens.OutputTokens != 50 || entry.Tokens.CachedTokens != 8 {
|
||||||
|
t.Fatalf("tokens = %+v", entry.Tokens)
|
||||||
|
}
|
||||||
|
if entry.Tokens.TokensPerSecond != 40.0 || entry.Tokens.PromptPerSecond != 100.0 {
|
||||||
|
t.Fatalf("rates = %+v", entry.Tokens)
|
||||||
|
}
|
||||||
|
if entry.DurationMs != 1450 {
|
||||||
|
t.Fatalf("DurationMs = %d, want 1450", entry.DurationMs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_ProcessStreamingResponse(t *testing.T) {
|
||||||
|
body := []byte("data: {\"choices\":[{}]}\n\n" +
|
||||||
|
"data: {\"usage\":{\"prompt_tokens\":15,\"completion_tokens\":33}}\n\n" +
|
||||||
|
"data: [DONE]\n\n")
|
||||||
|
entry, err := processStreamingResponse("m", time.Now(), body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("processStreamingResponse: %v", err)
|
||||||
|
}
|
||||||
|
if entry.Tokens.InputTokens != 15 || entry.Tokens.OutputTokens != 33 {
|
||||||
|
t.Fatalf("tokens = %+v", entry.Tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_ProcessStreamingResponse_NoData(t *testing.T) {
|
||||||
|
if _, err := processStreamingResponse("m", time.Now(), []byte("data: [DONE]\n\n")); err == nil {
|
||||||
|
t.Fatal("expected error for stream with no usage data")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_ParseMetrics_Infill(t *testing.T) {
|
||||||
|
// /infill responses are arrays; timings live in the last element.
|
||||||
|
body := `[{"content":"a"},{"content":"b","timings":{"prompt_n":5,"predicted_n":9,"prompt_ms":10,"predicted_ms":20}}]`
|
||||||
|
parsed := gjson.Parse(body)
|
||||||
|
timings := parsed.Get("timings")
|
||||||
|
if arr := parsed.Array(); len(arr) > 0 {
|
||||||
|
timings = arr[len(arr)-1].Get("timings")
|
||||||
|
}
|
||||||
|
entry, err := parseMetrics("m", time.Now(), parsed.Get("usage"), timings)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseMetrics: %v", err)
|
||||||
|
}
|
||||||
|
if entry.Tokens.InputTokens != 5 || entry.Tokens.OutputTokens != 9 {
|
||||||
|
t.Fatalf("tokens = %+v", entry.Tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,290 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/router"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server owns the HTTP mux, cross-cutting middleware, and the local/peer model
|
||||||
|
// dispatch. It supersedes router.Server: it builds the local and peer routers
|
||||||
|
// directly and dispatches between them itself.
|
||||||
|
type Server struct {
|
||||||
|
cfg config.Config
|
||||||
|
|
||||||
|
muxlog *logmon.Monitor
|
||||||
|
proxylog *logmon.Monitor
|
||||||
|
upstreamlog *logmon.Monitor
|
||||||
|
|
||||||
|
perf *perf.Monitor
|
||||||
|
inflight *inflightCounter
|
||||||
|
metrics *metricsMonitor
|
||||||
|
build BuildInfo
|
||||||
|
|
||||||
|
local router.LocalRouter
|
||||||
|
peer router.Router
|
||||||
|
|
||||||
|
mux *http.ServeMux
|
||||||
|
handler http.Handler
|
||||||
|
|
||||||
|
shutdownCtx context.Context
|
||||||
|
shutdownFn context.CancelFunc
|
||||||
|
shuttingDown atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelPostJSONRoutes are endpoints with a model id in the JSON request body.
|
||||||
|
var modelPostJSONRoutes = []string{
|
||||||
|
"/v1/chat/completions",
|
||||||
|
"/v1/responses",
|
||||||
|
"/v1/completions",
|
||||||
|
"/v1/messages",
|
||||||
|
"/v1/messages/count_tokens",
|
||||||
|
"/v1/embeddings",
|
||||||
|
"/reranking",
|
||||||
|
"/rerank",
|
||||||
|
"/v1/rerank",
|
||||||
|
"/v1/reranking",
|
||||||
|
"/infill",
|
||||||
|
"/completion",
|
||||||
|
"/v1/audio/speech",
|
||||||
|
"/v1/audio/voices",
|
||||||
|
"/v1/images/generations",
|
||||||
|
"/sdapi/v1/txt2img",
|
||||||
|
"/sdapi/v1/img2img",
|
||||||
|
|
||||||
|
// versionless routes, the /v/ is stripped before the request is forwarded upstream
|
||||||
|
// see issue #728
|
||||||
|
"/v/chat/completions",
|
||||||
|
"/v/responses",
|
||||||
|
"/v/completions",
|
||||||
|
"/v/messages",
|
||||||
|
"/v/messages/count_tokens",
|
||||||
|
"/v/embeddings",
|
||||||
|
"/v/rerank",
|
||||||
|
"/v/reranking",
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelPostFormRoutes are multipart/form-data endpoints with a model id in the form data
|
||||||
|
var modelPostFormRoutes = []string{
|
||||||
|
"/v1/audio/transcriptions",
|
||||||
|
"/v1/images/edits",
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelGetRoutes are model-dispatched GET endpoints (the model arrives as a
|
||||||
|
// query parameter).
|
||||||
|
var modelGetRoutes = []string{
|
||||||
|
"/v1/audio/voices",
|
||||||
|
"/sdapi/v1/loras",
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildInfo carries version metadata surfaced by GET /api/version.
|
||||||
|
type BuildInfo struct {
|
||||||
|
Version string
|
||||||
|
Commit string
|
||||||
|
Date string
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(cfg config.Config, muxlog *logmon.Monitor, proxylog *logmon.Monitor, upstreamlog *logmon.Monitor, perfMon *perf.Monitor, build BuildInfo) (*Server, error) {
|
||||||
|
var local router.LocalRouter
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if cfg.Matrix != nil {
|
||||||
|
local, err = router.NewMatrix(cfg, proxylog, upstreamlog)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating matrix router: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
local, err = router.NewGroup(cfg, proxylog, upstreamlog)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating group router: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
peer, err := router.NewPeer(cfg, proxylog)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating peer router: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
|
||||||
|
s := &Server{
|
||||||
|
cfg: cfg,
|
||||||
|
muxlog: muxlog,
|
||||||
|
proxylog: proxylog,
|
||||||
|
upstreamlog: upstreamlog,
|
||||||
|
perf: perfMon,
|
||||||
|
inflight: &inflightCounter{},
|
||||||
|
metrics: newMetricsMonitor(proxylog, cfg.MetricsMaxInMemory, cfg.CaptureBuffer),
|
||||||
|
build: build,
|
||||||
|
local: local,
|
||||||
|
peer: peer,
|
||||||
|
shutdownCtx: shutdownCtx,
|
||||||
|
shutdownFn: shutdownFn,
|
||||||
|
}
|
||||||
|
s.routes()
|
||||||
|
s.startPreload()
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// localPeerHandler dispatches a model-routed request to the local or peer
|
||||||
|
// router. The model is resolved once via router.FetchContext.
|
||||||
|
func (s *Server) localPeerHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
stripVersionPrefix(r)
|
||||||
|
|
||||||
|
data, err := router.FetchContext(r, s.cfg)
|
||||||
|
if err != nil {
|
||||||
|
router.SendError(w, r, router.ErrNoModelInContext)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case s.local.Handles(data.ModelID):
|
||||||
|
s.proxylog.Debugf("dispatch: using local process for model: %s", data.ModelID)
|
||||||
|
s.local.ServeHTTP(w, r)
|
||||||
|
case s.peer.Handles(data.ModelID):
|
||||||
|
s.proxylog.Debugf("dispatch: using peer for model: %s", data.ModelID)
|
||||||
|
s.peer.ServeHTTP(w, r)
|
||||||
|
default:
|
||||||
|
router.SendError(w, r, router.ErrNoRouterFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// stripVersionPrefix rewrites versionless /v/... requests to their /... form
|
||||||
|
// before forwarding upstream (issue #728).
|
||||||
|
func stripVersionPrefix(r *http.Request) {
|
||||||
|
if strings.HasPrefix(r.URL.Path, "/v/") {
|
||||||
|
r.URL.Path = strings.TrimPrefix(r.URL.Path, "/v")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// routes builds the mux, registers every route, and wraps the mux with the
|
||||||
|
// global CORS middleware.
|
||||||
|
func (s *Server) routes() {
|
||||||
|
authMW := CreateAuthMiddleware(s.cfg)
|
||||||
|
filterMW := CreateFilterMiddleware(s.cfg)
|
||||||
|
formFilterMW := CreateFormFilterMiddleware(s.cfg)
|
||||||
|
|
||||||
|
// Model-dispatched routes get auth + per-model concurrency limiting + body
|
||||||
|
// filters + in-flight tracking + token metrics. concurrencyMW rejects with
|
||||||
|
// 429 before the body filters do any rewrite work. filterMW rewrites JSON
|
||||||
|
// bodies and formFilterMW rewrites multipart bodies; each is a no-op for the
|
||||||
|
// other's Content-Type. Both run before the metrics middleware so it buffers
|
||||||
|
// the rewritten body.
|
||||||
|
modelChain := chain.New(
|
||||||
|
authMW,
|
||||||
|
CreateConcurrencyMiddleware(s.cfg),
|
||||||
|
filterMW,
|
||||||
|
formFilterMW,
|
||||||
|
CreateInflightMiddleware(s.inflight),
|
||||||
|
CreateMetricsMiddleware(s.metrics, s.cfg),
|
||||||
|
)
|
||||||
|
// Custom endpoints only need auth.
|
||||||
|
apiChain := chain.New(authMW)
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
dispatch := http.HandlerFunc(s.localPeerHandler)
|
||||||
|
|
||||||
|
for _, path := range modelPostJSONRoutes {
|
||||||
|
mux.Handle("POST "+path, modelChain.Then(dispatch))
|
||||||
|
}
|
||||||
|
for _, path := range modelPostFormRoutes {
|
||||||
|
mux.Handle("POST "+path, modelChain.Then(dispatch))
|
||||||
|
}
|
||||||
|
for _, path := range modelGetRoutes {
|
||||||
|
mux.Handle("GET "+path, modelChain.Then(dispatch))
|
||||||
|
}
|
||||||
|
|
||||||
|
// llama-swap API + custom endpoints.
|
||||||
|
mux.Handle("GET /v1/models", apiChain.ThenFunc(s.handleListModels))
|
||||||
|
mux.Handle("GET /logs", apiChain.ThenFunc(s.handleLogs))
|
||||||
|
mux.Handle("GET /logs/stream", apiChain.ThenFunc(s.handleLogStream))
|
||||||
|
mux.Handle("GET /logs/stream/{logMonitorID...}", apiChain.ThenFunc(s.handleLogStream))
|
||||||
|
|
||||||
|
mux.HandleFunc("GET /health", handleHealth)
|
||||||
|
mux.HandleFunc("GET /wol-health", handleHealth)
|
||||||
|
mux.HandleFunc("GET /{$}", handleRootRedirect)
|
||||||
|
|
||||||
|
// Embedded UI.
|
||||||
|
mux.HandleFunc("GET /ui/", s.handleUI)
|
||||||
|
mux.HandleFunc("GET /favicon.ico", s.handleFavicon)
|
||||||
|
|
||||||
|
// Prometheus metrics (no auth, matches the legacy endpoint).
|
||||||
|
mux.HandleFunc("GET /metrics", s.handleMetrics)
|
||||||
|
|
||||||
|
// Operations endpoints.
|
||||||
|
mux.Handle("GET /unload", apiChain.ThenFunc(s.handleUnload))
|
||||||
|
mux.Handle("GET /running", apiChain.ThenFunc(s.handleRunning))
|
||||||
|
|
||||||
|
// Upstream passthrough.
|
||||||
|
mux.HandleFunc("GET /upstream", handleUpstreamRedirect)
|
||||||
|
mux.Handle("/upstream/{upstreamPath...}", apiChain.ThenFunc(s.handleUpstream))
|
||||||
|
|
||||||
|
// API group (API-key protected) consumed by the UI.
|
||||||
|
mux.Handle("POST /api/models/unload", apiChain.ThenFunc(s.handleAPIUnloadAll))
|
||||||
|
mux.Handle("POST /api/models/unload/{model...}", apiChain.ThenFunc(s.handleAPIUnloadModel))
|
||||||
|
mux.Handle("GET /api/events", apiChain.ThenFunc(s.handleAPIEvents))
|
||||||
|
mux.Handle("GET /api/metrics", apiChain.ThenFunc(s.handleAPIMetrics))
|
||||||
|
mux.Handle("GET /api/performance", apiChain.ThenFunc(s.handleAPIPerformance))
|
||||||
|
mux.Handle("GET /api/version", apiChain.ThenFunc(s.handleAPIVersion))
|
||||||
|
mux.Handle("GET /api/captures/{id}", apiChain.ThenFunc(s.handleAPICapture))
|
||||||
|
|
||||||
|
s.mux = mux
|
||||||
|
s.handler = chain.New(CreateRequestLogMiddleware(s.proxylog), CreateCORSMiddleware()).Then(mux)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.handler.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseStreams cancels long-lived response streams (Server-Sent Events) so a
|
||||||
|
// graceful httpServer.Shutdown can drain without blocking on them. It does not
|
||||||
|
// tear down routers; call Shutdown for that. Safe to call repeatedly.
|
||||||
|
func (s *Server) CloseStreams() {
|
||||||
|
s.shutdownFn()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown stops the local and peer routers in parallel. It is idempotent;
|
||||||
|
// repeated calls return nil without re-running shutdown.
|
||||||
|
//
|
||||||
|
// Callers must drain inflight HTTP requests (httpServer.Shutdown) before
|
||||||
|
// calling this, otherwise inflight requests 502 when their processes are torn
|
||||||
|
// down. Call CloseStreams before httpServer.Shutdown so SSE streams do not
|
||||||
|
// block the drain.
|
||||||
|
func (s *Server) Shutdown(timeout time.Duration) error {
|
||||||
|
if !s.shuttingDown.CompareAndSwap(false, true) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.shutdownFn()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var mu sync.Mutex
|
||||||
|
var errs []error
|
||||||
|
|
||||||
|
for _, rt := range []router.Router{s.local, s.peer} {
|
||||||
|
if rt == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go func(rt router.Router) {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := rt.Shutdown(timeout); err != nil {
|
||||||
|
mu.Lock()
|
||||||
|
errs = append(errs, err)
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
}(rt)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
return errors.Join(errs...)
|
||||||
|
}
|
||||||
@@ -0,0 +1,331 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/event"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/router"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
|
)
|
||||||
|
|
||||||
|
// stubRouter is a minimal router.LocalRouter for Server dispatch tests.
|
||||||
|
type stubRouter struct {
|
||||||
|
models map[string]bool
|
||||||
|
response string
|
||||||
|
shutdownCalls atomic.Int32
|
||||||
|
running map[string]process.ProcessState
|
||||||
|
unloadCalls atomic.Int32
|
||||||
|
loggers map[string]*logmon.Monitor
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStubRouter(models []string, response string) *stubRouter {
|
||||||
|
m := make(map[string]bool, len(models))
|
||||||
|
for _, id := range models {
|
||||||
|
m[id] = true
|
||||||
|
}
|
||||||
|
return &stubRouter{models: m, response: response}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubRouter) Handles(model string) bool { return s.models[model] }
|
||||||
|
func (s *stubRouter) Shutdown(_ time.Duration) error { s.shutdownCalls.Add(1); return nil }
|
||||||
|
func (s *stubRouter) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(s.response))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubRouter) RunningModels() map[string]process.ProcessState { return s.running }
|
||||||
|
func (s *stubRouter) Unload(_ time.Duration, _ ...string) { s.unloadCalls.Add(1) }
|
||||||
|
func (s *stubRouter) ProcessLogger(modelID string) (*logmon.Monitor, bool) {
|
||||||
|
if s.loggers != nil {
|
||||||
|
if lg, ok := s.loggers[modelID]; ok {
|
||||||
|
return lg, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTestServer wires a Server with stub routers and a built mux.
|
||||||
|
func newTestServer(local router.LocalRouter, peer router.Router) *Server {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
proxylog := logmon.NewWriter(io.Discard)
|
||||||
|
s := &Server{
|
||||||
|
cfg: config.Config{},
|
||||||
|
muxlog: logmon.NewWriter(io.Discard),
|
||||||
|
proxylog: proxylog,
|
||||||
|
upstreamlog: logmon.NewWriter(io.Discard),
|
||||||
|
inflight: &inflightCounter{},
|
||||||
|
metrics: newMetricsMonitor(proxylog, 0, 0),
|
||||||
|
local: local,
|
||||||
|
peer: peer,
|
||||||
|
shutdownCtx: ctx,
|
||||||
|
shutdownFn: cancel,
|
||||||
|
}
|
||||||
|
s.routes()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func chatRequest(model string) *http.Request {
|
||||||
|
body := strings.NewReader(`{"model":"` + model + `"}`)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_New_GroupConfig(t *testing.T) {
|
||||||
|
discard := logmon.NewWriter(io.Discard)
|
||||||
|
s, err := New(config.Config{HealthCheckTimeout: 15}, discard, discard, discard, nil, BuildInfo{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New (group): %v", err)
|
||||||
|
}
|
||||||
|
if err := s.Shutdown(time.Second); err != nil {
|
||||||
|
t.Fatalf("Shutdown: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_New_MatrixConfig(t *testing.T) {
|
||||||
|
discard := logmon.NewWriter(io.Discard)
|
||||||
|
cfg := config.Config{HealthCheckTimeout: 15, Matrix: &config.MatrixConfig{}}
|
||||||
|
s, err := New(cfg, discard, discard, discard, nil, BuildInfo{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New (matrix): %v", err)
|
||||||
|
}
|
||||||
|
if err := s.Shutdown(time.Second); err != nil {
|
||||||
|
t.Fatalf("Shutdown: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_RouteToLocalModel(t *testing.T) {
|
||||||
|
s := newTestServer(
|
||||||
|
newStubRouter([]string{"local-model"}, "local response"),
|
||||||
|
newStubRouter(nil, ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, chatRequest("local-model"))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if w.Body.String() != "local response" {
|
||||||
|
t.Errorf("body=%q want %q", w.Body.String(), "local response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_RouteToPeerModel(t *testing.T) {
|
||||||
|
s := newTestServer(
|
||||||
|
newStubRouter(nil, ""),
|
||||||
|
newStubRouter([]string{"peer-model"}, "peer response"),
|
||||||
|
)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, chatRequest("peer-model"))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if w.Body.String() != "peer response" {
|
||||||
|
t.Errorf("body=%q want %q", w.Body.String(), "peer response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_UnknownModelReturns404(t *testing.T) {
|
||||||
|
s := newTestServer(
|
||||||
|
newStubRouter([]string{"local-model"}, ""),
|
||||||
|
newStubRouter(nil, ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, chatRequest("unknown-model"))
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Errorf("status=%d want 404 body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_UnknownPathReturns404(t *testing.T) {
|
||||||
|
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/does-not-exist", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Errorf("status=%d want 404", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_Health(t *testing.T) {
|
||||||
|
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||||
|
|
||||||
|
for _, path := range []string{"/health", "/wol-health"} {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, path, nil))
|
||||||
|
if w.Code != http.StatusOK || w.Body.String() != "OK" {
|
||||||
|
t.Errorf("%s: status=%d body=%q", path, w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_CORSPreflight(t *testing.T) {
|
||||||
|
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodOptions, "/v1/chat/completions", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNoContent {
|
||||||
|
t.Fatalf("status=%d want 204", w.Code)
|
||||||
|
}
|
||||||
|
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
|
||||||
|
t.Errorf("Access-Control-Allow-Origin=%q want *", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_Unload(t *testing.T) {
|
||||||
|
local := newStubRouter([]string{"m1"}, "")
|
||||||
|
s := newTestServer(local, newStubRouter(nil, ""))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/unload", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK || w.Body.String() != "OK" {
|
||||||
|
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if got := local.unloadCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("unloadCalls=%d want 1", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_Running(t *testing.T) {
|
||||||
|
local := newStubRouter([]string{"m1"}, "")
|
||||||
|
local.running = map[string]process.ProcessState{"m1": process.StateReady}
|
||||||
|
s := newTestServer(local, newStubRouter(nil, ""))
|
||||||
|
s.cfg = config.Config{Models: map[string]config.ModelConfig{
|
||||||
|
"m1": {
|
||||||
|
Cmd: "llama-server",
|
||||||
|
Proxy: "http://localhost:9999",
|
||||||
|
UnloadAfter: 300,
|
||||||
|
Name: "Model One",
|
||||||
|
Description: "the first model",
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/running", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Running []runningModel `json:"running"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("decode: %v body=%q", err, w.Body.String())
|
||||||
|
}
|
||||||
|
if len(resp.Running) != 1 {
|
||||||
|
t.Fatalf("running=%v want 1 entry", resp.Running)
|
||||||
|
}
|
||||||
|
want := runningModel{
|
||||||
|
Model: "m1",
|
||||||
|
State: "ready",
|
||||||
|
Cmd: "llama-server",
|
||||||
|
Proxy: "http://localhost:9999",
|
||||||
|
TTL: 300,
|
||||||
|
Name: "Model One",
|
||||||
|
Description: "the first model",
|
||||||
|
}
|
||||||
|
if resp.Running[0] != want {
|
||||||
|
t.Errorf("got %+v want %+v", resp.Running[0], want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_Preload(t *testing.T) {
|
||||||
|
local := newStubRouter([]string{"m1"}, "ok")
|
||||||
|
s := newTestServer(local, newStubRouter(nil, ""))
|
||||||
|
s.cfg = config.Config{Hooks: config.HooksConfig{
|
||||||
|
OnStartup: config.HookOnStartup{Preload: []string{"m1"}},
|
||||||
|
}}
|
||||||
|
|
||||||
|
got := make(chan shared.ModelPreloadedEvent, 1)
|
||||||
|
cancel := event.On(func(e shared.ModelPreloadedEvent) { got <- e })
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
s.startPreload()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case e := <-got:
|
||||||
|
if e.ModelName != "m1" || !e.Success {
|
||||||
|
t.Errorf("event=%+v want {ModelName:m1 Success:true}", e)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("preload event not received")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_Shutdown_StopsRoutersAndIsIdempotent(t *testing.T) {
|
||||||
|
local := newStubRouter([]string{"local-model"}, "")
|
||||||
|
peer := newStubRouter(nil, "")
|
||||||
|
s := newTestServer(local, peer)
|
||||||
|
|
||||||
|
if err := s.Shutdown(time.Second); err != nil {
|
||||||
|
t.Fatalf("Shutdown: %v", err)
|
||||||
|
}
|
||||||
|
if err := s.Shutdown(time.Second); err != nil {
|
||||||
|
t.Fatalf("second Shutdown: %v", err)
|
||||||
|
}
|
||||||
|
if got := local.shutdownCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("local shutdownCalls=%d want 1", got)
|
||||||
|
}
|
||||||
|
if got := peer.shutdownCalls.Load(); got != 1 {
|
||||||
|
t.Errorf("peer shutdownCalls=%d want 1", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_LogStream_ModelID(t *testing.T) {
|
||||||
|
buf := logmon.NewWriter(io.Discard)
|
||||||
|
buf.Write([]byte("hello from model"))
|
||||||
|
|
||||||
|
local := newStubRouter([]string{"mymodel"}, "")
|
||||||
|
local.loggers = map[string]*logmon.Monitor{"mymodel": buf}
|
||||||
|
|
||||||
|
s := newTestServer(local, newStubRouter(nil, ""))
|
||||||
|
s.cfg = config.Config{Models: map[string]config.ModelConfig{"mymodel": {}}}
|
||||||
|
|
||||||
|
// Pre-cancel the context so the streaming loop exits immediately after
|
||||||
|
// flushing history.
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/logs/stream/mymodel", nil).WithContext(ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if got := w.Body.String(); got != "hello from model" {
|
||||||
|
t.Errorf("body=%q want %q", got, "hello from model")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_LogStream_UnknownID_Returns400(t *testing.T) {
|
||||||
|
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/logs/stream/no-such-model", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("status=%d want 400", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,111 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"embed"
|
||||||
|
"io/fs"
|
||||||
|
"net/http"
|
||||||
|
"path"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// uiStaticFS holds the embedded UI build. The build is copied into ui_dist by
|
||||||
|
// the Makefile's `ui` target; placeholder.txt keeps the embed valid before a
|
||||||
|
// build has run.
|
||||||
|
//
|
||||||
|
//go:embed ui_dist
|
||||||
|
var uiStaticFS embed.FS
|
||||||
|
|
||||||
|
// uiFS is the embedded UI rooted at ui_dist.
|
||||||
|
var uiFS = func() http.FileSystem {
|
||||||
|
sub, err := fs.Sub(uiStaticFS, "ui_dist")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return http.FS(sub)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// selectEncoding chooses the best pre-compressed encoding the client accepts.
|
||||||
|
// It returns the encoding ("br" or "gzip") and the matching file extension.
|
||||||
|
func selectEncoding(acceptEncoding string) (encoding, ext string) {
|
||||||
|
if acceptEncoding == "" {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||||
|
if strings.TrimSpace(strings.SplitN(part, ";", 2)[0]) == "br" {
|
||||||
|
return "br", ".br"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||||
|
if strings.TrimSpace(strings.SplitN(part, ";", 2)[0]) == "gzip" {
|
||||||
|
return "gzip", ".gz"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// serveCompressedFile serves name from fsys, preferring a pre-compressed
|
||||||
|
// sibling (name+".br" / name+".gz") when the client accepts it. It returns an
|
||||||
|
// error without writing a response when name cannot be served, so callers can
|
||||||
|
// fall back (e.g. SPA routing).
|
||||||
|
func serveCompressedFile(fsys http.FileSystem, w http.ResponseWriter, r *http.Request, name string) error {
|
||||||
|
if encoding, ext := selectEncoding(r.Header.Get("Accept-Encoding")); encoding != "" {
|
||||||
|
if cf, err := fsys.Open(name + ext); err == nil {
|
||||||
|
defer cf.Close()
|
||||||
|
if stat, err := cf.Stat(); err == nil && !stat.IsDir() {
|
||||||
|
w.Header().Set("Content-Encoding", encoding)
|
||||||
|
w.Header().Add("Vary", "Accept-Encoding")
|
||||||
|
http.ServeContent(w, r, name, stat.ModTime(), cf)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := fsys.Open(name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
stat, err := file.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if stat.IsDir() {
|
||||||
|
return fs.ErrNotExist
|
||||||
|
}
|
||||||
|
|
||||||
|
http.ServeContent(w, r, name, stat.ModTime(), file)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleUI serves the embedded SPA under /ui/.
|
||||||
|
func (s *Server) handleUI(w http.ResponseWriter, r *http.Request) {
|
||||||
|
serveUI(uiFS, w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// serveUI serves the SPA from fsys. Real files are served with compression
|
||||||
|
// support; unknown paths without a file extension fall back to index.html so
|
||||||
|
// client-side routing works.
|
||||||
|
func serveUI(fsys http.FileSystem, w http.ResponseWriter, r *http.Request) {
|
||||||
|
name := strings.TrimPrefix(r.URL.Path, "/ui/")
|
||||||
|
if name == "" {
|
||||||
|
name = "index.html"
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := serveCompressedFile(fsys, w, r, name); err != nil {
|
||||||
|
if strings.Contains(path.Base(name), ".") {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := serveCompressedFile(fsys, w, r, "index.html"); err != nil {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleFavicon serves /favicon.ico from the embedded UI build.
|
||||||
|
func (s *Server) handleFavicon(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := serveCompressedFile(uiFS, w, r, "favicon.ico"); err != nil {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"testing/fstest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServer_SelectEncoding(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
accept string
|
||||||
|
encoding string
|
||||||
|
ext string
|
||||||
|
}{
|
||||||
|
{"", "", ""},
|
||||||
|
{"gzip", "gzip", ".gz"},
|
||||||
|
{"gzip, deflate, br", "br", ".br"},
|
||||||
|
{"deflate", "", ""},
|
||||||
|
{"br;q=1.0, gzip;q=0.8", "br", ".br"},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
enc, ext := selectEncoding(c.accept)
|
||||||
|
if enc != c.encoding || ext != c.ext {
|
||||||
|
t.Errorf("selectEncoding(%q) = (%q, %q), want (%q, %q)", c.accept, enc, ext, c.encoding, c.ext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func uiTestFS() http.FileSystem {
|
||||||
|
return http.FS(fstest.MapFS{
|
||||||
|
"index.html": {Data: []byte("<html>app</html>")},
|
||||||
|
"app.js": {Data: []byte("plain")},
|
||||||
|
"app.js.br": {Data: []byte("brotli")},
|
||||||
|
"app.js.gz": {Data: []byte("gzipped")},
|
||||||
|
"favicon.ico": {Data: []byte("icon")},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func serveUIRequest(t *testing.T, path, acceptEncoding string) *httptest.ResponseRecorder {
|
||||||
|
t.Helper()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||||
|
if acceptEncoding != "" {
|
||||||
|
req.Header.Set("Accept-Encoding", acceptEncoding)
|
||||||
|
}
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
serveUI(uiTestFS(), w, req)
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_ServeUI_File(t *testing.T) {
|
||||||
|
w := serveUIRequest(t, "/ui/app.js", "")
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status = %d, want 200", w.Code)
|
||||||
|
}
|
||||||
|
if w.Body.String() != "plain" {
|
||||||
|
t.Errorf("body = %q, want plain", w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_ServeUI_Brotli(t *testing.T) {
|
||||||
|
w := serveUIRequest(t, "/ui/app.js", "gzip, br")
|
||||||
|
if got := w.Header().Get("Content-Encoding"); got != "br" {
|
||||||
|
t.Fatalf("Content-Encoding = %q, want br", got)
|
||||||
|
}
|
||||||
|
if w.Body.String() != "brotli" {
|
||||||
|
t.Errorf("body = %q, want brotli", w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_ServeUI_IndexAndRoot(t *testing.T) {
|
||||||
|
for _, path := range []string{"/ui/", "/ui/index.html"} {
|
||||||
|
w := serveUIRequest(t, path, "")
|
||||||
|
if w.Code != http.StatusOK || w.Body.String() != "<html>app</html>" {
|
||||||
|
t.Errorf("%s: status=%d body=%q", path, w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_ServeUI_SPAFallback(t *testing.T) {
|
||||||
|
w := serveUIRequest(t, "/ui/models", "")
|
||||||
|
if w.Code != http.StatusOK || w.Body.String() != "<html>app</html>" {
|
||||||
|
t.Errorf("SPA fallback: status=%d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_ServeUI_MissingFile(t *testing.T) {
|
||||||
|
w := serveUIRequest(t, "/ui/missing.js", "")
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Errorf("status = %d, want 404", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user