Introduce new routing backend (#790)
This is a huge backend change that essentially started with rewriting the concurrency handling for processes and blew up to a refactor of the entire application. In short these are the improvements: **Better state and life cycle management:** Life cycle management of processes has always been the trickiest part of the code. Juggling mutex locks between multiple locations to reduce race conditions was complex. Too complex for my feeble brain to build a simple mental model around as llama-swap gained more features. All of that has been refactored. Most of the locks are gone, replaced with a single run() that owns all state changes. There is one place to start from now to understand and extend routing logic. The improved life cycle management makes it easier to implement more complex swap optimization strategies in the future like #727. **Collation of requests:** llama-swap previously handled requests and swapping in the order they came in. For example requests for models in this order ABCABC would result in 5 swaps. Now those requests are handled in this order AABBCC. The result is less time waiting for swap under a high churn request queue. This fixes #588 #612. A possible future enhancement is to support a starvation parameter so swap can be forced when models have been waiting too long. **Shared base implementation for groups and swap matrix:** During the refactor it became clear that much of the swapping logic was shared between these two implementations. That is not surprising considering the swap matrix was added many moons after groups. Now they share a common base and their specific swap strategies are implemented into the swapPlanner interface. Requests for bespoke or specific swapping scenarios is a common theme in the issues. Now users can implement whatever bespoke and weird swapping strategy they want in their own fork. Just ask your agent of choice to implement swapPlanner. I'll still remaining more conservative on what actually lands in core llama-swap and will continue to evaluate PRs if the changes is good for everyone or just one specific use case. **AI / Agentic Disclosure:** I paid very close attention to the low level swap concurrency design and implementation. It's important to keep that essential part reliable, boring and no surprises. Backwards compatibility was also maintained, even the one way non-exclusive group model loading behaviour that people have rightly pointed out be a weird design decision. With the underlying swap core done the web server, api and UI sitting on top were largely ported over with Claude Code and Opus 4.7 in multiple phases. If you're curious I kept the changes in docs/newrouter-todo.md. I did several passes to make sure things weren't left behind. However, even frontier LLMs at the time of this PR still make small decisions that don't make a lot of sense. They get shit wrong all the time, just in small subtle way. That said, there's likely to be some new bugs introduced with this massive refactor. I'm fairly confident that there's no major architectural flaws that would cause goal seeking agents to make dumb, ugly code decisions. For a little while the legacy llama-swap will be available under cmd/legacy/llama-swap. The plan is to eventually delete that entry point as well as the proxy package. On a bit of a personal note, this PR is exciting and a bit sad for me. I hand wrote much of the original code and this PR ultimately replaces much of it. While the old code served as a good reference for the agent to implement the new stuff it still a bit sad to eventually delete it all.
This commit is contained in:
Vendored
+102
@@ -0,0 +1,102 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrExceedsMaxSize = errors.New("item exceeds maximum cache size")
|
||||
ErrNotFound = errors.New("item not found")
|
||||
)
|
||||
|
||||
type Cache struct {
|
||||
mu sync.Mutex
|
||||
items map[int][]byte
|
||||
order []int
|
||||
size int
|
||||
maxSize int
|
||||
}
|
||||
|
||||
func New(maxBytes int) *Cache {
|
||||
return &Cache{
|
||||
items: make(map[int][]byte),
|
||||
order: make([]int, 0),
|
||||
maxSize: maxBytes,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cache) Add(id int, data []byte) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
dataSize := len(data)
|
||||
if dataSize > c.maxSize {
|
||||
return ErrExceedsMaxSize
|
||||
}
|
||||
|
||||
// If key already exists, remove old entry from size and order
|
||||
if old, exists := c.items[id]; exists {
|
||||
c.size -= len(old)
|
||||
c.removeOrder(id)
|
||||
}
|
||||
|
||||
// Evict oldest (FIFO) until room available
|
||||
for c.size+dataSize > c.maxSize && len(c.order) > 0 {
|
||||
oldestID := c.order[0]
|
||||
c.order = c.order[1:]
|
||||
if evicted, exists := c.items[oldestID]; exists {
|
||||
c.size -= len(evicted)
|
||||
delete(c.items, oldestID)
|
||||
}
|
||||
}
|
||||
|
||||
c.items[id] = data
|
||||
c.order = append(c.order, id)
|
||||
c.size += dataSize
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cache) removeOrder(id int) {
|
||||
for i, v := range c.order {
|
||||
if v == id {
|
||||
c.order = append(c.order[:i], c.order[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cache) Get(id int) ([]byte, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
data, exists := c.items[id]
|
||||
if !exists {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (c *Cache) Has(id int) bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
_, exists := c.items[id]
|
||||
return exists
|
||||
}
|
||||
|
||||
func (c *Cache) Size() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
return c.size
|
||||
}
|
||||
|
||||
func (c *Cache) Clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.items = make(map[int][]byte)
|
||||
c.order = c.order[:0]
|
||||
c.size = 0
|
||||
}
|
||||
Vendored
+130
@@ -0,0 +1,130 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCache_Add(t *testing.T) {
|
||||
t.Run("adds and retrieves item", func(t *testing.T) {
|
||||
c := New(1024)
|
||||
data := []byte("hello")
|
||||
require.NoError(t, c.Add(1, data))
|
||||
|
||||
got, err := c.Get(1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, data, got)
|
||||
})
|
||||
|
||||
t.Run("returns error for oversized item", func(t *testing.T) {
|
||||
c := New(10)
|
||||
err := c.Add(1, make([]byte, 20))
|
||||
assert.ErrorIs(t, err, ErrExceedsMaxSize)
|
||||
})
|
||||
|
||||
t.Run("evicts oldest items to make room", func(t *testing.T) {
|
||||
c := New(100)
|
||||
|
||||
require.NoError(t, c.Add(1, make([]byte, 40)))
|
||||
require.NoError(t, c.Add(2, make([]byte, 40)))
|
||||
// Adding item 3 should evict item 1
|
||||
require.NoError(t, c.Add(3, make([]byte, 40)))
|
||||
|
||||
assert.False(t, c.Has(1))
|
||||
assert.True(t, c.Has(2))
|
||||
assert.True(t, c.Has(3))
|
||||
})
|
||||
|
||||
t.Run("overwrites existing key", func(t *testing.T) {
|
||||
c := New(100)
|
||||
require.NoError(t, c.Add(1, []byte("old")))
|
||||
require.NoError(t, c.Add(1, []byte("new")))
|
||||
|
||||
got, err := c.Get(1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("new"), got)
|
||||
assert.Equal(t, 3, c.Size())
|
||||
})
|
||||
}
|
||||
|
||||
func TestCache_Get(t *testing.T) {
|
||||
t.Run("returns ErrNotFound for missing key", func(t *testing.T) {
|
||||
c := New(100)
|
||||
_, err := c.Get(99)
|
||||
assert.ErrorIs(t, err, ErrNotFound)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCache_Has(t *testing.T) {
|
||||
t.Run("returns true for existing key", func(t *testing.T) {
|
||||
c := New(100)
|
||||
require.NoError(t, c.Add(1, []byte("data")))
|
||||
assert.True(t, c.Has(1))
|
||||
})
|
||||
|
||||
t.Run("returns false for missing key", func(t *testing.T) {
|
||||
c := New(100)
|
||||
assert.False(t, c.Has(1))
|
||||
})
|
||||
}
|
||||
|
||||
func TestCache_Size(t *testing.T) {
|
||||
t.Run("tracks byte usage", func(t *testing.T) {
|
||||
c := New(1000)
|
||||
assert.Equal(t, 0, c.Size())
|
||||
|
||||
require.NoError(t, c.Add(1, make([]byte, 100)))
|
||||
assert.Equal(t, 100, c.Size())
|
||||
|
||||
require.NoError(t, c.Add(2, make([]byte, 200)))
|
||||
assert.Equal(t, 300, c.Size())
|
||||
})
|
||||
|
||||
t.Run("updates on eviction", func(t *testing.T) {
|
||||
c := New(150)
|
||||
require.NoError(t, c.Add(1, make([]byte, 100)))
|
||||
require.NoError(t, c.Add(2, make([]byte, 100)))
|
||||
|
||||
// Item 1 should be evicted, size = 100
|
||||
assert.Equal(t, 100, c.Size())
|
||||
})
|
||||
}
|
||||
|
||||
func TestCache_Clear(t *testing.T) {
|
||||
t.Run("removes all items and resets size", func(t *testing.T) {
|
||||
c := New(1000)
|
||||
require.NoError(t, c.Add(1, []byte("a")))
|
||||
require.NoError(t, c.Add(2, []byte("b")))
|
||||
|
||||
c.Clear()
|
||||
|
||||
assert.Equal(t, 0, c.Size())
|
||||
assert.False(t, c.Has(1))
|
||||
assert.False(t, c.Has(2))
|
||||
})
|
||||
}
|
||||
|
||||
func TestCache_Concurrent(t *testing.T) {
|
||||
t.Run("concurrent operations are safe", func(t *testing.T) {
|
||||
c := New(10000)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 100; j++ {
|
||||
key := id*100 + j
|
||||
_ = c.Add(key, []byte("data"))
|
||||
_, _ = c.Get(key)
|
||||
_ = c.Has(key)
|
||||
_ = c.Size()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
// Package chain composes http.Handler middleware into a single handler.
|
||||
//
|
||||
// A Middleware wraps a downstream http.Handler and may run logic before or
|
||||
// after delegating to it, or short-circuit by not calling next at all
|
||||
// (e.g. auth failure, CORS preflight).
|
||||
package chain
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Middleware wraps an http.Handler with cross-cutting behavior. It receives
|
||||
// the next handler in the chain and returns a handler that may call next,
|
||||
// modify the request/response around it, or short-circuit.
|
||||
type Middleware func(next http.Handler) http.Handler
|
||||
|
||||
// Chain is a reusable middleware stack. Build it once with New (and optionally
|
||||
// extend per-route with Append), then call Then to wrap each terminal handler
|
||||
// when registering routes against an http.ServeMux:
|
||||
//
|
||||
// api := chain.New(authMW, corsMW)
|
||||
// mux.Handle("/v1/chat/completions", api.Then(dispatch))
|
||||
// mux.Handle("/v1/embeddings", api.Append(filters).Then(dispatch))
|
||||
//
|
||||
// Middlewares execute left-to-right: mws[0] runs first and may call into
|
||||
// mws[1], and so on, with the terminal handler invoked last. A middleware
|
||||
// that does not call next short-circuits the remainder of the chain.
|
||||
// A zero Chain is valid and applies no middleware.
|
||||
type Chain struct {
|
||||
mws []Middleware
|
||||
}
|
||||
|
||||
// New returns a Chain that applies mws left-to-right around any terminal
|
||||
// handler passed to Then.
|
||||
func New(mws ...Middleware) Chain {
|
||||
cp := make([]Middleware, len(mws))
|
||||
copy(cp, mws)
|
||||
return Chain{mws: cp}
|
||||
}
|
||||
|
||||
// Append returns a new Chain with mws added after the existing middleware.
|
||||
// The receiver is not modified, so a base Chain can be safely reused across
|
||||
// multiple routes that each need different per-route additions.
|
||||
func (c Chain) Append(mws ...Middleware) Chain {
|
||||
out := make([]Middleware, 0, len(c.mws)+len(mws))
|
||||
out = append(out, c.mws...)
|
||||
out = append(out, mws...)
|
||||
return Chain{mws: out}
|
||||
}
|
||||
|
||||
// Then wraps final with the chain's middleware and returns the resulting
|
||||
// handler, suitable for passing to http.ServeMux.Handle. With an empty chain,
|
||||
// Then returns final unchanged.
|
||||
func (c Chain) Then(final http.Handler) http.Handler {
|
||||
h := final
|
||||
for i := len(c.mws) - 1; i >= 0; i-- {
|
||||
h = c.mws[i](h)
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
// ThenFunc is shorthand for Then(http.HandlerFunc(f)).
|
||||
func (c Chain) ThenFunc(f http.HandlerFunc) http.Handler {
|
||||
return c.Then(f)
|
||||
}
|
||||
@@ -0,0 +1,205 @@
|
||||
package chain
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// recordingMiddleware appends tag before calling next and "-after-"+tag after.
|
||||
func recordingMiddleware(tag string, log *[]string) Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
*log = append(*log, tag)
|
||||
next.ServeHTTP(w, r)
|
||||
*log = append(*log, "after-"+tag)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_HandlersExecuteInDeclaredOrder(t *testing.T) {
|
||||
var log []string
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log = append(log, "final")
|
||||
})
|
||||
|
||||
h := New(
|
||||
recordingMiddleware("a", &log),
|
||||
recordingMiddleware("b", &log),
|
||||
recordingMiddleware("c", &log),
|
||||
).Then(final)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
h.ServeHTTP(rec, req)
|
||||
|
||||
want := []string{"a", "b", "c", "final", "after-c", "after-b", "after-a"}
|
||||
if !equal(log, want) {
|
||||
t.Fatalf("execution order mismatch:\n got: %v\nwant: %v", log, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_ShortCircuitsWhenMiddlewareDoesNotCallNext(t *testing.T) {
|
||||
var log []string
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log = append(log, "final")
|
||||
})
|
||||
|
||||
gate := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log = append(log, "gate")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
|
||||
h := New(
|
||||
recordingMiddleware("outer", &log),
|
||||
gate,
|
||||
recordingMiddleware("inner", &log),
|
||||
).Then(final)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
h.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status: got %d, want %d", rec.Code, http.StatusUnauthorized)
|
||||
}
|
||||
want := []string{"outer", "gate", "after-outer"}
|
||||
if !equal(log, want) {
|
||||
t.Fatalf("short-circuit order mismatch:\n got: %v\nwant: %v", log, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_EarlyWritesAreVisibleToLaterMiddleware(t *testing.T) {
|
||||
header := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Set-By", "outer")
|
||||
_, _ = io.WriteString(w, "outer:")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
inner := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// The outer middleware already set the header; we should see it.
|
||||
if got := w.Header().Get("X-Set-By"); got != "outer" {
|
||||
_, _ = io.WriteString(w, "missing-header;")
|
||||
}
|
||||
_, _ = io.WriteString(w, "inner:")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "final")
|
||||
})
|
||||
|
||||
h := New(header, inner).Then(final)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
|
||||
body, _ := io.ReadAll(rec.Body)
|
||||
if got := string(body); !strings.Contains(got, "outer:inner:final") {
|
||||
t.Fatalf("body: got %q, want it to contain %q", got, "outer:inner:final")
|
||||
}
|
||||
if got := rec.Header().Get("X-Set-By"); got != "outer" {
|
||||
t.Fatalf("header X-Set-By: got %q, want %q", got, "outer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_ReusableAcrossRoutesViaThen(t *testing.T) {
|
||||
var log []string
|
||||
base := New(
|
||||
recordingMiddleware("auth", &log),
|
||||
recordingMiddleware("cors", &log),
|
||||
)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/a", base.ThenFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log = append(log, "handler-a")
|
||||
}))
|
||||
mux.Handle("/b", base.ThenFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log = append(log, "handler-b")
|
||||
}))
|
||||
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
for _, path := range []string{"/a", "/b"} {
|
||||
resp, err := http.Get(srv.URL + path)
|
||||
if err != nil {
|
||||
t.Fatalf("GET %s: %v", path, err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
want := []string{
|
||||
"auth", "cors", "handler-a", "after-cors", "after-auth",
|
||||
"auth", "cors", "handler-b", "after-cors", "after-auth",
|
||||
}
|
||||
if !equal(log, want) {
|
||||
t.Fatalf("reusable chain order mismatch:\n got: %v\nwant: %v", log, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_AppendDoesNotMutateReceiver(t *testing.T) {
|
||||
var log []string
|
||||
base := New(recordingMiddleware("base", &log))
|
||||
extended := base.Append(recordingMiddleware("extra", &log))
|
||||
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log = append(log, "final")
|
||||
})
|
||||
|
||||
// Run extended first to surface any aliasing of the underlying slice.
|
||||
rec := httptest.NewRecorder()
|
||||
extended.Then(final).ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
base.Then(final).ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
|
||||
want := []string{
|
||||
"base", "extra", "final", "after-extra", "after-base",
|
||||
"base", "final", "after-base",
|
||||
}
|
||||
if !equal(log, want) {
|
||||
t.Fatalf("Append must not mutate the receiver:\n got: %v\nwant: %v", log, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_ZeroValueAndEmptyThenAreIdentity(t *testing.T) {
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
})
|
||||
|
||||
for name, c := range map[string]Chain{
|
||||
"zero": {},
|
||||
"empty": New(),
|
||||
} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := c.Then(final)
|
||||
if _, ok := h.(http.HandlerFunc); !ok {
|
||||
t.Fatalf("expected http.HandlerFunc identity, got %T", h)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
if rec.Code != http.StatusTeapot {
|
||||
t.Fatalf("status: got %d, want %d", rec.Code, http.StatusTeapot)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func equal(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,828 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/billziss-gh/golib/shlex"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const DEFAULT_GROUP_ID = "(default)"
|
||||
const (
|
||||
LogToStdoutProxy = "proxy"
|
||||
LogToStdoutUpstream = "upstream"
|
||||
LogToStdoutBoth = "both"
|
||||
LogToStdoutNone = "none"
|
||||
)
|
||||
|
||||
type MacroEntry struct {
|
||||
Name string
|
||||
Value any
|
||||
}
|
||||
|
||||
type MacroList []MacroEntry
|
||||
|
||||
// UnmarshalYAML implements custom YAML unmarshaling that preserves macro definition order
|
||||
func (ml *MacroList) UnmarshalYAML(value *yaml.Node) error {
|
||||
if value.Kind != yaml.MappingNode {
|
||||
return fmt.Errorf("macros must be a mapping")
|
||||
}
|
||||
|
||||
// yaml.Node.Content for a mapping contains alternating key/value nodes
|
||||
entries := make([]MacroEntry, 0, len(value.Content)/2)
|
||||
for i := 0; i < len(value.Content); i += 2 {
|
||||
keyNode := value.Content[i]
|
||||
valueNode := value.Content[i+1]
|
||||
|
||||
var name string
|
||||
if err := keyNode.Decode(&name); err != nil {
|
||||
return fmt.Errorf("failed to decode macro name: %w", err)
|
||||
}
|
||||
|
||||
var val any
|
||||
if err := valueNode.Decode(&val); err != nil {
|
||||
return fmt.Errorf("failed to decode macro value for '%s': %w", name, err)
|
||||
}
|
||||
|
||||
entries = append(entries, MacroEntry{Name: name, Value: val})
|
||||
}
|
||||
|
||||
*ml = entries
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a macro value by name
|
||||
func (ml MacroList) Get(name string) (any, bool) {
|
||||
for _, entry := range ml {
|
||||
if entry.Name == name {
|
||||
return entry.Value, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// ToMap converts MacroList to a map (for backward compatibility if needed)
|
||||
func (ml MacroList) ToMap() map[string]any {
|
||||
result := make(map[string]any, len(ml))
|
||||
for _, entry := range ml {
|
||||
result[entry.Name] = entry.Value
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type GroupConfig struct {
|
||||
Swap bool `yaml:"swap"`
|
||||
Exclusive bool `yaml:"exclusive"`
|
||||
Persistent bool `yaml:"persistent"`
|
||||
Members []string `yaml:"members"`
|
||||
}
|
||||
|
||||
var (
|
||||
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
|
||||
)
|
||||
|
||||
// set default values for GroupConfig
|
||||
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawGroupConfig GroupConfig
|
||||
defaults := rawGroupConfig{
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Persistent: false,
|
||||
Members: []string{},
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*c = GroupConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
type HooksConfig struct {
|
||||
OnStartup HookOnStartup `yaml:"on_startup"`
|
||||
}
|
||||
|
||||
type HookOnStartup struct {
|
||||
Preload []string `yaml:"preload"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
LogRequests bool `yaml:"logRequests"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
LogTimeFormat string `yaml:"logTimeFormat"`
|
||||
LogToStdout string `yaml:"logToStdout"`
|
||||
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
|
||||
CaptureBuffer int `yaml:"captureBuffer"`
|
||||
Performance PerformanceConfig `yaml:"performance"`
|
||||
GlobalTTL int `yaml:"globalTTL"`
|
||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||
|
||||
// swap matrix: solver-based alternative to groups
|
||||
Matrix *MatrixConfig `yaml:"matrix"`
|
||||
|
||||
// populated during validation when matrix is configured
|
||||
ExpandedSets []ExpandedSet `yaml:"-"`
|
||||
|
||||
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
||||
Macros MacroList `yaml:"macros"`
|
||||
|
||||
// map aliases to actual model IDs
|
||||
aliases map[string]string
|
||||
|
||||
// automatic port assignments
|
||||
StartPort int `yaml:"startPort"`
|
||||
|
||||
// hooks, see: #209
|
||||
Hooks HooksConfig `yaml:"hooks"`
|
||||
|
||||
// send loading state in reasoning
|
||||
SendLoadingState bool `yaml:"sendLoadingState"`
|
||||
|
||||
// present aliases to /v1/models OpenAI API listing
|
||||
IncludeAliasesInList bool `yaml:"includeAliasesInList"`
|
||||
|
||||
// support API keys, see issue #433, #50, #251
|
||||
RequiredAPIKeys []string `yaml:"apiKeys"`
|
||||
|
||||
// support remote peers, see issue #433, #296
|
||||
Peers PeerDictionaryConfig `yaml:"peers"`
|
||||
}
|
||||
|
||||
func (c *Config) RealModelName(search string) (string, bool) {
|
||||
if _, found := c.Models[search]; found {
|
||||
return search, true
|
||||
} else if name, found := c.aliases[search]; found {
|
||||
return name, found
|
||||
} else {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
||||
if realName, found := c.RealModelName(modelName); !found {
|
||||
return ModelConfig{}, "", false
|
||||
} else {
|
||||
return c.Models[realName], realName, true
|
||||
}
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (Config, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
defer file.Close()
|
||||
return LoadConfigFromReader(file)
|
||||
}
|
||||
|
||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
yamlStr := string(data)
|
||||
|
||||
// Phase 1: Substitute all ${env.VAR} macros at string level
|
||||
// This is safe because env values are simple strings without YAML formatting
|
||||
yamlStr, err = substituteEnvMacros(yamlStr)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
// Unmarshal into full Config with defaults
|
||||
config := Config{
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
GlobalTTL: 0,
|
||||
}
|
||||
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
// Apply defaults for performance config when section is missing
|
||||
if config.Performance.Every == 0 {
|
||||
config.Performance.Every = 5 * time.Second
|
||||
}
|
||||
if err = config.Performance.Validate(); err != nil {
|
||||
return Config{}, fmt.Errorf("performance: %w", err)
|
||||
}
|
||||
|
||||
if config.StartPort < 1 {
|
||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||
}
|
||||
|
||||
if config.GlobalTTL < 0 {
|
||||
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
||||
}
|
||||
|
||||
switch config.LogToStdout {
|
||||
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
||||
default:
|
||||
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
if _, found := config.aliases[alias]; found {
|
||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||
}
|
||||
config.aliases[alias] = modelName
|
||||
}
|
||||
}
|
||||
|
||||
// Validate global macros
|
||||
for _, macro := range config.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get and sort all model IDs for consistent port assignment
|
||||
modelIds := make([]string, 0, len(config.Models))
|
||||
for modelId := range config.Models {
|
||||
modelIds = append(modelIds, modelId)
|
||||
}
|
||||
sort.Strings(modelIds)
|
||||
|
||||
nextPort := config.StartPort
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
modelConfig.HealthCheckTimeout = config.HealthCheckTimeout
|
||||
|
||||
// Strip comments from command fields
|
||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||
|
||||
// set model TTL to globalTTL it is the default value
|
||||
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
|
||||
modelConfig.UnloadAfter = config.GlobalTTL
|
||||
}
|
||||
|
||||
if modelConfig.UnloadAfter < 0 {
|
||||
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
|
||||
}
|
||||
|
||||
// Validate model macros
|
||||
for _, macro := range modelConfig.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||
mergedMacros = append(mergedMacros, config.Macros...)
|
||||
|
||||
// Add model macros (override globals with same name)
|
||||
for _, entry := range modelConfig.Macros {
|
||||
found := false
|
||||
for i, existing := range mergedMacros {
|
||||
if existing.Name == entry.Name {
|
||||
mergedMacros[i] = entry
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
mergedMacros = append(mergedMacros, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Substitute remaining macros in model fields (LIFO order)
|
||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||
entry := mergedMacros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||
|
||||
// Substitute macros in SetParamsByID keys and values
|
||||
if len(modelConfig.Filters.SetParamsByID) > 0 {
|
||||
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
|
||||
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
|
||||
}
|
||||
newParamMap, ok := newValAny.(map[string]any)
|
||||
if !ok {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
|
||||
}
|
||||
newSetParamsByID[newKey] = newParamMap
|
||||
}
|
||||
modelConfig.Filters.SetParamsByID = newSetParamsByID
|
||||
}
|
||||
|
||||
// Substitute in metadata (type-preserving)
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle PORT macro - only allocate if cmd uses it
|
||||
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||
if cmdHasPort || proxyHasPort {
|
||||
if !cmdHasPort && proxyHasPort {
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
macroSlug := "${PORT}"
|
||||
macroStr := fmt.Sprintf("%v", nextPort)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
|
||||
nextPort++
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
fieldMap := map[string]string{
|
||||
"cmd": modelConfig.Cmd,
|
||||
"cmdStop": modelConfig.CmdStop,
|
||||
"proxy": modelConfig.Proxy,
|
||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||
"filters.stripParams": modelConfig.Filters.StripParams,
|
||||
"name": modelConfig.Name,
|
||||
"description": modelConfig.Description,
|
||||
}
|
||||
|
||||
for fieldName, fieldValue := range fieldMap {
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
if macroName == "PID" && fieldName == "cmdStop" {
|
||||
continue // replaced at runtime
|
||||
}
|
||||
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate SetParamsByID keys and values
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
|
||||
}
|
||||
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
|
||||
for key := range modelConfig.Filters.SetParamsByID {
|
||||
if key == modelId {
|
||||
continue
|
||||
}
|
||||
if _, exists := config.Models[key]; exists {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
|
||||
}
|
||||
if existingModel, exists := config.aliases[key]; exists {
|
||||
if existingModel != modelId {
|
||||
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
|
||||
}
|
||||
continue // already registered as explicit alias for this model
|
||||
}
|
||||
config.aliases[key] = modelId
|
||||
modelConfig.Aliases = append(modelConfig.Aliases, key)
|
||||
}
|
||||
|
||||
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
|
||||
}
|
||||
|
||||
if modelConfig.SendLoadingState == nil {
|
||||
v := config.SendLoadingState
|
||||
modelConfig.SendLoadingState = &v
|
||||
}
|
||||
|
||||
config.Models[modelId] = modelConfig
|
||||
}
|
||||
|
||||
// groups XOR matrix
|
||||
if config.Matrix != nil && len(config.Groups) > 0 {
|
||||
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
|
||||
}
|
||||
|
||||
if config.Matrix != nil {
|
||||
expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("matrix: %w", err)
|
||||
}
|
||||
config.ExpandedSets = expandedSets
|
||||
} else {
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
|
||||
// Validate group members
|
||||
memberUsage := make(map[string]string)
|
||||
for groupID, groupConfig := range config.Groups {
|
||||
prevSet := make(map[string]bool)
|
||||
for _, member := range groupConfig.Members {
|
||||
if _, found := prevSet[member]; found {
|
||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||
}
|
||||
prevSet[member] = true
|
||||
|
||||
if existingGroup, exists := memberUsage[member]; exists {
|
||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||
}
|
||||
memberUsage[member] = groupID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up hooks preload
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
var toPreload []string
|
||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
continue
|
||||
}
|
||||
if real, found := config.RealModelName(modelID); found {
|
||||
toPreload = append(toPreload, real)
|
||||
}
|
||||
}
|
||||
config.Hooks.OnStartup.Preload = toPreload
|
||||
}
|
||||
|
||||
// Validate API keys (env macros already substituted at string level)
|
||||
for i, apikey := range config.RequiredAPIKeys {
|
||||
if apikey == "" {
|
||||
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
||||
}
|
||||
if strings.Contains(apikey, " ") {
|
||||
return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey)
|
||||
}
|
||||
config.RequiredAPIKeys[i] = apikey
|
||||
}
|
||||
|
||||
// Process peers with global macro substitution
|
||||
for peerName, peerConfig := range config.Peers {
|
||||
// Substitute global macros (LIFO order)
|
||||
for i := len(config.Macros) - 1; i >= 0; i-- {
|
||||
entry := config.Macros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
|
||||
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute in setParams (type-preserving)
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
|
||||
}
|
||||
peerConfig.Filters.SetParams = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
config.Peers[peerName] = peerConfig
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// rewrites the yaml to include a default group with any orphaned models
|
||||
func AddDefaultGroupToConfig(config Config) Config {
|
||||
|
||||
if config.Groups == nil {
|
||||
config.Groups = make(map[string]GroupConfig)
|
||||
}
|
||||
|
||||
defaultGroup := GroupConfig{
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{},
|
||||
}
|
||||
// if groups is empty, create a default group and put
|
||||
// all models into it
|
||||
if len(config.Groups) == 0 {
|
||||
for modelName := range config.Models {
|
||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||
}
|
||||
} else {
|
||||
// iterate over existing group members and add non-grouped models into the default group
|
||||
for modelName := range config.Models {
|
||||
foundModel := false
|
||||
found:
|
||||
// search for the model in existing groups
|
||||
for _, groupConfig := range config.Groups {
|
||||
for _, member := range groupConfig.Members {
|
||||
if member == modelName {
|
||||
foundModel = true
|
||||
break found
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundModel {
|
||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
|
||||
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
// Handle trailing backslashes by replacing with space
|
||||
if strings.HasSuffix(trimmed, "\\") {
|
||||
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||
} else {
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
// put it back together
|
||||
cmdStr = strings.Join(cleanedLines, "\n")
|
||||
|
||||
// Split the command into arguments
|
||||
var args []string
|
||||
if runtime.GOOS == "windows" {
|
||||
args = shlex.Windows.Split(cmdStr)
|
||||
} else {
|
||||
args = shlex.Posix.Split(cmdStr)
|
||||
}
|
||||
|
||||
// Ensure the command is not empty
|
||||
if len(args) == 0 {
|
||||
return nil, fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func StripComments(cmdStr string) string {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
return strings.Join(cleanedLines, "\n")
|
||||
}
|
||||
|
||||
// validateMacro validates macro name and value constraints
|
||||
func validateMacro(name string, value any) error {
|
||||
if len(name) >= 64 {
|
||||
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
|
||||
}
|
||||
if !macroNameRegex.MatchString(name) {
|
||||
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
|
||||
}
|
||||
|
||||
// Validate that value is a scalar type
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check for self-reference
|
||||
macroSlug := fmt.Sprintf("${%s}", name)
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
||||
// These types are allowed
|
||||
default:
|
||||
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
|
||||
}
|
||||
|
||||
switch name {
|
||||
case "PORT", "MODEL_ID":
|
||||
return fmt.Errorf("macro name '%s' is reserved", name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
|
||||
func validateNestedForUnknownMacros(value any, context string) error {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
|
||||
}
|
||||
// Check for unsubstituted env macros
|
||||
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range envMatches {
|
||||
varName := match[1]
|
||||
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
|
||||
}
|
||||
return nil
|
||||
|
||||
case map[string]any:
|
||||
for _, val := range v {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
case []any:
|
||||
for _, val := range v {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
default:
|
||||
// Scalar types don't contain macros
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||
// This is called once per macro, allowing LIFO substitution order
|
||||
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
macroStr := fmt.Sprintf("%v", macroValue)
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check if this is a direct macro substitution
|
||||
if v == macroSlug {
|
||||
return macroValue, nil
|
||||
}
|
||||
// Handle string interpolation
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case map[string]any:
|
||||
// Recursively process map values
|
||||
newMap := make(map[string]any)
|
||||
for key, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newMap[key] = newVal
|
||||
}
|
||||
return newMap, nil
|
||||
|
||||
case []any:
|
||||
// Recursively process slice elements
|
||||
newSlice := make([]any, len(v))
|
||||
for i, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSlice[i] = newVal
|
||||
}
|
||||
return newSlice, nil
|
||||
|
||||
default:
|
||||
// Return scalar types as-is
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
|
||||
// Returns error if any referenced env var is not set or contains invalid characters.
|
||||
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
|
||||
// (which strips comments) and only checking the comment-free version for macros.
|
||||
func substituteEnvMacros(s string) (string, error) {
|
||||
// Unmarshal and remarshal to strip YAML comments
|
||||
var raw any
|
||||
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
|
||||
// If YAML is invalid, fall back to scanning the original string
|
||||
// so the user gets the env var error rather than a confusing YAML parse error
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
clean, err := yaml.Marshal(raw)
|
||||
if err != nil {
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
|
||||
return substituteEnvMacrosInString(s, string(clean))
|
||||
}
|
||||
|
||||
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
|
||||
// them in target. This separation allows scanning comment-free YAML while
|
||||
// substituting in the original string.
|
||||
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
|
||||
result := target
|
||||
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
|
||||
for _, match := range matches {
|
||||
fullMatch := match[0] // ${env.VAR_NAME}
|
||||
varName := match[1] // VAR_NAME
|
||||
|
||||
value, exists := os.LookupEnv(varName)
|
||||
if !exists {
|
||||
return "", fmt.Errorf("environment variable '%s' is not set", varName)
|
||||
}
|
||||
|
||||
// Sanitize the value for safe YAML substitution
|
||||
value, err := sanitizeEnvValueForYAML(value, varName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result = strings.ReplaceAll(result, fullMatch, value)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
|
||||
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
|
||||
// for compatibility with double-quoted YAML strings.
|
||||
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
|
||||
// Reject values that would break YAML structure regardless of quoting context
|
||||
if strings.ContainsAny(value, "\n\r\x00") {
|
||||
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
|
||||
}
|
||||
|
||||
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
|
||||
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
|
||||
// In double-quoted contexts, they are interpreted correctly.
|
||||
value = strings.ReplaceAll(value, `\`, `\\`)
|
||||
value = strings.ReplaceAll(value, `"`, `\"`)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
@@ -0,0 +1,274 @@
|
||||
//go:build !windows
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||
// Test a command with spaces and newlines
|
||||
args, err := SanitizeCommand(`python model1.py \
|
||||
-a "double quotes" \
|
||||
--arg2 'single quotes'
|
||||
-s
|
||||
# comment 1
|
||||
--arg3 123 \
|
||||
|
||||
# comment 2
|
||||
--arg4 '"string in string"'
|
||||
|
||||
|
||||
# this will get stripped out as well as the white space above
|
||||
-c "'single quoted'"
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{
|
||||
"python", "model1.py",
|
||||
"-a", "double quotes",
|
||||
"--arg2", "single quotes",
|
||||
"-s",
|
||||
"--arg3", "123",
|
||||
"--arg4", `"string in string"`,
|
||||
"-c", `'single quoted'`,
|
||||
}, args)
|
||||
|
||||
// Test an empty command
|
||||
args, err = SanitizeCommand("")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, args)
|
||||
}
|
||||
|
||||
// Test the default values are automatically set for global, model and group configurations
|
||||
// after loading the configuration
|
||||
func TestConfig_DefaultValuesPosix(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 120, config.HealthCheckTimeout)
|
||||
assert.Equal(t, 5800, config.StartPort)
|
||||
assert.Equal(t, "info", config.LogLevel)
|
||||
assert.Equal(t, "", config.LogTimeFormat)
|
||||
|
||||
// Test default group exists
|
||||
defaultGroup, exists := config.Groups["(default)"]
|
||||
assert.True(t, exists, "default group should exist")
|
||||
if assert.NotNil(t, defaultGroup, "default group should not be nil") {
|
||||
assert.Equal(t, true, defaultGroup.Swap)
|
||||
assert.Equal(t, true, defaultGroup.Exclusive)
|
||||
assert.Equal(t, false, defaultGroup.Persistent)
|
||||
assert.Equal(t, []string{"model1"}, defaultGroup.Members)
|
||||
}
|
||||
|
||||
model1, exists := config.Models["model1"]
|
||||
assert.True(t, exists, "model1 should exist")
|
||||
if assert.NotNil(t, model1, "model1 should not be nil") {
|
||||
assert.Equal(t, "path/to/cmd --port 5800", model1.Cmd) // has the port replaced
|
||||
assert.Equal(t, "", model1.CmdStop)
|
||||
assert.Equal(t, "http://localhost:5800", model1.Proxy)
|
||||
assert.Equal(t, "/health", model1.CheckEndpoint)
|
||||
assert.Equal(t, []string{}, model1.Aliases)
|
||||
assert.Equal(t, []string{}, model1.Env)
|
||||
assert.Equal(t, 0, model1.UnloadAfter)
|
||||
assert.Equal(t, false, model1.Unlisted)
|
||||
assert.Equal(t, "", model1.UseModelName)
|
||||
assert.Equal(t, 0, model1.ConcurrencyLimit)
|
||||
}
|
||||
|
||||
// default empty filter exists
|
||||
assert.Equal(t, "", model1.Filters.StripParams)
|
||||
}
|
||||
|
||||
func TestConfig_LoadPosix(t *testing.T) {
|
||||
// Create a temporary YAML file for testing
|
||||
tempDir, err := os.MkdirTemp("", "test-config")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||
content := `
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
hooks:
|
||||
on_startup:
|
||||
preload: ["model1", "model2"]
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
name: "Model 1"
|
||||
description: "This is model 1"
|
||||
aliases:
|
||||
- "m1"
|
||||
- "model-one"
|
||||
env:
|
||||
- "VAR1=value1"
|
||||
- "VAR2=value2"
|
||||
checkEndpoint: "/health"
|
||||
model2:
|
||||
cmd: ${svr-path} --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "m2"
|
||||
checkEndpoint: "/"
|
||||
model3:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "mthree"
|
||||
checkEndpoint: "/"
|
||||
model4:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8082"
|
||||
checkEndpoint: "/"
|
||||
|
||||
healthCheckTimeout: 15
|
||||
profiles:
|
||||
test:
|
||||
- model1
|
||||
- model2
|
||||
groups:
|
||||
group1:
|
||||
swap: true
|
||||
exclusive: false
|
||||
members: ["model2"]
|
||||
forever:
|
||||
exclusive: false
|
||||
persistent: true
|
||||
members:
|
||||
- "model4"
|
||||
`
|
||||
|
||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("Failed to write temporary file: %v", err)
|
||||
}
|
||||
|
||||
// Load the config and verify
|
||||
config, err := LoadConfig(tempFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
modelLoadingState := false
|
||||
|
||||
defaultTimeout := TimeoutsConfig{
|
||||
Connect: 30,
|
||||
KeepAlive: 30,
|
||||
ResponseHeader: 0,
|
||||
TLSHandshake: 10,
|
||||
ExpectContinue: 1,
|
||||
IdleConn: 90,
|
||||
}
|
||||
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
StartPort: 5800,
|
||||
Macros: MacroList{
|
||||
{"svr-path", "path/to/server"},
|
||||
},
|
||||
Hooks: HooksConfig{
|
||||
OnStartup: HookOnStartup{
|
||||
Preload: []string{"model1", "model2"},
|
||||
},
|
||||
},
|
||||
SendLoadingState: false,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
Name: "Model 1",
|
||||
Description: "This is model 1",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
HealthCheckTimeout: 15,
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
HealthCheckTimeout: 15,
|
||||
},
|
||||
"model3": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"mthree"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
HealthCheckTimeout: 15,
|
||||
},
|
||||
"model4": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8082",
|
||||
CheckEndpoint: "/",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
HealthCheckTimeout: 15,
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
Performance: PerformanceConfig{
|
||||
Every: 5 * time.Second,
|
||||
},
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
aliases: map[string]string{
|
||||
"m1": "model1",
|
||||
"model-one": "model1",
|
||||
"m2": "model2",
|
||||
"mthree": "model3",
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
DEFAULT_GROUP_ID: {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model3"},
|
||||
},
|
||||
"group1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model4"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, config)
|
||||
|
||||
realname, found := config.RealModelName("m1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", realname)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,263 @@
|
||||
//go:build windows
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||
// does not support single quoted strings like in config_posix_test.go
|
||||
args, err := SanitizeCommand(`python model1.py \
|
||||
|
||||
-a "double quotes" \
|
||||
-s
|
||||
--arg3 123 \
|
||||
|
||||
# comment 2
|
||||
--arg4 '"string in string"'
|
||||
|
||||
|
||||
|
||||
# this will get stripped out as well as the white space above
|
||||
-c "'single quoted'"
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{
|
||||
"python", "model1.py",
|
||||
"-a", "double quotes",
|
||||
"-s",
|
||||
"--arg3", "123",
|
||||
"--arg4", "'string in string'", // this is a little weird but the lexer says so...?
|
||||
"-c", `'single quoted'`,
|
||||
}, args)
|
||||
|
||||
// Test an empty command
|
||||
args, err = SanitizeCommand("")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, args)
|
||||
}
|
||||
|
||||
func TestConfig_DefaultValuesWindows(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 120, config.HealthCheckTimeout)
|
||||
assert.Equal(t, 5800, config.StartPort)
|
||||
assert.Equal(t, "info", config.LogLevel)
|
||||
assert.Equal(t, "", config.LogTimeFormat)
|
||||
|
||||
// Test default group exists
|
||||
defaultGroup, exists := config.Groups["(default)"]
|
||||
assert.True(t, exists, "default group should exist")
|
||||
if assert.NotNil(t, defaultGroup, "default group should not be nil") {
|
||||
assert.Equal(t, true, defaultGroup.Swap)
|
||||
assert.Equal(t, true, defaultGroup.Exclusive)
|
||||
assert.Equal(t, false, defaultGroup.Persistent)
|
||||
assert.Equal(t, []string{"model1"}, defaultGroup.Members)
|
||||
}
|
||||
|
||||
model1, exists := config.Models["model1"]
|
||||
assert.True(t, exists, "model1 should exist")
|
||||
if assert.NotNil(t, model1, "model1 should not be nil") {
|
||||
assert.Equal(t, "path/to/cmd --port 5800", model1.Cmd) // has the port replaced
|
||||
assert.Equal(t, "taskkill /f /t /pid ${PID}", model1.CmdStop)
|
||||
assert.Equal(t, "http://localhost:5800", model1.Proxy)
|
||||
assert.Equal(t, "/health", model1.CheckEndpoint)
|
||||
assert.Equal(t, []string{}, model1.Aliases)
|
||||
assert.Equal(t, []string{}, model1.Env)
|
||||
assert.Equal(t, 0, model1.UnloadAfter)
|
||||
assert.Equal(t, false, model1.Unlisted)
|
||||
assert.Equal(t, "", model1.UseModelName)
|
||||
assert.Equal(t, 0, model1.ConcurrencyLimit)
|
||||
}
|
||||
|
||||
// default empty filter exists
|
||||
assert.Equal(t, "", model1.Filters.StripParams)
|
||||
}
|
||||
|
||||
func TestConfig_LoadWindows(t *testing.T) {
|
||||
// Create a temporary YAML file for testing
|
||||
tempDir, err := os.MkdirTemp("", "test-config")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||
content := `
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
aliases:
|
||||
- "m1"
|
||||
- "model-one"
|
||||
env:
|
||||
- "VAR1=value1"
|
||||
- "VAR2=value2"
|
||||
checkEndpoint: "/health"
|
||||
model2:
|
||||
cmd: ${svr-path} --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "m2"
|
||||
checkEndpoint: "/"
|
||||
model3:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "mthree"
|
||||
checkEndpoint: "/"
|
||||
model4:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8082"
|
||||
checkEndpoint: "/"
|
||||
|
||||
healthCheckTimeout: 15
|
||||
profiles:
|
||||
test:
|
||||
- model1
|
||||
- model2
|
||||
groups:
|
||||
group1:
|
||||
swap: true
|
||||
exclusive: false
|
||||
members: ["model2"]
|
||||
forever:
|
||||
exclusive: false
|
||||
persistent: true
|
||||
members:
|
||||
- "model4"
|
||||
`
|
||||
|
||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("Failed to write temporary file: %v", err)
|
||||
}
|
||||
|
||||
// Load the config and verify
|
||||
config, err := LoadConfig(tempFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
modelLoadingState := false
|
||||
|
||||
defaultTimeout := TimeoutsConfig{
|
||||
Connect: 30,
|
||||
KeepAlive: 30,
|
||||
ResponseHeader: 0,
|
||||
TLSHandshake: 10,
|
||||
ExpectContinue: 1,
|
||||
IdleConn: 90,
|
||||
}
|
||||
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
StartPort: 5800,
|
||||
Macros: MacroList{
|
||||
{"svr-path", "path/to/server"},
|
||||
},
|
||||
SendLoadingState: false,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
HealthCheckTimeout: 15,
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
HealthCheckTimeout: 15,
|
||||
},
|
||||
"model3": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"mthree"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
HealthCheckTimeout: 15,
|
||||
},
|
||||
"model4": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8082",
|
||||
CheckEndpoint: "/",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
HealthCheckTimeout: 15,
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
Performance: PerformanceConfig{
|
||||
Every: 5 * time.Second,
|
||||
},
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
aliases: map[string]string{
|
||||
"m1": "model1",
|
||||
"model-one": "model1",
|
||||
"m2": "model2",
|
||||
"mthree": "model3",
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
DEFAULT_GROUP_ID: {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model3"},
|
||||
},
|
||||
"group1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model4"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, config)
|
||||
|
||||
realname, found := config.RealModelName("m1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", realname)
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ProtectedParams is a list of parameters that cannot be set or stripped via filters
|
||||
// These are protected to prevent breaking the proxy's ability to route requests correctly
|
||||
var ProtectedParams = []string{"model"}
|
||||
|
||||
// Filters contains filter settings for modifying request parameters
|
||||
// Used by both models and peers
|
||||
type Filters struct {
|
||||
// StripParams is a comma-separated list of parameters to remove from requests
|
||||
// The "model" parameter can never be removed
|
||||
StripParams string `yaml:"stripParams"`
|
||||
|
||||
// SetParams is a dictionary of parameters to set/override in requests
|
||||
// Protected params (like "model") cannot be set
|
||||
SetParams map[string]any `yaml:"setParams"`
|
||||
|
||||
// SetParamsByID maps requested model IDs to parameters to set/override in requests.
|
||||
// Useful with aliases: a single loaded model can behave differently depending on
|
||||
// which alias the client used. Applied after SetParams, so it can override those values.
|
||||
// Protected params (like "model") cannot be set.
|
||||
SetParamsByID map[string]map[string]any `yaml:"setParamsByID"`
|
||||
}
|
||||
|
||||
// SanitizedStripParams returns a sorted list of parameters to strip,
|
||||
// with duplicates, empty strings, and protected params removed
|
||||
func (f Filters) SanitizedStripParams() []string {
|
||||
if f.StripParams == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
params := strings.Split(f.StripParams, ",")
|
||||
cleaned := make([]string, 0, len(params))
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, param := range params {
|
||||
trimmed := strings.TrimSpace(param)
|
||||
// Skip protected params, empty strings, and duplicates
|
||||
if slices.Contains(ProtectedParams, trimmed) || trimmed == "" || seen[trimmed] {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = true
|
||||
cleaned = append(cleaned, trimmed)
|
||||
}
|
||||
|
||||
if len(cleaned) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
slices.Sort(cleaned)
|
||||
return cleaned
|
||||
}
|
||||
|
||||
// SanitizedSetParamsByID returns the params to set for the given requestedModelID,
|
||||
// with protected params removed and keys sorted for consistent iteration order.
|
||||
// Returns nil if the ID has no entry or all its params are protected.
|
||||
func (f Filters) SanitizedSetParamsByID(requestedModelID string) (map[string]any, []string) {
|
||||
if len(f.SetParamsByID) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
params, found := f.SetParamsByID[requestedModelID]
|
||||
if !found || len(params) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
result := make(map[string]any, len(params))
|
||||
keys := make([]string, 0, len(params))
|
||||
for key, value := range params {
|
||||
if slices.Contains(ProtectedParams, key) {
|
||||
continue
|
||||
}
|
||||
result[key] = value
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
if len(result) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return result, keys
|
||||
}
|
||||
|
||||
// SanitizedSetParams returns a copy of SetParams with protected params removed
|
||||
// and keys sorted for consistent iteration order
|
||||
func (f Filters) SanitizedSetParams() (map[string]any, []string) {
|
||||
if len(f.SetParams) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result := make(map[string]any, len(f.SetParams))
|
||||
keys := make([]string, 0, len(f.SetParams))
|
||||
|
||||
for key, value := range f.SetParams {
|
||||
// Skip protected params
|
||||
if slices.Contains(ProtectedParams, key) {
|
||||
continue
|
||||
}
|
||||
result[key] = value
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
// Sort keys for consistent ordering
|
||||
sort.Strings(keys)
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return result, keys
|
||||
}
|
||||
@@ -0,0 +1,285 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFilters_SanitizedStripParams(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
stripParams string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
stripParams: "",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "single param",
|
||||
stripParams: "temperature",
|
||||
want: []string{"temperature"},
|
||||
},
|
||||
{
|
||||
name: "multiple params",
|
||||
stripParams: "temperature, top_p, top_k",
|
||||
want: []string{"temperature", "top_k", "top_p"}, // sorted
|
||||
},
|
||||
{
|
||||
name: "model param filtered",
|
||||
stripParams: "model, temperature, top_p",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "only model param",
|
||||
stripParams: "model",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "duplicates removed",
|
||||
stripParams: "temperature, top_p, temperature",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "extra whitespace",
|
||||
stripParams: " temperature , top_p ",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "empty values filtered",
|
||||
stripParams: "temperature,,top_p,",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := Filters{StripParams: tt.stripParams}
|
||||
got := f.SanitizedStripParams()
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilters_SanitizedSetParams(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setParams map[string]any
|
||||
wantParams map[string]any
|
||||
wantKeys []string
|
||||
}{
|
||||
{
|
||||
name: "empty setParams",
|
||||
setParams: nil,
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "empty map",
|
||||
setParams: map[string]any{},
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "normal params",
|
||||
setParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
wantKeys: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "protected model param filtered",
|
||||
setParams: map[string]any{
|
||||
"model": "should-be-filtered",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
},
|
||||
wantKeys: []string{"temperature"},
|
||||
},
|
||||
{
|
||||
name: "only protected param",
|
||||
setParams: map[string]any{
|
||||
"model": "should-be-filtered",
|
||||
},
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "complex nested values",
|
||||
setParams: map[string]any{
|
||||
"provider": map[string]any{
|
||||
"data_collection": "deny",
|
||||
"allow_fallbacks": false,
|
||||
},
|
||||
"transforms": []string{"middle-out"},
|
||||
},
|
||||
wantParams: map[string]any{
|
||||
"provider": map[string]any{
|
||||
"data_collection": "deny",
|
||||
"allow_fallbacks": false,
|
||||
},
|
||||
"transforms": []string{"middle-out"},
|
||||
},
|
||||
wantKeys: []string{"provider", "transforms"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := Filters{SetParams: tt.setParams}
|
||||
gotParams, gotKeys := f.SanitizedSetParams()
|
||||
|
||||
assert.Equal(t, len(tt.wantKeys), len(gotKeys), "keys length mismatch")
|
||||
for i, key := range gotKeys {
|
||||
assert.Equal(t, tt.wantKeys[i], key, "key mismatch at %d", i)
|
||||
}
|
||||
|
||||
if tt.wantParams == nil {
|
||||
assert.Nil(t, gotParams, "expected nil params")
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, len(tt.wantParams), len(gotParams), "params length mismatch")
|
||||
for key, wantValue := range tt.wantParams {
|
||||
gotValue, exists := gotParams[key]
|
||||
assert.True(t, exists, "missing key: %s", key)
|
||||
// Simple comparison for basic types
|
||||
switch v := wantValue.(type) {
|
||||
case string, int, float64, bool:
|
||||
assert.Equal(t, v, gotValue, "value mismatch for key %s", key)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilters_SanitizedSetParamsByID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setParamsByID map[string]map[string]any
|
||||
requestedModelID string
|
||||
wantParams map[string]any
|
||||
wantKeys []string
|
||||
}{
|
||||
{
|
||||
name: "empty SetParamsByID returns nil",
|
||||
setParamsByID: nil,
|
||||
requestedModelID: "model1",
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "empty map returns nil",
|
||||
setParamsByID: map[string]map[string]any{},
|
||||
requestedModelID: "model1",
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "non-matching model ID returns nil",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model2": {"temperature": 0.9},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "matching model ID returns correct params",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1": {"temperature": 0.7, "top_p": 0.9},
|
||||
"model2": {"temperature": 0.5},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
wantKeys: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "protected param model is filtered out",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1": {
|
||||
"model": "should-be-filtered",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
},
|
||||
wantKeys: []string{"temperature"},
|
||||
},
|
||||
{
|
||||
name: "only protected param returns nil",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1": {
|
||||
"model": "should-be-filtered",
|
||||
},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "keys are sorted",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1": {
|
||||
"z_param": "z",
|
||||
"a_param": "a",
|
||||
"m_param": "m",
|
||||
},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: map[string]any{
|
||||
"z_param": "z",
|
||||
"a_param": "a",
|
||||
"m_param": "m",
|
||||
},
|
||||
wantKeys: []string{"a_param", "m_param", "z_param"},
|
||||
},
|
||||
{
|
||||
name: "alias style key lookup",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1:high": {"reasoning_effort": "high"},
|
||||
"model1:low": {"reasoning_effort": "low"},
|
||||
},
|
||||
requestedModelID: "model1:high",
|
||||
wantParams: map[string]any{
|
||||
"reasoning_effort": "high",
|
||||
},
|
||||
wantKeys: []string{"reasoning_effort"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := Filters{SetParamsByID: tt.setParamsByID}
|
||||
gotParams, gotKeys := f.SanitizedSetParamsByID(tt.requestedModelID)
|
||||
|
||||
if tt.wantParams == nil {
|
||||
assert.Nil(t, gotParams)
|
||||
assert.Nil(t, gotKeys)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.wantKeys, gotKeys)
|
||||
assert.Equal(t, tt.wantParams, gotParams)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtectedParams(t *testing.T) {
|
||||
// Verify that "model" is protected
|
||||
assert.Contains(t, ProtectedParams, "model")
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Test macro-in-macro basic substitution
|
||||
func TestConfig_MacroInMacroBasic(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"A": "value-A"
|
||||
"B": "prefix-${A}-suffix"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: echo ${B}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "echo prefix-value-A-suffix", config.Models["test"].Cmd)
|
||||
}
|
||||
|
||||
// Test LIFO substitution order with 3+ macro levels
|
||||
func TestConfig_MacroInMacroLIFOOrder(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"base": "/models"
|
||||
"path": "${base}/llama"
|
||||
"full": "${path}/model.gguf"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: load ${full}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "load /models/llama/model.gguf", config.Models["test"].Cmd)
|
||||
}
|
||||
|
||||
// Test MODEL_ID in global macro used by model
|
||||
func TestConfig_ModelIdInGlobalMacro(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"podman-llama": "podman run --name ${MODEL_ID} ghcr.io/ggml-org/llama.cpp:server-cuda"
|
||||
|
||||
models:
|
||||
my-model:
|
||||
cmd: ${podman-llama} -m model.gguf
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "podman run --name my-model ghcr.io/ggml-org/llama.cpp:server-cuda -m model.gguf", config.Models["my-model"].Cmd)
|
||||
}
|
||||
|
||||
// Test model macro overrides global macro in substitution
|
||||
func TestConfig_ModelMacroOverridesGlobal(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"tag": "global"
|
||||
"msg": "value-${tag}"
|
||||
|
||||
models:
|
||||
test:
|
||||
macros:
|
||||
"tag": "model-level"
|
||||
cmd: echo ${msg}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "echo value-model-level", config.Models["test"].Cmd)
|
||||
}
|
||||
|
||||
// Test self-reference detection error
|
||||
func TestConfig_SelfReferenceDetection(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"recursive": "value-${recursive}"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: echo ${recursive}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "recursive")
|
||||
assert.Contains(t, err.Error(), "self-reference")
|
||||
}
|
||||
|
||||
// Test macro substitution in name and description fields
|
||||
func TestConfig_MacroInNameAndDescription(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"VARIANT": "Q4_K_M"
|
||||
"FAMILY": "llama"
|
||||
|
||||
models:
|
||||
my-model:
|
||||
cmd: echo ok
|
||||
proxy: http://localhost:8080
|
||||
name: "${FAMILY} ${VARIANT}"
|
||||
description: "A ${FAMILY} model in ${VARIANT} format"
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "llama Q4_K_M", config.Models["my-model"].Name)
|
||||
assert.Equal(t, "A llama model in Q4_K_M format", config.Models["my-model"].Description)
|
||||
}
|
||||
|
||||
// Test MODEL_ID macro in name and description fields
|
||||
func TestConfig_ModelIDInNameAndDescription(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
models:
|
||||
llama-3b:
|
||||
cmd: echo ok
|
||||
proxy: http://localhost:8080
|
||||
name: "Model: ${MODEL_ID}"
|
||||
description: "Running ${MODEL_ID}"
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Model: llama-3b", config.Models["llama-3b"].Name)
|
||||
assert.Equal(t, "Running llama-3b", config.Models["llama-3b"].Description)
|
||||
}
|
||||
|
||||
// Test unknown macro in name or description returns an error
|
||||
func TestConfig_UnknownMacroInNameDescription(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
models:
|
||||
test:
|
||||
cmd: echo ok
|
||||
proxy: http://localhost:8080
|
||||
name: "Model ${UNDEFINED}"
|
||||
`
|
||||
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "UNDEFINED")
|
||||
}
|
||||
|
||||
// Test undefined macro reference error
|
||||
func TestConfig_UndefinedMacroReference(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"A": "value-${UNDEFINED}"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: echo ${A}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "UNDEFINED")
|
||||
}
|
||||
@@ -0,0 +1,226 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var varKeyPattern = regexp.MustCompile(`^[a-zA-Z0-9]{1,8}$`)
|
||||
|
||||
// MatrixConfig represents the swap matrix configuration block.
|
||||
type MatrixConfig struct {
|
||||
Var map[string]string `yaml:"vars"`
|
||||
EvictCosts map[string]int `yaml:"evict_costs"`
|
||||
Sets OrderedSets `yaml:"sets"`
|
||||
}
|
||||
|
||||
// SetEntry is a single named set with its DSL expression.
|
||||
type SetEntry struct {
|
||||
Name string
|
||||
DSL string
|
||||
}
|
||||
|
||||
// OrderedSets preserves YAML definition order of sets (used for tie-breaking).
|
||||
type OrderedSets []SetEntry
|
||||
|
||||
func (os *OrderedSets) UnmarshalYAML(value *yaml.Node) error {
|
||||
if value.Kind != yaml.MappingNode {
|
||||
return fmt.Errorf("sets must be a mapping")
|
||||
}
|
||||
|
||||
entries := make([]SetEntry, 0, len(value.Content)/2)
|
||||
for i := 0; i < len(value.Content); i += 2 {
|
||||
keyNode := value.Content[i]
|
||||
valueNode := value.Content[i+1]
|
||||
|
||||
var name string
|
||||
if err := keyNode.Decode(&name); err != nil {
|
||||
return fmt.Errorf("failed to decode set name: %w", err)
|
||||
}
|
||||
|
||||
var dsl string
|
||||
if err := valueNode.Decode(&dsl); err != nil {
|
||||
return fmt.Errorf("failed to decode DSL for set %q: %w", name, err)
|
||||
}
|
||||
|
||||
entries = append(entries, SetEntry{Name: name, DSL: dsl})
|
||||
}
|
||||
|
||||
*os = entries
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExpandedSet is one valid combination of concurrent models (real model names).
|
||||
type ExpandedSet struct {
|
||||
SetName string
|
||||
DSL string
|
||||
Models []string // real model names, sorted
|
||||
}
|
||||
|
||||
// ValidateMatrix validates the matrix config and returns all expanded sets.
|
||||
func ValidateMatrix(matrix MatrixConfig, models map[string]ModelConfig) ([]ExpandedSet, error) {
|
||||
if len(matrix.Sets) == 0 {
|
||||
return nil, fmt.Errorf("matrix must define at least one set")
|
||||
}
|
||||
|
||||
if len(matrix.Var) == 0 {
|
||||
return nil, fmt.Errorf("matrix must define at least one var")
|
||||
}
|
||||
|
||||
// Validate var entries
|
||||
if matrix.Var != nil {
|
||||
for id, modelName := range matrix.Var {
|
||||
if !varKeyPattern.MatchString(id) {
|
||||
return nil, fmt.Errorf("var key %q must be alphanumeric and 1-8 characters", id)
|
||||
}
|
||||
if _, exists := models[modelName]; !exists {
|
||||
return nil, fmt.Errorf("var key %q references unknown model %q", id, modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate evict_costs
|
||||
if matrix.EvictCosts != nil {
|
||||
for key, cost := range matrix.EvictCosts {
|
||||
if cost <= 0 {
|
||||
return nil, fmt.Errorf("evict_cost for %q must be a positive integer, got %d", key, cost)
|
||||
}
|
||||
if _, ok := matrix.Var[key]; !ok {
|
||||
return nil, fmt.Errorf("evict_costs: unknown var ID %q", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build dependency graph for +ref topological sort
|
||||
setNames := make(map[string]bool)
|
||||
for _, entry := range matrix.Sets {
|
||||
setNames[entry.Name] = true
|
||||
}
|
||||
|
||||
deps := make(map[string][]string) // setName -> set names it depends on
|
||||
for _, entry := range matrix.Sets {
|
||||
refs, err := extractRefs(entry.DSL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("set %q: %w", entry.Name, err)
|
||||
}
|
||||
for _, ref := range refs {
|
||||
if !setNames[ref] {
|
||||
return nil, fmt.Errorf("set %q references undefined set %q", entry.Name, ref)
|
||||
}
|
||||
}
|
||||
deps[entry.Name] = refs
|
||||
}
|
||||
|
||||
// Topological sort with cycle detection
|
||||
order, err := topologicalSort(matrix.Sets, deps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Expand sets in topological order
|
||||
resolvedRefs := make(map[string][][]string) // set name -> expanded alias-level combos
|
||||
var allExpanded []ExpandedSet
|
||||
totalCombinations := 0
|
||||
|
||||
// Build ordered map for efficient lookup
|
||||
setDSL := make(map[string]string)
|
||||
for _, entry := range matrix.Sets {
|
||||
setDSL[entry.Name] = entry.DSL
|
||||
}
|
||||
|
||||
for _, name := range order {
|
||||
dsl := setDSL[name]
|
||||
combos, err := ParseAndExpandDSL(dsl, resolvedRefs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("set %q: %w", name, err)
|
||||
}
|
||||
|
||||
resolvedRefs[name] = combos
|
||||
|
||||
// Resolve var IDs to real model names
|
||||
for _, combo := range combos {
|
||||
resolved := make([]string, len(combo))
|
||||
for i, ident := range combo {
|
||||
realName, ok := matrix.Var[ident]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("set %q: unknown var ID %q", name, ident)
|
||||
}
|
||||
resolved[i] = realName
|
||||
}
|
||||
sort.Strings(resolved)
|
||||
allExpanded = append(allExpanded, ExpandedSet{
|
||||
SetName: name,
|
||||
DSL: dsl,
|
||||
Models: resolved,
|
||||
})
|
||||
}
|
||||
|
||||
totalCombinations += len(combos)
|
||||
if totalCombinations > maxDSLExpansions {
|
||||
return nil, fmt.Errorf("total expanded combinations (%d) exceed limit of %d", totalCombinations, maxDSLExpansions)
|
||||
}
|
||||
}
|
||||
|
||||
return allExpanded, nil
|
||||
}
|
||||
|
||||
// topologicalSort returns set names in dependency order.
|
||||
// Returns an error if a cycle is detected.
|
||||
func topologicalSort(sets OrderedSets, deps map[string][]string) ([]string, error) {
|
||||
// States: 0 = unvisited, 1 = visiting, 2 = visited
|
||||
state := make(map[string]int)
|
||||
var order []string
|
||||
|
||||
var visit func(name string) error
|
||||
visit = func(name string) error {
|
||||
switch state[name] {
|
||||
case 1:
|
||||
return fmt.Errorf("circular reference detected involving set %q", name)
|
||||
case 2:
|
||||
return nil
|
||||
}
|
||||
state[name] = 1
|
||||
|
||||
for _, dep := range deps[name] {
|
||||
if err := visit(dep); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
state[name] = 2
|
||||
order = append(order, name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Visit in definition order for deterministic output
|
||||
for _, entry := range sets {
|
||||
if state[entry.Name] == 0 {
|
||||
if err := visit(entry.Name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return order, nil
|
||||
}
|
||||
|
||||
// ResolvedEvictCosts returns a map of real model name -> evict cost,
|
||||
// resolving var IDs. Models not listed default to 1.
|
||||
func (m *MatrixConfig) ResolvedEvictCosts() map[string]int {
|
||||
costs := make(map[string]int)
|
||||
if m.EvictCosts == nil {
|
||||
return costs
|
||||
}
|
||||
for key, cost := range m.EvictCosts {
|
||||
// Resolve var ID if present
|
||||
if realName, ok := m.Var[key]; ok {
|
||||
costs[realName] = cost
|
||||
} else {
|
||||
costs[key] = cost
|
||||
}
|
||||
}
|
||||
return costs
|
||||
}
|
||||
@@ -0,0 +1,376 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
const maxDSLExpansions = 1000
|
||||
|
||||
// Token types for the DSL lexer
|
||||
type tokenType int
|
||||
|
||||
const (
|
||||
tokIdent tokenType = iota // model alias or name
|
||||
tokAnd // &
|
||||
tokOr // |
|
||||
tokLParen // (
|
||||
tokRParen // )
|
||||
tokRef // +setName
|
||||
tokEOF
|
||||
)
|
||||
|
||||
type token struct {
|
||||
typ tokenType
|
||||
val string
|
||||
}
|
||||
|
||||
// tokenize splits a DSL string into tokens.
|
||||
func tokenize(input string) ([]token, error) {
|
||||
var tokens []token
|
||||
i := 0
|
||||
runes := []rune(input)
|
||||
|
||||
for i < len(runes) {
|
||||
ch := runes[i]
|
||||
|
||||
// skip whitespace
|
||||
if unicode.IsSpace(ch) {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
switch ch {
|
||||
case '&':
|
||||
tokens = append(tokens, token{tokAnd, "&"})
|
||||
i++
|
||||
case '|':
|
||||
tokens = append(tokens, token{tokOr, "|"})
|
||||
i++
|
||||
case '(':
|
||||
tokens = append(tokens, token{tokLParen, "("})
|
||||
i++
|
||||
case ')':
|
||||
tokens = append(tokens, token{tokRParen, ")"})
|
||||
i++
|
||||
case '+':
|
||||
// +ref: read the identifier that follows
|
||||
i++
|
||||
start := i
|
||||
for i < len(runes) && isIdentChar(runes[i]) {
|
||||
i++
|
||||
}
|
||||
if i == start {
|
||||
return nil, fmt.Errorf("expected set name after '+' at position %d", start)
|
||||
}
|
||||
tokens = append(tokens, token{tokRef, string(runes[start:i])})
|
||||
default:
|
||||
if isIdentChar(ch) {
|
||||
start := i
|
||||
for i < len(runes) && isIdentChar(runes[i]) {
|
||||
i++
|
||||
}
|
||||
tokens = append(tokens, token{tokIdent, string(runes[start:i])})
|
||||
} else {
|
||||
return nil, fmt.Errorf("unexpected character %q at position %d", ch, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokens = append(tokens, token{tokEOF, ""})
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func isIdentChar(ch rune) bool {
|
||||
return unicode.IsLetter(ch) || unicode.IsDigit(ch) || ch == '_' || ch == '-' || ch == '.'
|
||||
}
|
||||
|
||||
// AST node types
|
||||
type dslNode interface {
|
||||
dslNode()
|
||||
}
|
||||
|
||||
type andNode struct {
|
||||
children []dslNode
|
||||
}
|
||||
|
||||
type orNode struct {
|
||||
children []dslNode
|
||||
}
|
||||
|
||||
type leafNode struct {
|
||||
name string
|
||||
}
|
||||
|
||||
type refNode struct {
|
||||
setName string
|
||||
}
|
||||
|
||||
func (andNode) dslNode() {}
|
||||
func (orNode) dslNode() {}
|
||||
func (leafNode) dslNode() {}
|
||||
func (refNode) dslNode() {}
|
||||
|
||||
// parser holds state for recursive-descent parsing.
|
||||
type parser struct {
|
||||
tokens []token
|
||||
pos int
|
||||
}
|
||||
|
||||
func (p *parser) peek() token {
|
||||
if p.pos < len(p.tokens) {
|
||||
return p.tokens[p.pos]
|
||||
}
|
||||
return token{tokEOF, ""}
|
||||
}
|
||||
|
||||
func (p *parser) next() token {
|
||||
t := p.peek()
|
||||
if t.typ != tokEOF {
|
||||
p.pos++
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func (p *parser) expect(typ tokenType) (token, error) {
|
||||
t := p.next()
|
||||
if t.typ != typ {
|
||||
return t, fmt.Errorf("expected token type %d, got %q", typ, t.val)
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// Grammar:
|
||||
//
|
||||
// expr = andExpr
|
||||
// andExpr = orExpr ('&' orExpr)*
|
||||
// orExpr = atom ('|' atom)*
|
||||
// atom = ident | '+' ident | '(' expr ')'
|
||||
//
|
||||
// & binds tighter than |, so "a | b & c" means "a | (b & c)"
|
||||
func parse(tokens []token) (dslNode, error) {
|
||||
p := &parser{tokens: tokens}
|
||||
node, err := p.parseExpr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if p.peek().typ != tokEOF {
|
||||
return nil, fmt.Errorf("unexpected token %q after expression", p.peek().val)
|
||||
}
|
||||
return node, nil
|
||||
}
|
||||
|
||||
func (p *parser) parseExpr() (dslNode, error) {
|
||||
return p.parseOrExpr()
|
||||
}
|
||||
|
||||
func (p *parser) parseOrExpr() (dslNode, error) {
|
||||
left, err := p.parseAndExpr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.peek().typ == tokOr {
|
||||
children := []dslNode{left}
|
||||
for p.peek().typ == tokOr {
|
||||
p.next() // consume |
|
||||
right, err := p.parseAndExpr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
children = append(children, right)
|
||||
}
|
||||
return orNode{children: children}, nil
|
||||
}
|
||||
|
||||
return left, nil
|
||||
}
|
||||
|
||||
func (p *parser) parseAndExpr() (dslNode, error) {
|
||||
left, err := p.parseAtom()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.peek().typ == tokAnd {
|
||||
children := []dslNode{left}
|
||||
for p.peek().typ == tokAnd {
|
||||
p.next() // consume &
|
||||
right, err := p.parseAtom()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
children = append(children, right)
|
||||
}
|
||||
return andNode{children: children}, nil
|
||||
}
|
||||
|
||||
return left, nil
|
||||
}
|
||||
|
||||
func (p *parser) parseAtom() (dslNode, error) {
|
||||
t := p.peek()
|
||||
|
||||
switch t.typ {
|
||||
case tokIdent:
|
||||
p.next()
|
||||
return leafNode{name: t.val}, nil
|
||||
|
||||
case tokRef:
|
||||
p.next()
|
||||
return refNode{setName: t.val}, nil
|
||||
|
||||
case tokLParen:
|
||||
p.next() // consume (
|
||||
node, err := p.parseExpr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := p.expect(tokRParen); err != nil {
|
||||
return nil, fmt.Errorf("missing closing parenthesis")
|
||||
}
|
||||
return node, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected token %q", t.val)
|
||||
}
|
||||
}
|
||||
|
||||
// expand walks the AST and produces all combinations.
|
||||
// resolvedRefs contains previously expanded sets for +ref resolution.
|
||||
func expand(node dslNode, resolvedRefs map[string][][]string) ([][]string, error) {
|
||||
switch n := node.(type) {
|
||||
case leafNode:
|
||||
return [][]string{{n.name}}, nil
|
||||
|
||||
case refNode:
|
||||
expanded, ok := resolvedRefs[n.setName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown set reference +%s", n.setName)
|
||||
}
|
||||
// Return a copy
|
||||
result := make([][]string, len(expanded))
|
||||
for i, combo := range expanded {
|
||||
result[i] = make([]string, len(combo))
|
||||
copy(result[i], combo)
|
||||
}
|
||||
return result, nil
|
||||
|
||||
case orNode:
|
||||
// Union of all children's expansions
|
||||
var result [][]string
|
||||
for _, child := range n.children {
|
||||
childResult, err := expand(child, resolvedRefs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, childResult...)
|
||||
if len(result) > maxDSLExpansions {
|
||||
return nil, fmt.Errorf("DSL expansion exceeded %d combinations", maxDSLExpansions)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
|
||||
case andNode:
|
||||
// Cartesian product across children
|
||||
result := [][]string{{}} // start with one empty combo
|
||||
for _, child := range n.children {
|
||||
childResult, err := expand(child, resolvedRefs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result, err = cartesianProduct(result, childResult, maxDSLExpansions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown node type %T", node)
|
||||
}
|
||||
}
|
||||
|
||||
// cartesianProduct computes the cartesian product of two sets of combinations.
|
||||
// It returns an error if the product would exceed cap.
|
||||
func cartesianProduct(left, right [][]string, cap int) ([][]string, error) {
|
||||
if int64(len(left))*int64(len(right)) > int64(cap) {
|
||||
return nil, fmt.Errorf("DSL expansion exceeded %d combinations", cap)
|
||||
}
|
||||
result := make([][]string, 0, len(left)*len(right))
|
||||
for _, l := range left {
|
||||
for _, r := range right {
|
||||
combo := make([]string, 0, len(l)+len(r))
|
||||
combo = append(combo, l...)
|
||||
combo = append(combo, r...)
|
||||
result = append(result, combo)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ParseAndExpandDSL tokenizes, parses, and expands a DSL string.
|
||||
// resolvedRefs contains previously expanded sets for +ref inlining.
|
||||
func ParseAndExpandDSL(dsl string, resolvedRefs map[string][][]string) ([][]string, error) {
|
||||
dsl = strings.TrimSpace(dsl)
|
||||
if dsl == "" {
|
||||
return nil, fmt.Errorf("empty DSL expression")
|
||||
}
|
||||
|
||||
tokens, err := tokenize(dsl)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tokenize: %w", err)
|
||||
}
|
||||
|
||||
tree, err := parse(tokens)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse: %w", err)
|
||||
}
|
||||
|
||||
result, err := expand(tree, resolvedRefs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Deduplicate models within each combination and sort for consistency
|
||||
for i, combo := range result {
|
||||
result[i] = dedupAndSort(combo)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// dedupAndSort removes duplicate entries and sorts alphabetically.
|
||||
func dedupAndSort(items []string) []string {
|
||||
seen := make(map[string]bool, len(items))
|
||||
var unique []string
|
||||
for _, item := range items {
|
||||
if !seen[item] {
|
||||
seen[item] = true
|
||||
unique = append(unique, item)
|
||||
}
|
||||
}
|
||||
sort.Strings(unique)
|
||||
return unique
|
||||
}
|
||||
|
||||
// extractRefs scans a DSL string for +ref tokens without full parsing.
|
||||
// Used for building the dependency graph for topological sorting.
|
||||
func extractRefs(dsl string) ([]string, error) {
|
||||
tokens, err := tokenize(dsl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var refs []string
|
||||
seen := make(map[string]bool)
|
||||
for _, t := range tokens {
|
||||
if t.typ == tokRef && !seen[t.val] {
|
||||
seen[t.val] = true
|
||||
refs = append(refs, t.val)
|
||||
}
|
||||
}
|
||||
return refs, nil
|
||||
}
|
||||
@@ -0,0 +1,300 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDSL_Tokenize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expect []token
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "single identifier",
|
||||
input: "abc",
|
||||
expect: []token{
|
||||
{tokIdent, "abc"},
|
||||
{tokEOF, ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "identifier with hyphens and dots",
|
||||
input: "model-name.v2",
|
||||
expect: []token{
|
||||
{tokIdent, "model-name.v2"},
|
||||
{tokEOF, ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "and expression",
|
||||
input: "a & b",
|
||||
expect: []token{
|
||||
{tokIdent, "a"},
|
||||
{tokAnd, "&"},
|
||||
{tokIdent, "b"},
|
||||
{tokEOF, ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "or expression",
|
||||
input: "a | b",
|
||||
expect: []token{
|
||||
{tokIdent, "a"},
|
||||
{tokOr, "|"},
|
||||
{tokIdent, "b"},
|
||||
{tokEOF, ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "parentheses",
|
||||
input: "(a | b) & c",
|
||||
expect: []token{
|
||||
{tokLParen, "("},
|
||||
{tokIdent, "a"},
|
||||
{tokOr, "|"},
|
||||
{tokIdent, "b"},
|
||||
{tokRParen, ")"},
|
||||
{tokAnd, "&"},
|
||||
{tokIdent, "c"},
|
||||
{tokEOF, ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ref token",
|
||||
input: "+llms & v",
|
||||
expect: []token{
|
||||
{tokRef, "llms"},
|
||||
{tokAnd, "&"},
|
||||
{tokIdent, "v"},
|
||||
{tokEOF, ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no whitespace",
|
||||
input: "(a|b)&c",
|
||||
expect: []token{
|
||||
{tokLParen, "("},
|
||||
{tokIdent, "a"},
|
||||
{tokOr, "|"},
|
||||
{tokIdent, "b"},
|
||||
{tokRParen, ")"},
|
||||
{tokAnd, "&"},
|
||||
{tokIdent, "c"},
|
||||
{tokEOF, ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty ref",
|
||||
input: "+",
|
||||
errMsg: "expected set name after '+'",
|
||||
},
|
||||
{
|
||||
name: "invalid character",
|
||||
input: "a @ b",
|
||||
errMsg: "unexpected character",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokens, err := tokenize(tt.input)
|
||||
if tt.errMsg != "" {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expect, tokens)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDSL_ParseAndExpand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dsl string
|
||||
refs map[string][][]string
|
||||
expect [][]string
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "single model",
|
||||
dsl: "L",
|
||||
expect: [][]string{{"L"}},
|
||||
},
|
||||
{
|
||||
name: "two models with AND",
|
||||
dsl: "a & b",
|
||||
expect: [][]string{{"a", "b"}},
|
||||
},
|
||||
{
|
||||
name: "two models with OR",
|
||||
dsl: "a | b",
|
||||
expect: [][]string{{"a"}, {"b"}},
|
||||
},
|
||||
{
|
||||
name: "three models with OR",
|
||||
dsl: "a | b | c",
|
||||
expect: [][]string{{"a"}, {"b"}, {"c"}},
|
||||
},
|
||||
{
|
||||
name: "cartesian product (a|b) & (c|d)",
|
||||
dsl: "(a | b) & (c | d)",
|
||||
expect: [][]string{
|
||||
{"a", "c"},
|
||||
{"a", "d"},
|
||||
{"b", "c"},
|
||||
{"b", "d"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "three-way AND",
|
||||
dsl: "a & b & c",
|
||||
expect: [][]string{
|
||||
{"a", "b", "c"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "(g | q | m) & v",
|
||||
dsl: "(g | q | m) & v",
|
||||
expect: [][]string{
|
||||
{"g", "v"},
|
||||
{"q", "v"},
|
||||
{"m", "v"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "(g | q) & v & e",
|
||||
dsl: "(g | q) & v & e",
|
||||
expect: [][]string{
|
||||
{"e", "g", "v"},
|
||||
{"e", "q", "v"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "precedence: a | b & c means a | (b & c)",
|
||||
dsl: "a | b & c",
|
||||
expect: [][]string{
|
||||
{"a"},
|
||||
{"b", "c"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "+ref inlining",
|
||||
dsl: "+llms & v",
|
||||
refs: map[string][][]string{
|
||||
"llms": {{"g"}, {"q"}, {"m"}},
|
||||
},
|
||||
expect: [][]string{
|
||||
{"g", "v"},
|
||||
{"q", "v"},
|
||||
{"m", "v"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "+ref chained",
|
||||
dsl: "+with_tts & e",
|
||||
refs: map[string][][]string{
|
||||
"with_tts": {{"g", "v"}, {"q", "v"}, {"m", "v"}},
|
||||
},
|
||||
expect: [][]string{
|
||||
{"e", "g", "v"},
|
||||
{"e", "q", "v"},
|
||||
{"e", "m", "v"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "dedup within combination",
|
||||
dsl: "a & a",
|
||||
expect: [][]string{
|
||||
{"a"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty expression",
|
||||
dsl: "",
|
||||
errMsg: "empty DSL expression",
|
||||
},
|
||||
{
|
||||
name: "unmatched open paren",
|
||||
dsl: "(a | b",
|
||||
errMsg: "missing closing parenthesis",
|
||||
},
|
||||
{
|
||||
name: "unmatched close paren",
|
||||
dsl: "a | b)",
|
||||
errMsg: "unexpected token",
|
||||
},
|
||||
{
|
||||
name: "unknown ref",
|
||||
dsl: "+unknown",
|
||||
errMsg: "unknown set reference +unknown",
|
||||
},
|
||||
{
|
||||
name: "empty parens",
|
||||
dsl: "()",
|
||||
errMsg: "unexpected token",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
refs := tt.refs
|
||||
if refs == nil {
|
||||
refs = map[string][][]string{}
|
||||
}
|
||||
result, err := ParseAndExpandDSL(tt.dsl, refs)
|
||||
if tt.errMsg != "" {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expect, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDSL_ExpansionCap(t *testing.T) {
|
||||
// Build an expression that would exceed 1000 combinations:
|
||||
// (a1|a2|...|a32) & (b1|b2|...|b32) = 1024 combos
|
||||
var aItems, bItems []string
|
||||
for i := 0; i < 32; i++ {
|
||||
aItems = append(aItems, fmt.Sprintf("a%d", i))
|
||||
bItems = append(bItems, fmt.Sprintf("b%d", i))
|
||||
}
|
||||
dsl := fmt.Sprintf("(%s) & (%s)",
|
||||
join(aItems, " | "),
|
||||
join(bItems, " | "),
|
||||
)
|
||||
_, err := ParseAndExpandDSL(dsl, map[string][][]string{})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "exceeded")
|
||||
}
|
||||
|
||||
func TestDSL_ExtractRefs(t *testing.T) {
|
||||
refs, err := extractRefs("+llms & v & +other")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"llms", "other"}, refs)
|
||||
|
||||
refs, err = extractRefs("a & b")
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, refs)
|
||||
}
|
||||
|
||||
func join(items []string, sep string) string {
|
||||
result := ""
|
||||
for i, item := range items {
|
||||
if i > 0 {
|
||||
result += sep
|
||||
}
|
||||
result += item
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,305 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func makeModels(names ...string) map[string]ModelConfig {
|
||||
m := make(map[string]ModelConfig)
|
||||
for _, name := range names {
|
||||
m[name] = ModelConfig{Cmd: "echo " + name}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func TestValidateMatrix_Basic(t *testing.T) {
|
||||
models := makeModels("gemma", "qwen", "mistral", "voxtral", "llama70B")
|
||||
|
||||
matrix := MatrixConfig{
|
||||
Var: map[string]string{
|
||||
"g": "gemma",
|
||||
"q": "qwen",
|
||||
"m": "mistral",
|
||||
"v": "voxtral",
|
||||
"L": "llama70B",
|
||||
},
|
||||
EvictCosts: map[string]int{
|
||||
"L": 30,
|
||||
"v": 50,
|
||||
},
|
||||
Sets: OrderedSets{
|
||||
{Name: "standard", DSL: "(g | q | m) & v"},
|
||||
{Name: "full", DSL: "L"},
|
||||
},
|
||||
}
|
||||
|
||||
expanded, err := ValidateMatrix(matrix, models)
|
||||
require.NoError(t, err)
|
||||
|
||||
// standard expands to [gemma,voxtral], [qwen,voxtral], [mistral,voxtral]
|
||||
// full expands to [llama70B]
|
||||
assert.Len(t, expanded, 4)
|
||||
|
||||
assert.Equal(t, "standard", expanded[0].SetName)
|
||||
assert.Equal(t, []string{"gemma", "voxtral"}, expanded[0].Models)
|
||||
|
||||
assert.Equal(t, "standard", expanded[1].SetName)
|
||||
assert.Equal(t, []string{"qwen", "voxtral"}, expanded[1].Models)
|
||||
|
||||
assert.Equal(t, "standard", expanded[2].SetName)
|
||||
assert.Equal(t, []string{"mistral", "voxtral"}, expanded[2].Models)
|
||||
|
||||
assert.Equal(t, "full", expanded[3].SetName)
|
||||
assert.Equal(t, []string{"llama70B"}, expanded[3].Models)
|
||||
}
|
||||
|
||||
func TestValidateMatrix_WithRef(t *testing.T) {
|
||||
models := makeModels("gemma", "qwen", "mistral", "voxtral", "reranker")
|
||||
|
||||
matrix := MatrixConfig{
|
||||
Var: map[string]string{
|
||||
"g": "gemma",
|
||||
"q": "qwen",
|
||||
"m": "mistral",
|
||||
"v": "voxtral",
|
||||
"e": "reranker",
|
||||
},
|
||||
Sets: OrderedSets{
|
||||
{Name: "llms", DSL: "g | q | m"},
|
||||
{Name: "with_tts", DSL: "+llms & v"},
|
||||
{Name: "mega", DSL: "+with_tts & e"},
|
||||
},
|
||||
}
|
||||
|
||||
expanded, err := ValidateMatrix(matrix, models)
|
||||
require.NoError(t, err)
|
||||
|
||||
// llms: [gemma], [qwen], [mistral]
|
||||
// with_tts: [gemma,voxtral], [qwen,voxtral], [mistral,voxtral]
|
||||
// mega: [gemma,reranker,voxtral], [qwen,reranker,voxtral], [mistral,reranker,voxtral]
|
||||
assert.Len(t, expanded, 9)
|
||||
|
||||
// Check mega entries
|
||||
megaEntries := filterBySetName(expanded, "mega")
|
||||
assert.Len(t, megaEntries, 3)
|
||||
assert.Equal(t, []string{"gemma", "reranker", "voxtral"}, megaEntries[0].Models)
|
||||
}
|
||||
|
||||
func TestValidateMatrix_MapIDRequired(t *testing.T) {
|
||||
// DSL cannot use real model names directly — must use var IDs
|
||||
models := makeModels("gemma", "voxtral")
|
||||
|
||||
matrix := MatrixConfig{
|
||||
Var: map[string]string{"g": "gemma"},
|
||||
Sets: OrderedSets{
|
||||
{Name: "combo", DSL: "g & voxtral"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ValidateMatrix(matrix, models)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown var ID")
|
||||
}
|
||||
|
||||
func TestValidateMatrix_InvalidAliasKey(t *testing.T) {
|
||||
models := makeModels("gemma")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
alias string
|
||||
errMsg string
|
||||
}{
|
||||
{"too long", "abcdefghi", "alphanumeric and 1-8 characters"},
|
||||
{"has underscore", "a_b", "alphanumeric and 1-8 characters"},
|
||||
{"has hyphen", "a-b", "alphanumeric and 1-8 characters"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matrix := MatrixConfig{
|
||||
Var: map[string]string{tt.alias: "gemma"},
|
||||
Sets: OrderedSets{{Name: "s", DSL: tt.alias}},
|
||||
}
|
||||
_, err := ValidateMatrix(matrix, models)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateMatrix_AliasReferencesUnknownModel(t *testing.T) {
|
||||
models := makeModels("gemma")
|
||||
|
||||
matrix := MatrixConfig{
|
||||
Var: map[string]string{"x": "nonexistent"},
|
||||
Sets: OrderedSets{{Name: "s", DSL: "x"}},
|
||||
}
|
||||
|
||||
_, err := ValidateMatrix(matrix, models)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown model")
|
||||
}
|
||||
|
||||
func TestValidateMatrix_EvictCostInvalid(t *testing.T) {
|
||||
models := makeModels("gemma")
|
||||
|
||||
t.Run("zero cost", func(t *testing.T) {
|
||||
matrix := MatrixConfig{
|
||||
Var: map[string]string{"g": "gemma"},
|
||||
EvictCosts: map[string]int{"g": 0},
|
||||
Sets: OrderedSets{{Name: "s", DSL: "g"}},
|
||||
}
|
||||
_, err := ValidateMatrix(matrix, models)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "positive integer")
|
||||
})
|
||||
|
||||
t.Run("negative cost", func(t *testing.T) {
|
||||
matrix := MatrixConfig{
|
||||
Var: map[string]string{"g": "gemma"},
|
||||
EvictCosts: map[string]int{"g": -1},
|
||||
Sets: OrderedSets{{Name: "s", DSL: "g"}},
|
||||
}
|
||||
_, err := ValidateMatrix(matrix, models)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "positive integer")
|
||||
})
|
||||
|
||||
t.Run("unknown var ID in evict_costs", func(t *testing.T) {
|
||||
matrix := MatrixConfig{
|
||||
Var: map[string]string{"g": "gemma"},
|
||||
EvictCosts: map[string]int{"unknown": 5},
|
||||
Sets: OrderedSets{{Name: "s", DSL: "g"}},
|
||||
}
|
||||
_, err := ValidateMatrix(matrix, models)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown var ID")
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateMatrix_CycleDetection(t *testing.T) {
|
||||
models := makeModels("gemma")
|
||||
|
||||
matrix := MatrixConfig{
|
||||
Var: map[string]string{"g": "gemma"},
|
||||
Sets: OrderedSets{
|
||||
{Name: "a", DSL: "+b"},
|
||||
{Name: "b", DSL: "+a"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ValidateMatrix(matrix, models)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "circular reference")
|
||||
}
|
||||
|
||||
func TestValidateMatrix_UndefinedRefTarget(t *testing.T) {
|
||||
models := makeModels("gemma")
|
||||
|
||||
matrix := MatrixConfig{
|
||||
Var: map[string]string{"g": "gemma"},
|
||||
Sets: OrderedSets{
|
||||
{Name: "a", DSL: "+nonexistent"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ValidateMatrix(matrix, models)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "references undefined set")
|
||||
}
|
||||
|
||||
func TestValidateMatrix_NoSets(t *testing.T) {
|
||||
_, err := ValidateMatrix(MatrixConfig{}, makeModels("gemma"))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "at least one set")
|
||||
}
|
||||
|
||||
func TestValidateMatrix_UnknownMapIDInDSL(t *testing.T) {
|
||||
models := makeModels("gemma")
|
||||
|
||||
matrix := MatrixConfig{
|
||||
Var: map[string]string{"g": "gemma"},
|
||||
Sets: OrderedSets{
|
||||
{Name: "s", DSL: "g & nonexistent"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ValidateMatrix(matrix, models)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown var ID")
|
||||
}
|
||||
|
||||
func TestValidateMatrix_ResolvedEvictCosts(t *testing.T) {
|
||||
mc := &MatrixConfig{
|
||||
Var: map[string]string{
|
||||
"g": "gemma",
|
||||
"L": "llama70B",
|
||||
},
|
||||
EvictCosts: map[string]int{
|
||||
"L": 30,
|
||||
"g": 5,
|
||||
},
|
||||
}
|
||||
|
||||
costs := mc.ResolvedEvictCosts()
|
||||
assert.Equal(t, 30, costs["llama70B"])
|
||||
assert.Equal(t, 5, costs["gemma"])
|
||||
}
|
||||
|
||||
func TestValidateMatrix_ConfigXOR(t *testing.T) {
|
||||
// groups and matrix both defined
|
||||
yaml := `
|
||||
models:
|
||||
model1:
|
||||
cmd: echo model1
|
||||
proxy: http://localhost:8080
|
||||
groups:
|
||||
group1:
|
||||
members:
|
||||
- model1
|
||||
matrix:
|
||||
sets:
|
||||
s: "model1"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "cannot use both")
|
||||
}
|
||||
|
||||
func TestValidateMatrix_ConfigMatrixOnly(t *testing.T) {
|
||||
yaml := `
|
||||
models:
|
||||
gemma:
|
||||
cmd: echo gemma
|
||||
proxy: http://localhost:8080
|
||||
qwen:
|
||||
cmd: echo qwen
|
||||
proxy: http://localhost:8081
|
||||
matrix:
|
||||
vars:
|
||||
g: gemma
|
||||
q: qwen
|
||||
sets:
|
||||
combo: "g | q"
|
||||
`
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, cfg.Matrix)
|
||||
assert.Len(t, cfg.ExpandedSets, 2)
|
||||
// Groups should be empty when matrix is used
|
||||
assert.Empty(t, cfg.Groups)
|
||||
}
|
||||
|
||||
func filterBySetName(sets []ExpandedSet, name string) []ExpandedSet {
|
||||
var result []ExpandedSet
|
||||
for _, s := range sets {
|
||||
if s.SetName == name {
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,139 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
const (
|
||||
MODEL_CONFIG_DEFAULT_TTL = -1
|
||||
)
|
||||
|
||||
// TimeoutsConfig holds timeout settings for proxy connections
|
||||
// 0 = no timeout
|
||||
type TimeoutsConfig struct {
|
||||
Connect int `yaml:"connect"`
|
||||
KeepAlive int `yaml:"keepalive"`
|
||||
ResponseHeader int `yaml:"responseHeader"`
|
||||
TLSHandshake int `yaml:"tlsHandshake"`
|
||||
ExpectContinue int `yaml:"expectContinue"`
|
||||
IdleConn int `yaml:"idleConn"`
|
||||
}
|
||||
|
||||
type ModelConfig struct {
|
||||
Cmd string `yaml:"cmd"`
|
||||
CmdStop string `yaml:"cmdStop"`
|
||||
Proxy string `yaml:"proxy"`
|
||||
Aliases []string `yaml:"aliases"`
|
||||
Env []string `yaml:"env"`
|
||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||
UnloadAfter int `yaml:"ttl"`
|
||||
Unlisted bool `yaml:"unlisted"`
|
||||
UseModelName string `yaml:"useModelName"`
|
||||
|
||||
// #179 for /v1/models
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
|
||||
// Limit concurrency of HTTP requests to process
|
||||
ConcurrencyLimit int `yaml:"concurrencyLimit"`
|
||||
|
||||
// Model filters see issue #174
|
||||
Filters ModelFilters `yaml:"filters"`
|
||||
|
||||
// Macros: see #264
|
||||
// Model level macros take precedence over the global macros
|
||||
Macros MacroList `yaml:"macros"`
|
||||
|
||||
// Metadata: see #264
|
||||
// Arbitrary metadata that can be exposed through the API
|
||||
Metadata map[string]any `yaml:"metadata"`
|
||||
|
||||
// override global setting
|
||||
SendLoadingState *bool `yaml:"sendLoadingState"`
|
||||
|
||||
// Timeout settings for proxy connections
|
||||
Timeouts TimeoutsConfig `yaml:"timeouts"`
|
||||
|
||||
// Copy of HealthCheckTimeout from global config
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
}
|
||||
|
||||
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawModelConfig ModelConfig
|
||||
defaults := rawModelConfig{
|
||||
Cmd: "",
|
||||
CmdStop: "",
|
||||
Proxy: "http://localhost:${PORT}",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/health",
|
||||
UnloadAfter: MODEL_CONFIG_DEFAULT_TTL, // use GlobalTTL
|
||||
Unlisted: false,
|
||||
UseModelName: "",
|
||||
ConcurrencyLimit: 0,
|
||||
Name: "",
|
||||
Description: "",
|
||||
|
||||
// matches http.DefaultTransport
|
||||
Timeouts: TimeoutsConfig{
|
||||
Connect: 30,
|
||||
KeepAlive: 30,
|
||||
ResponseHeader: 0,
|
||||
TLSHandshake: 10,
|
||||
ExpectContinue: 1,
|
||||
IdleConn: 90,
|
||||
},
|
||||
}
|
||||
|
||||
// the default cmdStop to taskkill /f /t /pid ${PID}
|
||||
if runtime.GOOS == "windows" {
|
||||
defaults.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*m = ModelConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
return SanitizeCommand(m.Cmd)
|
||||
}
|
||||
|
||||
// ModelFilters embeds Filters and adds legacy support for strip_params field
|
||||
// See issue #174
|
||||
type ModelFilters struct {
|
||||
Filters `yaml:",inline"`
|
||||
}
|
||||
|
||||
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawModelFilters ModelFilters
|
||||
defaults := rawModelFilters{}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Try to unmarshal with the old field name for backwards compatibility
|
||||
if defaults.StripParams == "" {
|
||||
var legacy struct {
|
||||
StripParams string `yaml:"strip_params"`
|
||||
}
|
||||
if legacyErr := unmarshal(&legacy); legacyErr != nil {
|
||||
return errors.New("failed to unmarshal legacy filters.strip_params: " + legacyErr.Error())
|
||||
}
|
||||
defaults.StripParams = legacy.StripParams
|
||||
}
|
||||
|
||||
*m = ModelFilters(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SanitizedStripParams wraps Filters.SanitizedStripParams for backwards compatibility
|
||||
// Returns ([]string, error) to match existing API
|
||||
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
|
||||
return f.Filters.SanitizedStripParams(), nil
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||
config := &ModelConfig{
|
||||
Cmd: `python model1.py \
|
||||
--arg1 value1 \
|
||||
--arg2 value2`,
|
||||
}
|
||||
|
||||
args, err := config.SanitizedCommand()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
||||
}
|
||||
|
||||
func TestConfig_ModelFilters(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
default_strip: "temperature, top_p"
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
# macros inserted and list is cleaned of duplicates and empty strings
|
||||
stripParams: "model, top_k, top_k, temperature, ${default_strip}, , ,"
|
||||
# check for strip_params (legacy field name) compatibility
|
||||
legacy:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
strip_params: "model, top_k, top_k, temperature, ${default_strip}, , ,"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
for modelId, modelConfig := range config.Models {
|
||||
t.Run(fmt.Sprintf("Testing macros in filters for model %s", modelId), func(t *testing.T) {
|
||||
assert.Equal(t, "model, top_k, top_k, temperature, temperature, top_p, , ,", modelConfig.Filters.StripParams)
|
||||
sanitized, err := modelConfig.Filters.SanitizedStripParams()
|
||||
if assert.NoError(t, err) {
|
||||
// model has been removed
|
||||
// empty strings have been removed
|
||||
// duplicates have been removed
|
||||
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_ModelSendLoadingState(t *testing.T) {
|
||||
content := `
|
||||
sendLoadingState: true
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
sendLoadingState: false
|
||||
model2:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, config.SendLoadingState)
|
||||
if assert.NotNil(t, config.Models["model1"].SendLoadingState) {
|
||||
assert.False(t, *config.Models["model1"].SendLoadingState)
|
||||
}
|
||||
if assert.NotNil(t, config.Models["model2"].SendLoadingState) {
|
||||
assert.True(t, *config.Models["model2"].SendLoadingState)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_SetParamsByIDAutoAlias(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
setParamsByID:
|
||||
"${MODEL_ID}:high":
|
||||
reasoning_effort: high
|
||||
"${MODEL_ID}:low":
|
||||
reasoning_effort: low
|
||||
`
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Keys (other than the model's own ID) should be registered as aliases
|
||||
realName, found := cfg.RealModelName("model1:high")
|
||||
assert.True(t, found, "model1:high should be an auto-registered alias")
|
||||
assert.Equal(t, "model1", realName)
|
||||
|
||||
realName, found = cfg.RealModelName("model1:low")
|
||||
assert.True(t, found, "model1:low should be an auto-registered alias")
|
||||
assert.Equal(t, "model1", realName)
|
||||
|
||||
// Auto-aliases should also appear in modelConfig.Aliases
|
||||
aliases := cfg.Models["model1"].Aliases
|
||||
assert.Contains(t, aliases, "model1:high")
|
||||
assert.Contains(t, aliases, "model1:low")
|
||||
}
|
||||
|
||||
func TestConfig_SetParamsByIDAutoAliasConflictWithModelID(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
setParamsByID:
|
||||
model2:
|
||||
reasoning_effort: high
|
||||
model2:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.ErrorContains(t, err, "conflicts with an existing model ID")
|
||||
}
|
||||
|
||||
func TestConfig_SetParamsByIDAutoAliasConflictWithOtherModel(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
setParamsByID:
|
||||
"shared-alias":
|
||||
reasoning_effort: high
|
||||
model2:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
setParamsByID:
|
||||
"shared-alias":
|
||||
reasoning_effort: low
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.ErrorContains(t, err, "duplicate alias")
|
||||
}
|
||||
|
||||
func TestConfig_ModelFiltersWithSetParams(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
stripParams: "top_k"
|
||||
setParams:
|
||||
temperature: 0.7
|
||||
top_p: 0.9
|
||||
stop:
|
||||
- "<|end|>"
|
||||
- "<|stop|>"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
modelConfig := config.Models["model1"]
|
||||
|
||||
// Check stripParams
|
||||
stripParams, err := modelConfig.Filters.SanitizedStripParams()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"top_k"}, stripParams)
|
||||
|
||||
// Check setParams
|
||||
setParams, keys := modelConfig.Filters.SanitizedSetParams()
|
||||
assert.NotNil(t, setParams)
|
||||
assert.Equal(t, []string{"stop", "temperature", "top_p"}, keys)
|
||||
assert.Equal(t, 0.7, setParams["temperature"])
|
||||
assert.Equal(t, 0.9, setParams["top_p"])
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type PeerDictionaryConfig map[string]PeerConfig
|
||||
type PeerConfig struct {
|
||||
Proxy string `yaml:"proxy"`
|
||||
ProxyURL *url.URL `yaml:"-"`
|
||||
ApiKey string `yaml:"apiKey"`
|
||||
Models []string `yaml:"models"`
|
||||
Filters Filters `yaml:"filters"`
|
||||
|
||||
// Timeout settings for proxy connections
|
||||
Timeouts TimeoutsConfig `yaml:"timeouts"`
|
||||
}
|
||||
|
||||
func (c *PeerConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawPeerConfig PeerConfig
|
||||
defaults := rawPeerConfig{
|
||||
Proxy: "",
|
||||
ApiKey: "",
|
||||
Models: []string{},
|
||||
Filters: Filters{},
|
||||
|
||||
// mostly matches http.DefaultTransport but with a 60s ResponseHeader timeout
|
||||
// to match the pre PR #619 functionality
|
||||
Timeouts: TimeoutsConfig{
|
||||
Connect: 30,
|
||||
KeepAlive: 30,
|
||||
ResponseHeader: 60,
|
||||
TLSHandshake: 10,
|
||||
ExpectContinue: 1,
|
||||
IdleConn: 90,
|
||||
},
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate proxy is not empty
|
||||
if defaults.Proxy == "" {
|
||||
return fmt.Errorf("proxy is required")
|
||||
}
|
||||
|
||||
// Validate proxy is a valid URL and store the parsed value
|
||||
parsedURL, err := url.Parse(defaults.Proxy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid peer proxy URL (%s): %w", defaults.Proxy, err)
|
||||
}
|
||||
defaults.ProxyURL = parsedURL
|
||||
|
||||
// Validate models is not empty
|
||||
if len(defaults.Models) == 0 {
|
||||
return fmt.Errorf("peer models can not be empty")
|
||||
}
|
||||
|
||||
*c = PeerConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,209 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestPeerConfig_UnmarshalYAML(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
yaml string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
yaml: `
|
||||
proxy: http://192.168.1.23
|
||||
models:
|
||||
- model_a
|
||||
- model_b
|
||||
`,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "valid config with apiKey",
|
||||
yaml: `
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-test-key
|
||||
models:
|
||||
- meta-llama/llama-3.1-8b-instruct
|
||||
`,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "missing proxy",
|
||||
yaml: `
|
||||
models:
|
||||
- model_a
|
||||
`,
|
||||
wantErr: "proxy is required",
|
||||
},
|
||||
{
|
||||
name: "empty proxy",
|
||||
yaml: `
|
||||
proxy: ""
|
||||
models:
|
||||
- model_a
|
||||
`,
|
||||
wantErr: "proxy is required",
|
||||
},
|
||||
{
|
||||
name: "invalid proxy URL",
|
||||
yaml: `
|
||||
proxy: "://invalid"
|
||||
models:
|
||||
- model_a
|
||||
`,
|
||||
wantErr: "invalid peer proxy URL",
|
||||
},
|
||||
{
|
||||
name: "missing models",
|
||||
yaml: `
|
||||
proxy: http://localhost:8080
|
||||
`,
|
||||
wantErr: "peer models can not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty models",
|
||||
yaml: `
|
||||
proxy: http://localhost:8080
|
||||
models: []
|
||||
`,
|
||||
wantErr: "peer models can not be empty",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(tt.yaml), &config)
|
||||
|
||||
if tt.wantErr == "" {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Errorf("expected error containing %q, got nil", tt.wantErr)
|
||||
} else if !contains(err.Error(), tt.wantErr) {
|
||||
t.Errorf("expected error containing %q, got %q", tt.wantErr, err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerConfig_ProxyURL(t *testing.T) {
|
||||
yamlData := `
|
||||
proxy: http://192.168.1.23:8080/api
|
||||
apiKey: sk-test
|
||||
models:
|
||||
- model_a
|
||||
`
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if config.ProxyURL == nil {
|
||||
t.Fatal("ProxyURL should not be nil")
|
||||
}
|
||||
|
||||
if config.ProxyURL.Host != "192.168.1.23:8080" {
|
||||
t.Errorf("expected host %q, got %q", "192.168.1.23:8080", config.ProxyURL.Host)
|
||||
}
|
||||
|
||||
if config.ProxyURL.Scheme != "http" {
|
||||
t.Errorf("expected scheme %q, got %q", "http", config.ProxyURL.Scheme)
|
||||
}
|
||||
|
||||
if config.ProxyURL.Path != "/api" {
|
||||
t.Errorf("expected path %q, got %q", "/api", config.ProxyURL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && searchSubstring(s, substr)
|
||||
}
|
||||
|
||||
func searchSubstring(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestPeerConfig_WithFilters(t *testing.T) {
|
||||
yamlData := `
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-test
|
||||
models:
|
||||
- model_a
|
||||
filters:
|
||||
setParams:
|
||||
temperature: 0.7
|
||||
provider:
|
||||
data_collection: deny
|
||||
`
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if config.Filters.SetParams == nil {
|
||||
t.Fatal("Filters.SetParams should not be nil")
|
||||
}
|
||||
|
||||
if config.Filters.SetParams["temperature"] != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %v", config.Filters.SetParams["temperature"])
|
||||
}
|
||||
|
||||
provider, ok := config.Filters.SetParams["provider"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("provider should be a map")
|
||||
}
|
||||
if provider["data_collection"] != "deny" {
|
||||
t.Errorf("expected data_collection deny, got %v", provider["data_collection"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerConfig_WithBothFilters(t *testing.T) {
|
||||
yamlData := `
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-test
|
||||
models:
|
||||
- model_a
|
||||
filters:
|
||||
stripParams: "temperature, top_p"
|
||||
setParams:
|
||||
max_tokens: 1000
|
||||
`
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Check stripParams
|
||||
stripParams := config.Filters.SanitizedStripParams()
|
||||
if len(stripParams) != 2 {
|
||||
t.Errorf("expected 2 strip params, got %d", len(stripParams))
|
||||
}
|
||||
if stripParams[0] != "temperature" || stripParams[1] != "top_p" {
|
||||
t.Errorf("unexpected strip params: %v", stripParams)
|
||||
}
|
||||
|
||||
// Check setParams
|
||||
if config.Filters.SetParams == nil {
|
||||
t.Fatal("Filters.SetParams should not be nil")
|
||||
}
|
||||
if config.Filters.SetParams["max_tokens"] != 1000 {
|
||||
t.Errorf("expected max_tokens 1000, got %v", config.Filters.SetParams["max_tokens"])
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PerformanceConfig holds configuration for system performance monitoring
|
||||
type PerformanceConfig struct {
|
||||
Disabled bool `yaml:"disabled"`
|
||||
Every time.Duration `yaml:"every"`
|
||||
}
|
||||
|
||||
func (p *PerformanceConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawPerformanceConfig PerformanceConfig
|
||||
defaults := rawPerformanceConfig{
|
||||
Every: 5 * time.Second,
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*p = PerformanceConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate checks the PerformanceConfig values and returns an error if invalid
|
||||
func (p *PerformanceConfig) Validate() error {
|
||||
if p.Every < 5*time.Second {
|
||||
return fmt.Errorf("every must be at least 5s, got %v", p.Every)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPerformanceConfig_Defaults(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// When performance section is missing, defaults should be applied
|
||||
assert.False(t, config.Performance.Disabled)
|
||||
assert.Equal(t, 5*time.Second, config.Performance.Every)
|
||||
}
|
||||
|
||||
func TestPerformanceConfig_CustomValues(t *testing.T) {
|
||||
content := `
|
||||
performance:
|
||||
enable: true
|
||||
every: 30s
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.False(t, config.Performance.Disabled)
|
||||
assert.Equal(t, 30*time.Second, config.Performance.Every)
|
||||
}
|
||||
|
||||
func TestPerformanceConfig_Disabled(t *testing.T) {
|
||||
content := `
|
||||
performance:
|
||||
disabled: true
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.True(t, config.Performance.Disabled)
|
||||
// Duration defaults should still apply
|
||||
assert.Equal(t, 5*time.Second, config.Performance.Every)
|
||||
}
|
||||
|
||||
func TestPerformanceConfig_PartialValues(t *testing.T) {
|
||||
content := `
|
||||
performance:
|
||||
every: 10s
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// enable should default to true
|
||||
assert.False(t, config.Performance.Disabled)
|
||||
assert.Equal(t, 10*time.Second, config.Performance.Every)
|
||||
}
|
||||
|
||||
func TestPerformanceConfig_InvalidEvery(t *testing.T) {
|
||||
content := `
|
||||
performance:
|
||||
every: 4s
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "every must be at least 5s")
|
||||
}
|
||||
|
||||
func TestPerformanceConfig_ComplexDurations(t *testing.T) {
|
||||
content := `
|
||||
performance:
|
||||
every: 1m30s
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 90*time.Second, config.Performance.Every)
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
The code in `event` was originally a part of https://github.com/kelindar/event (v1.5.2)
|
||||
|
||||
The original code uses a `time.Ticker` to process the event queue which caused a large increase in CPU usage ([#189](https://github.com/mostlygeek/llama-swap/issues/189)). This code was ported to remove the ticker and instead be more event driven.
|
||||
@@ -0,0 +1,30 @@
|
||||
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||
|
||||
package event
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Default initializes a default in-process dispatcher
|
||||
var Default = NewDispatcherConfig(25000)
|
||||
|
||||
// On subscribes to an event, the type of the event will be automatically
|
||||
// inferred from the provided type. Must be constant for this to work. This
|
||||
// functions same way as Subscribe() but uses the default dispatcher instead.
|
||||
func On[T Event](handler func(T)) context.CancelFunc {
|
||||
return Subscribe(Default, handler)
|
||||
}
|
||||
|
||||
// OnType subscribes to an event with the specified event type. This functions
|
||||
// same way as SubscribeTo() but uses the default dispatcher instead.
|
||||
func OnType[T Event](eventType uint32, handler func(T)) context.CancelFunc {
|
||||
return SubscribeTo(Default, eventType, handler)
|
||||
}
|
||||
|
||||
// Emit writes an event into the dispatcher. This functions same way as
|
||||
// Publish() but uses the default dispatcher instead.
|
||||
func Emit[T Event](ev T) {
|
||||
Publish(Default, ev)
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||
|
||||
package event
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
/*
|
||||
cpu: 13th Gen Intel(R) Core(TM) i7-13700K
|
||||
BenchmarkSubcribeConcurrent-24 1826686 606.3 ns/op 1648 B/op 5 allocs/op
|
||||
*/
|
||||
func BenchmarkSubscribeConcurrent(b *testing.B) {
|
||||
d := NewDispatcher()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
unsub := Subscribe(d, func(ev MyEvent1) {})
|
||||
unsub()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultPublish(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Subscribe
|
||||
var count int64
|
||||
defer On(func(ev MyEvent1) {
|
||||
atomic.AddInt64(&count, 1)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
defer OnType(TypeEvent1, func(ev MyEvent1) {
|
||||
atomic.AddInt64(&count, 1)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
// Publish
|
||||
wg.Add(4)
|
||||
Emit(MyEvent1{})
|
||||
Emit(MyEvent1{})
|
||||
|
||||
// Wait and check
|
||||
wg.Wait()
|
||||
assert.Equal(t, int64(4), count)
|
||||
}
|
||||
@@ -0,0 +1,324 @@
|
||||
// Copyright (c) Roman Atachiants and contributors. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for details.
|
||||
|
||||
package event
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Event represents an event contract
|
||||
type Event interface {
|
||||
Type() uint32
|
||||
}
|
||||
|
||||
// registry holds an immutable sorted array of event mappings
|
||||
type registry struct {
|
||||
keys []uint32 // Event types (sorted)
|
||||
grps []any // Corresponding subscribers
|
||||
}
|
||||
|
||||
// ------------------------------------- Dispatcher -------------------------------------
|
||||
|
||||
// Dispatcher represents an event dispatcher.
|
||||
type Dispatcher struct {
|
||||
subs atomic.Pointer[registry] // Atomic pointer to immutable array
|
||||
done chan struct{} // Cancellation
|
||||
maxQueue int // Maximum queue size per consumer
|
||||
mu sync.Mutex // Only for writes (subscribe/unsubscribe)
|
||||
}
|
||||
|
||||
// NewDispatcher creates a new dispatcher of events.
|
||||
func NewDispatcher() *Dispatcher {
|
||||
return NewDispatcherConfig(50000)
|
||||
}
|
||||
|
||||
// NewDispatcherConfig creates a new dispatcher with configurable max queue size
|
||||
func NewDispatcherConfig(maxQueue int) *Dispatcher {
|
||||
d := &Dispatcher{
|
||||
done: make(chan struct{}),
|
||||
maxQueue: maxQueue,
|
||||
}
|
||||
|
||||
d.subs.Store(®istry{
|
||||
keys: make([]uint32, 0, 16),
|
||||
grps: make([]any, 0, 16),
|
||||
})
|
||||
return d
|
||||
}
|
||||
|
||||
// Close closes the dispatcher
|
||||
func (d *Dispatcher) Close() error {
|
||||
close(d.done)
|
||||
return nil
|
||||
}
|
||||
|
||||
// isClosed returns whether the dispatcher is closed or not
|
||||
func (d *Dispatcher) isClosed() bool {
|
||||
select {
|
||||
case <-d.done:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// findGroup performs a lock-free binary search for the event type
|
||||
func (d *Dispatcher) findGroup(eventType uint32) any {
|
||||
reg := d.subs.Load()
|
||||
keys := reg.keys
|
||||
|
||||
// Inlined binary search for better cache locality
|
||||
left, right := 0, len(keys)
|
||||
for left < right {
|
||||
mid := left + (right-left)/2
|
||||
if keys[mid] < eventType {
|
||||
left = mid + 1
|
||||
} else {
|
||||
right = mid
|
||||
}
|
||||
}
|
||||
|
||||
if left < len(keys) && keys[left] == eventType {
|
||||
return reg.grps[left]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Subscribe subscribes to an event, the type of the event will be automatically
|
||||
// inferred from the provided type. Must be constant for this to work.
|
||||
func Subscribe[T Event](broker *Dispatcher, handler func(T)) context.CancelFunc {
|
||||
var event T
|
||||
return SubscribeTo(broker, event.Type(), handler)
|
||||
}
|
||||
|
||||
// SubscribeTo subscribes to an event with the specified event type.
|
||||
func SubscribeTo[T Event](broker *Dispatcher, eventType uint32, handler func(T)) context.CancelFunc {
|
||||
if broker.isClosed() {
|
||||
panic(errClosed)
|
||||
}
|
||||
|
||||
broker.mu.Lock()
|
||||
defer broker.mu.Unlock()
|
||||
|
||||
// Check if group already exists
|
||||
if existing := broker.findGroup(eventType); existing != nil {
|
||||
grp := groupOf[T](eventType, existing)
|
||||
sub := grp.Add(handler)
|
||||
return func() {
|
||||
grp.Del(sub)
|
||||
}
|
||||
}
|
||||
|
||||
// Create new group
|
||||
grp := &group[T]{cond: sync.NewCond(new(sync.Mutex)), maxQueue: broker.maxQueue}
|
||||
sub := grp.Add(handler)
|
||||
|
||||
// Copy-on-write: insert new entry in sorted position
|
||||
old := broker.subs.Load()
|
||||
idx := sort.Search(len(old.keys), func(i int) bool {
|
||||
return old.keys[i] >= eventType
|
||||
})
|
||||
|
||||
// Create new arrays with space for one more element
|
||||
newKeys := make([]uint32, len(old.keys)+1)
|
||||
newGrps := make([]any, len(old.grps)+1)
|
||||
|
||||
// Copy elements before insertion point
|
||||
copy(newKeys[:idx], old.keys[:idx])
|
||||
copy(newGrps[:idx], old.grps[:idx])
|
||||
|
||||
// Insert new element
|
||||
newKeys[idx] = eventType
|
||||
newGrps[idx] = grp
|
||||
|
||||
// Copy elements after insertion point
|
||||
copy(newKeys[idx+1:], old.keys[idx:])
|
||||
copy(newGrps[idx+1:], old.grps[idx:])
|
||||
|
||||
// Atomically store the new registry (mutex ensures no concurrent writers)
|
||||
newReg := ®istry{keys: newKeys, grps: newGrps}
|
||||
broker.subs.Store(newReg)
|
||||
|
||||
return func() {
|
||||
grp.Del(sub)
|
||||
}
|
||||
}
|
||||
|
||||
// Publish writes an event into the dispatcher
|
||||
func Publish[T Event](broker *Dispatcher, ev T) {
|
||||
eventType := ev.Type()
|
||||
if sub := broker.findGroup(eventType); sub != nil {
|
||||
group := groupOf[T](eventType, sub)
|
||||
group.Broadcast(ev)
|
||||
}
|
||||
}
|
||||
|
||||
// Count counts the number of subscribers, this is for testing only.
|
||||
func (d *Dispatcher) count(eventType uint32) int {
|
||||
if group := d.findGroup(eventType); group != nil {
|
||||
return group.(interface{ Count() int }).Count()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// groupOf casts the subscriber group to the specified generic type
|
||||
func groupOf[T Event](eventType uint32, subs any) *group[T] {
|
||||
if group, ok := subs.(*group[T]); ok {
|
||||
return group
|
||||
}
|
||||
|
||||
panic(errConflict[T](eventType, subs))
|
||||
}
|
||||
|
||||
// ------------------------------------- Subscriber -------------------------------------
|
||||
|
||||
// consumer represents a consumer with a message queue
|
||||
type consumer[T Event] struct {
|
||||
queue []T // Current work queue
|
||||
stop bool // Stop signal
|
||||
}
|
||||
|
||||
// Listen listens to the event queue and processes events
|
||||
func (s *consumer[T]) Listen(c *sync.Cond, fn func(T)) {
|
||||
pending := make([]T, 0, 128)
|
||||
|
||||
for {
|
||||
c.L.Lock()
|
||||
for len(s.queue) == 0 {
|
||||
switch {
|
||||
case s.stop:
|
||||
c.L.Unlock()
|
||||
return
|
||||
default:
|
||||
c.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
// Swap buffers and reset the current queue
|
||||
temp := s.queue
|
||||
s.queue = pending[:0]
|
||||
pending = temp
|
||||
c.L.Unlock()
|
||||
|
||||
// Outside of the critical section, process the work
|
||||
for _, event := range pending {
|
||||
fn(event)
|
||||
}
|
||||
|
||||
// Notify potential publishers waiting due to backpressure
|
||||
c.Broadcast()
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------- Subscriber Group -------------------------------------
|
||||
|
||||
// group represents a consumer group
|
||||
type group[T Event] struct {
|
||||
cond *sync.Cond
|
||||
subs []*consumer[T]
|
||||
maxQueue int // Maximum queue size per consumer
|
||||
maxLen int // Current maximum queue length across all consumers
|
||||
}
|
||||
|
||||
// Broadcast sends an event to all consumers
|
||||
func (s *group[T]) Broadcast(ev T) {
|
||||
s.cond.L.Lock()
|
||||
defer s.cond.L.Unlock()
|
||||
|
||||
// Calculate current maximum queue length
|
||||
s.maxLen = 0
|
||||
for _, sub := range s.subs {
|
||||
if len(sub.queue) > s.maxLen {
|
||||
s.maxLen = len(sub.queue)
|
||||
}
|
||||
}
|
||||
|
||||
// Backpressure: wait if queues are full
|
||||
for s.maxLen >= s.maxQueue {
|
||||
s.cond.Wait()
|
||||
|
||||
// Recalculate after wakeup
|
||||
s.maxLen = 0
|
||||
for _, sub := range s.subs {
|
||||
if len(sub.queue) > s.maxLen {
|
||||
s.maxLen = len(sub.queue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add event to all queues and track new maximum
|
||||
newMax := 0
|
||||
for _, sub := range s.subs {
|
||||
sub.queue = append(sub.queue, ev)
|
||||
if len(sub.queue) > newMax {
|
||||
newMax = len(sub.queue)
|
||||
}
|
||||
}
|
||||
s.maxLen = newMax
|
||||
s.cond.Broadcast() // Wake consumers
|
||||
}
|
||||
|
||||
// Add adds a subscriber to the list
|
||||
func (s *group[T]) Add(handler func(T)) *consumer[T] {
|
||||
sub := &consumer[T]{
|
||||
queue: make([]T, 0, 64),
|
||||
}
|
||||
|
||||
// Add the consumer to the list of active consumers
|
||||
s.cond.L.Lock()
|
||||
s.subs = append(s.subs, sub)
|
||||
s.cond.L.Unlock()
|
||||
|
||||
// Start listening
|
||||
go sub.Listen(s.cond, handler)
|
||||
return sub
|
||||
}
|
||||
|
||||
// Del removes a subscriber from the list
|
||||
func (s *group[T]) Del(sub *consumer[T]) {
|
||||
s.cond.L.Lock()
|
||||
defer s.cond.L.Unlock()
|
||||
|
||||
// Search and remove the subscriber
|
||||
sub.stop = true
|
||||
for i, v := range s.subs {
|
||||
if v == sub {
|
||||
copy(s.subs[i:], s.subs[i+1:])
|
||||
s.subs = s.subs[:len(s.subs)-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------- Debugging -------------------------------------
|
||||
|
||||
var errClosed = fmt.Errorf("event dispatcher is closed")
|
||||
|
||||
// Count returns the number of subscribers in this group
|
||||
func (s *group[T]) Count() int {
|
||||
return len(s.subs)
|
||||
}
|
||||
|
||||
// String returns string representation of the type
|
||||
func (s *group[T]) String() string {
|
||||
typ := reflect.TypeOf(s).String()
|
||||
idx := strings.LastIndex(typ, "/")
|
||||
typ = typ[idx+1 : len(typ)-1]
|
||||
return typ
|
||||
}
|
||||
|
||||
// errConflict returns a conflict message
|
||||
func errConflict[T any](eventType uint32, existing any) string {
|
||||
var want T
|
||||
return fmt.Sprintf(
|
||||
"conflicting event type, want=<%T>, registered=<%s>, event=0x%v",
|
||||
want, existing, eventType,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,324 @@
|
||||
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||
|
||||
package event
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPublish(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Subscribe, must be received in order
|
||||
var count int64
|
||||
defer Subscribe(d, func(ev MyEvent1) {
|
||||
assert.Equal(t, int(atomic.AddInt64(&count, 1)), ev.Number)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
// Publish
|
||||
wg.Add(3)
|
||||
Publish(d, MyEvent1{Number: 1})
|
||||
Publish(d, MyEvent1{Number: 2})
|
||||
Publish(d, MyEvent1{Number: 3})
|
||||
|
||||
// Wait and check
|
||||
wg.Wait()
|
||||
assert.Equal(t, int64(3), count)
|
||||
}
|
||||
|
||||
func TestUnsubscribe(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
assert.Equal(t, 0, d.count(TypeEvent1))
|
||||
unsubscribe := Subscribe(d, func(ev MyEvent1) {
|
||||
// Nothing
|
||||
})
|
||||
|
||||
assert.Equal(t, 1, d.count(TypeEvent1))
|
||||
unsubscribe()
|
||||
assert.Equal(t, 0, d.count(TypeEvent1))
|
||||
}
|
||||
|
||||
func TestConcurrent(t *testing.T) {
|
||||
const max = 1000000
|
||||
var count int64
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
d := NewDispatcher()
|
||||
defer Subscribe(d, func(ev MyEvent1) {
|
||||
if current := atomic.AddInt64(&count, 1); current == max {
|
||||
wg.Done()
|
||||
}
|
||||
})()
|
||||
|
||||
// Asynchronously publish
|
||||
go func() {
|
||||
for i := 0; i < max; i++ {
|
||||
Publish(d, MyEvent1{})
|
||||
}
|
||||
}()
|
||||
|
||||
defer Subscribe(d, func(ev MyEvent1) {
|
||||
// Subscriber that does nothing
|
||||
})()
|
||||
|
||||
wg.Wait()
|
||||
assert.Equal(t, max, int(count))
|
||||
}
|
||||
|
||||
func TestSubscribeDifferentType(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
assert.Panics(t, func() {
|
||||
SubscribeTo(d, TypeEvent1, func(ev MyEvent1) {})
|
||||
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||
})
|
||||
}
|
||||
|
||||
func TestPublishDifferentType(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
assert.Panics(t, func() {
|
||||
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||
Publish(d, MyEvent1{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestCloseDispatcher(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
defer SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})()
|
||||
|
||||
assert.NoError(t, d.Close())
|
||||
assert.Panics(t, func() {
|
||||
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||
})
|
||||
}
|
||||
|
||||
func TestMatrix(t *testing.T) {
|
||||
const amount = 1000
|
||||
for _, subs := range []int{1, 10, 100} {
|
||||
for _, topics := range []int{1, 10} {
|
||||
expected := subs * topics * amount
|
||||
t.Run(fmt.Sprintf("%dx%d", topics, subs), func(t *testing.T) {
|
||||
var count atomic.Int64
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(expected)
|
||||
|
||||
d := NewDispatcher()
|
||||
for i := 0; i < subs; i++ {
|
||||
for id := 0; id < topics; id++ {
|
||||
defer SubscribeTo(d, uint32(id), func(ev MyEvent3) {
|
||||
count.Add(1)
|
||||
wg.Done()
|
||||
})()
|
||||
}
|
||||
}
|
||||
|
||||
for n := 0; n < amount; n++ {
|
||||
for id := 0; id < topics; id++ {
|
||||
go Publish(d, MyEvent3{ID: id})
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
assert.Equal(t, expected, int(count.Load()))
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentSubscriptionRace(t *testing.T) {
|
||||
// This test specifically targets the race condition that occurs when multiple
|
||||
// goroutines try to subscribe to different event types simultaneously.
|
||||
// Without the CAS loop, subscriptions could be lost due to registry corruption.
|
||||
|
||||
const numGoroutines = 100
|
||||
const numEventTypes = 50
|
||||
|
||||
d := NewDispatcher()
|
||||
defer d.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var receivedCount int64
|
||||
var subscribedTypes sync.Map // Thread-safe map
|
||||
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
// Start multiple goroutines that subscribe to different event types concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Each goroutine subscribes to a unique event type
|
||||
eventType := uint32(goroutineID%numEventTypes + 1000) // Offset to avoid collision with other tests
|
||||
|
||||
// Subscribe to the event type
|
||||
SubscribeTo(d, eventType, func(ev MyEvent3) {
|
||||
atomic.AddInt64(&receivedCount, 1)
|
||||
})
|
||||
|
||||
// Record that this type was subscribed
|
||||
subscribedTypes.Store(eventType, true)
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all subscriptions to complete
|
||||
wg.Wait()
|
||||
|
||||
// Count the number of unique event types subscribed
|
||||
expectedTypes := 0
|
||||
subscribedTypes.Range(func(key, value interface{}) bool {
|
||||
expectedTypes++
|
||||
return true
|
||||
})
|
||||
|
||||
// Small delay to ensure all subscriptions are fully processed
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Publish events to each subscribed type
|
||||
subscribedTypes.Range(func(key, value interface{}) bool {
|
||||
eventType := key.(uint32)
|
||||
Publish(d, MyEvent3{ID: int(eventType)})
|
||||
return true
|
||||
})
|
||||
|
||||
// Wait for all events to be processed
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Verify that we received at least the expected number of events
|
||||
// (there might be more if multiple goroutines subscribed to the same event type)
|
||||
received := atomic.LoadInt64(&receivedCount)
|
||||
assert.GreaterOrEqual(t, int(received), expectedTypes,
|
||||
"Should have received at least %d events, got %d", expectedTypes, received)
|
||||
|
||||
// Verify that we have the expected number of unique event types
|
||||
assert.Equal(t, numEventTypes, expectedTypes,
|
||||
"Should have exactly %d unique event types", numEventTypes)
|
||||
}
|
||||
|
||||
func TestConcurrentHandlerRegistration(t *testing.T) {
|
||||
const numGoroutines = 100
|
||||
|
||||
// Test concurrent subscriptions to the same event type
|
||||
t.Run("SameEventType", func(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
var handlerCount int64
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Start multiple goroutines subscribing to the same event type (0x1)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
SubscribeTo(d, uint32(0x1), func(ev MyEvent1) {
|
||||
atomic.AddInt64(&handlerCount, 1)
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all handlers were registered by publishing an event
|
||||
atomic.StoreInt64(&handlerCount, 0)
|
||||
Publish(d, MyEvent1{})
|
||||
|
||||
// Small delay to ensure all handlers have executed
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
assert.Equal(t, int64(numGoroutines), atomic.LoadInt64(&handlerCount),
|
||||
"Not all handlers were registered due to race condition")
|
||||
})
|
||||
|
||||
// Test concurrent subscriptions to different event types
|
||||
t.Run("DifferentEventTypes", func(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
var wg sync.WaitGroup
|
||||
receivedEvents := make(map[uint32]*int64)
|
||||
|
||||
// Create multiple event types and subscribe concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
eventType := uint32(100 + i)
|
||||
counter := new(int64)
|
||||
receivedEvents[eventType] = counter
|
||||
|
||||
wg.Add(1)
|
||||
go func(et uint32, cnt *int64) {
|
||||
defer wg.Done()
|
||||
SubscribeTo(d, et, func(ev MyEvent3) {
|
||||
atomic.AddInt64(cnt, 1)
|
||||
})
|
||||
}(eventType, counter)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Publish events to all types
|
||||
for eventType := uint32(100); eventType < uint32(100+numGoroutines); eventType++ {
|
||||
Publish(d, MyEvent3{ID: int(eventType)})
|
||||
}
|
||||
|
||||
// Small delay to ensure all handlers have executed
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Verify all event types received their events
|
||||
for eventType, counter := range receivedEvents {
|
||||
assert.Equal(t, int64(1), atomic.LoadInt64(counter),
|
||||
"Event type %d did not receive its event", eventType)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackpressure(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
d.maxQueue = 10
|
||||
|
||||
var processedCount int64
|
||||
unsub := SubscribeTo(d, uint32(0x200), func(ev MyEvent3) {
|
||||
atomic.AddInt64(&processedCount, 1)
|
||||
})
|
||||
defer unsub()
|
||||
|
||||
const eventsToPublish = 1000
|
||||
for i := 0; i < eventsToPublish; i++ {
|
||||
Publish(d, MyEvent3{ID: 0x200})
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify all events were eventually processed
|
||||
finalProcessed := atomic.LoadInt64(&processedCount)
|
||||
assert.Equal(t, int64(eventsToPublish), finalProcessed)
|
||||
t.Logf("Events processed: %d/%d", finalProcessed, eventsToPublish)
|
||||
}
|
||||
|
||||
// ------------------------------------- Test Events -------------------------------------
|
||||
|
||||
const (
|
||||
TypeEvent1 = 0x1
|
||||
TypeEvent2 = 0x2
|
||||
)
|
||||
|
||||
type MyEvent1 struct {
|
||||
Number int
|
||||
}
|
||||
|
||||
func (t MyEvent1) Type() uint32 { return TypeEvent1 }
|
||||
|
||||
type MyEvent2 struct {
|
||||
Text string
|
||||
}
|
||||
|
||||
func (t MyEvent2) Type() uint32 { return TypeEvent2 }
|
||||
|
||||
type MyEvent3 struct {
|
||||
ID int
|
||||
}
|
||||
|
||||
func (t MyEvent3) Type() uint32 { return uint32(t.ID) }
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
)
|
||||
|
||||
const DataEventID = 0x04
|
||||
|
||||
@@ -6,9 +6,9 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/ring"
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var simpleResponderPath string
|
||||
|
||||
func skipIfNoSimpleResponder(t *testing.T) {
|
||||
t.Helper()
|
||||
if _, err := os.Stat(simpleResponderPath); os.IsNotExist(err) {
|
||||
t.Skipf("simple-responder not found at %s, run `make simple-responder`", simpleResponderPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goos := runtime.GOOS
|
||||
goarch := runtime.GOARCH
|
||||
if goos == "windows" {
|
||||
simpleResponderPath = filepath.Join("..", "..", "build", "simple-responder.exe")
|
||||
} else {
|
||||
simpleResponderPath = filepath.Join("..", "..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||
}
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func getFreePort(t *testing.T) int {
|
||||
t.Helper()
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("getFreePort: %v", err)
|
||||
}
|
||||
defer l.Close()
|
||||
return l.Addr().(*net.TCPAddr).Port
|
||||
}
|
||||
|
||||
func simpleResponderCmd(t *testing.T, args ...string) (string, int) {
|
||||
port := getFreePort(t)
|
||||
cmdPath := filepath.ToSlash(simpleResponderPath)
|
||||
base := []string{cmdPath, fmt.Sprintf("-port %d", port)}
|
||||
base = append(base, args...)
|
||||
return strings.Join(base, " "), port
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
type ProcessState string
|
||||
|
||||
const (
|
||||
StateStopped ProcessState = ProcessState("stopped")
|
||||
StateStarting ProcessState = ProcessState("starting")
|
||||
StateReady ProcessState = ProcessState("ready")
|
||||
StateStopping ProcessState = ProcessState("stopping")
|
||||
|
||||
// process is shutdown and will not be restarted
|
||||
StateShutdown ProcessState = ProcessState("shutdown")
|
||||
)
|
||||
|
||||
type Process interface {
|
||||
// Run starts the process blocks until the process is terminated.
|
||||
// The timeout parameter controls how long to wait for the process to get
|
||||
// to a ready state to process traffic
|
||||
Run(timeout time.Duration) error
|
||||
|
||||
// WaitReady blocks until the process is ready to serve requests
|
||||
// or the context is cancelled. It returns nil when the process is ready
|
||||
WaitReady(context.Context) error
|
||||
|
||||
// Stop blocks until the process has terminated. It returns nil when
|
||||
// the process terminated as expected (exit 0)
|
||||
Stop(timeout time.Duration) error
|
||||
|
||||
// State returns the current state of the process
|
||||
// Note: this is a snapshot of the state at the time of the call
|
||||
// and may change at any time after the call returns.
|
||||
State() ProcessState
|
||||
|
||||
// ServeHTTP forwards requests to the underlying process
|
||||
// Calling it when the process is not ready will result in a
|
||||
// 503 response with a body indicating it is a llama-swap-error
|
||||
ServeHTTP(http.ResponseWriter, *http.Request)
|
||||
|
||||
// Logger returns the monitor that captures this process's stdout/stderr.
|
||||
Logger() *logmon.Monitor
|
||||
}
|
||||
@@ -0,0 +1,568 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
var ErrStartAborted = fmt.Errorf("aborted")
|
||||
|
||||
type runReq struct {
|
||||
timeout time.Duration
|
||||
respond chan error
|
||||
}
|
||||
|
||||
type stopReq struct {
|
||||
timeout time.Duration
|
||||
respond chan error
|
||||
}
|
||||
|
||||
type waitReadyReq struct {
|
||||
respond chan error
|
||||
}
|
||||
|
||||
type startResult struct {
|
||||
cmd *exec.Cmd
|
||||
cmdDone chan struct{}
|
||||
handlerFn http.HandlerFunc
|
||||
err error
|
||||
}
|
||||
|
||||
type ProcessCommand struct {
|
||||
id string
|
||||
config config.ModelConfig
|
||||
parentCtx context.Context
|
||||
|
||||
processLogger *logmon.Monitor
|
||||
proxyLogger *logmon.Monitor
|
||||
|
||||
runCh chan runReq
|
||||
stopCh chan stopReq
|
||||
waitReadyCh chan waitReadyReq
|
||||
|
||||
// current ProcessState. Written only by run(); read by State() via atomic load.
|
||||
state atomic.Value
|
||||
|
||||
// stores the active reverse-proxy handler when the process is running.
|
||||
// Written only by run(); read by ServeHTTP via atomic load.
|
||||
handler atomic.Pointer[http.HandlerFunc]
|
||||
|
||||
lastUse atomic.Int64 // unix nano timestamp of last ServeHTTP completion
|
||||
inflight atomic.Int64 // current in-flight ServeHTTP calls
|
||||
}
|
||||
|
||||
var _ Process = (*ProcessCommand)(nil)
|
||||
|
||||
func New(
|
||||
parentCtx context.Context,
|
||||
id string,
|
||||
conf config.ModelConfig,
|
||||
processLogger *logmon.Monitor,
|
||||
proxyLogger *logmon.Monitor,
|
||||
) (*ProcessCommand, error) {
|
||||
p := &ProcessCommand{
|
||||
id: id,
|
||||
config: conf,
|
||||
parentCtx: parentCtx,
|
||||
processLogger: processLogger,
|
||||
proxyLogger: proxyLogger,
|
||||
|
||||
runCh: make(chan runReq),
|
||||
stopCh: make(chan stopReq),
|
||||
waitReadyCh: make(chan waitReadyReq),
|
||||
}
|
||||
p.state.Store(StateStopped)
|
||||
|
||||
go p.run()
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) Logger() *logmon.Monitor { return p.processLogger }
|
||||
|
||||
// run is the single-writer goroutine that owns all mutable lifecycle state
|
||||
// (current ProcessState, the running *exec.Cmd, the active reverse-proxy
|
||||
// handler, and the list of WaitReady subscribers). Every public method
|
||||
// (Run / Stop / State / WaitReady) is a thin client that sends a request on
|
||||
// one of the channels below and waits for a response — this funnels concurrent
|
||||
// callers through a single serialization point so the state machine never
|
||||
// observes a race.
|
||||
func (p *ProcessCommand) run() {
|
||||
// Mutable state — only read/written from this goroutine. ServeHTTP reads
|
||||
// p.handler concurrently, which is why handler is an atomic.Pointer.
|
||||
// p.state mirrors `state` so State() can observe transitions; setState
|
||||
// writes both.
|
||||
state := StateStopped
|
||||
setState := func(s ProcessState) {
|
||||
old := state
|
||||
state = s
|
||||
p.state.Store(s)
|
||||
if old != s {
|
||||
event.Emit(shared.ProcessStateChangeEvent{
|
||||
ProcessName: p.id,
|
||||
OldState: string(old),
|
||||
NewState: string(s),
|
||||
})
|
||||
}
|
||||
}
|
||||
var (
|
||||
cmd *exec.Cmd
|
||||
cmdDone <-chan struct{}
|
||||
readyWaiters []waitReadyReq
|
||||
// runResp parks the in-flight Run caller's response channel. The
|
||||
// interface contract is that Run blocks until the process is
|
||||
// terminated, so we hold this until Stop, parentCtx, or an
|
||||
// upstream exit unblocks it via respondRun.
|
||||
runResp chan<- error
|
||||
)
|
||||
|
||||
// notifyWaiters wakes every blocked WaitReady caller with the given result.
|
||||
// Used on transitions out of StateStarting (ready, failed, aborted, or
|
||||
// shutdown) — anything that resolves the "is it ready yet?" question.
|
||||
notifyWaiters := func(err error) {
|
||||
for _, w := range readyWaiters {
|
||||
select {
|
||||
case w.respond <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
readyWaiters = nil
|
||||
}
|
||||
|
||||
// respondRun delivers the final Run result, if a Run caller is parked.
|
||||
respondRun := func(err error) {
|
||||
if runResp != nil {
|
||||
runResp <- err
|
||||
runResp = nil
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
// Shutdown: parent context cancelled. Tear down any running process,
|
||||
// wake any pending WaitReady callers with an error, then exit the
|
||||
// goroutine permanently. Subsequent public-method calls will fail
|
||||
// because parentCtx.Done() unblocks their send-side selects.
|
||||
case <-p.parentCtx.Done():
|
||||
// Mark shutdown before killProcess so concurrent State() readers
|
||||
// stop treating this process as ready while the (possibly slow)
|
||||
// teardown is in progress.
|
||||
setState(StateShutdown)
|
||||
if cmd != nil {
|
||||
p.handler.Store(nil)
|
||||
p.killProcess(cmd, cmdDone, 100*time.Millisecond)
|
||||
cmd = nil
|
||||
cmdDone = nil
|
||||
}
|
||||
notifyWaiters(fmt.Errorf("[%s] shutdown", p.id))
|
||||
respondRun(fmt.Errorf("[%s] shutdown", p.id))
|
||||
return
|
||||
|
||||
// Upstream exited on its own (not via Stop). Drop handler state,
|
||||
// transition to Stopped, and unblock the parked Run caller.
|
||||
// cmdDone is nil while no process is running, so this case is
|
||||
// dormant outside of StateReady.
|
||||
case <-cmdDone:
|
||||
cmd = nil
|
||||
cmdDone = nil
|
||||
p.handler.Store(nil)
|
||||
setState(StateStopped)
|
||||
respondRun(fmt.Errorf("[%s] upstream exited unexpectedly", p.id))
|
||||
|
||||
// WaitReady: if we're already in a terminal-for-this-question state,
|
||||
// respond immediately; otherwise queue the caller and let a future
|
||||
// state transition wake them via notifyWaiters.
|
||||
case req := <-p.waitReadyCh:
|
||||
switch state {
|
||||
case StateReady:
|
||||
req.respond <- nil
|
||||
case StateShutdown:
|
||||
req.respond <- fmt.Errorf("[%s] shutdown", p.id)
|
||||
default:
|
||||
readyWaiters = append(readyWaiters, req)
|
||||
}
|
||||
|
||||
// Run: start the upstream process. Only valid from StateStopped.
|
||||
// doStart can take a long time (health-check polling), so it runs in
|
||||
// a separate goroutine and we wait on resultCh. While waiting we also
|
||||
// listen for an incoming Stop — that's how callers cancel an in-flight
|
||||
// start.
|
||||
case req := <-p.runCh:
|
||||
if state != StateStopped {
|
||||
req.respond <- fmt.Errorf("[%s] could not be started in %s state", p.id, state)
|
||||
continue
|
||||
}
|
||||
setState(StateStarting)
|
||||
|
||||
startCtx, cancelStart := context.WithCancel(context.Background())
|
||||
resultCh := make(chan startResult, 1)
|
||||
go func() {
|
||||
resultCh <- p.doStart(startCtx, req.timeout)
|
||||
}()
|
||||
|
||||
// pendingStop holds a Stop request that arrived mid-start, so we
|
||||
// can respond to it AFTER we've finished tearing the start down.
|
||||
var pendingStop *stopReq
|
||||
select {
|
||||
// doStart finished on its own — either successfully (latch
|
||||
// cmd/handler and move to Ready) or with an error (back to
|
||||
// Stopped). Either way wake WaitReady subscribers and reply
|
||||
// to the Run caller.
|
||||
case res := <-resultCh:
|
||||
if res.err == nil {
|
||||
cmd = res.cmd
|
||||
cmdDone = res.cmdDone
|
||||
fn := res.handlerFn
|
||||
p.handler.Store(&fn)
|
||||
setState(StateReady)
|
||||
notifyWaiters(nil)
|
||||
// Park the Run response — Run blocks until the process
|
||||
// terminates, so we only fire this when Stop, parentCtx,
|
||||
// or the upstream exit takes the process down.
|
||||
runResp = req.respond
|
||||
|
||||
// Start TTL goroutine if configured — self-terminates
|
||||
// when state leaves StateReady.
|
||||
if p.config.UnloadAfter > 0 {
|
||||
ttlDuration := time.Duration(p.config.UnloadAfter) * time.Second
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
if p.State() != StateReady {
|
||||
return
|
||||
}
|
||||
if p.inflight.Load() != 0 {
|
||||
continue
|
||||
}
|
||||
if time.Since(time.Unix(0, p.lastUse.Load())) > ttlDuration {
|
||||
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.id, p.config.UnloadAfter)
|
||||
p.Stop(10 * time.Second)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
} else {
|
||||
setState(StateStopped)
|
||||
notifyWaiters(res.err)
|
||||
req.respond <- res.err
|
||||
}
|
||||
|
||||
// Stop arrived while doStart was still running. Cancel the
|
||||
// start context to abort it, then wait for doStart to return.
|
||||
// If doStart had already crossed the finish line before
|
||||
// cancellation took effect, it returns a live cmd that we
|
||||
// must kill ourselves. The Run caller gets ErrAbort; the Stop
|
||||
// caller is parked in pendingStop and answered below.
|
||||
case stop := <-p.stopCh:
|
||||
cancelStart()
|
||||
res := <-resultCh
|
||||
if res.cmd != nil {
|
||||
p.killProcess(res.cmd, res.cmdDone, stop.timeout)
|
||||
}
|
||||
setState(StateStopped)
|
||||
notifyWaiters(ErrStartAborted)
|
||||
req.respond <- ErrStartAborted
|
||||
pendingStop = &stop
|
||||
|
||||
// Parent context cancelled (e.g. config reload) while doStart
|
||||
// was still running. Stop() returns early when parentCtx is
|
||||
// done and never sends on stopCh, so we must handle shutdown
|
||||
// here to avoid leaving doStart running indefinitely.
|
||||
case <-p.parentCtx.Done():
|
||||
cancelStart()
|
||||
// Mark shutdown before tearing the process down: killProcess
|
||||
// may block (e.g. taskkill on Windows is slow to spawn), and
|
||||
// callers observing State() should see StateShutdown promptly
|
||||
// rather than a stale StateStarting.
|
||||
setState(StateShutdown)
|
||||
res := <-resultCh
|
||||
if res.cmd != nil {
|
||||
p.killProcess(res.cmd, res.cmdDone, 100*time.Millisecond)
|
||||
}
|
||||
notifyWaiters(fmt.Errorf("[%s] shutdown", p.id))
|
||||
respondRun(fmt.Errorf("[%s] shutdown", p.id))
|
||||
return
|
||||
}
|
||||
// cancelStart is idempotent; calling it again here ensures the
|
||||
// context is released even on the success path (govet leak check).
|
||||
cancelStart()
|
||||
if pendingStop != nil {
|
||||
pendingStop.respond <- nil
|
||||
}
|
||||
|
||||
// Stop: tear down a running process.
|
||||
case stop := <-p.stopCh:
|
||||
if cmd != nil {
|
||||
setState(StateStopping)
|
||||
p.killProcess(cmd, cmdDone, stop.timeout)
|
||||
cmd = nil
|
||||
cmdDone = nil
|
||||
p.handler.Store(nil)
|
||||
}
|
||||
// Stop is a no-op (and not an error) when already Stopped — this
|
||||
// is what makes it idempotent for callers that don't track state.
|
||||
setState(StateStopped)
|
||||
respondRun(nil)
|
||||
stop.respond <- nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) doStart(startCtx context.Context, healthCheckTimeout time.Duration) startResult {
|
||||
if p.config.Proxy == "" {
|
||||
return startResult{err: fmt.Errorf("upstream proxy missing")}
|
||||
}
|
||||
|
||||
args, err := p.config.SanitizedCommand()
|
||||
if err != nil {
|
||||
return startResult{err: fmt.Errorf("unable to get sanitized command: %w", err)}
|
||||
}
|
||||
|
||||
proxyURL, err := url.Parse(p.config.Proxy)
|
||||
if err != nil {
|
||||
return startResult{err: fmt.Errorf("invalid proxy URL %q: %w", p.config.Proxy, err)}
|
||||
}
|
||||
|
||||
reverseProxy := httputil.NewSingleHostReverseProxy(proxyURL)
|
||||
reverseProxy.Transport = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: time.Duration(p.config.Timeouts.Connect) * time.Second,
|
||||
KeepAlive: time.Duration(p.config.Timeouts.KeepAlive) * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: time.Duration(p.config.Timeouts.TLSHandshake) * time.Second,
|
||||
ResponseHeaderTimeout: time.Duration(p.config.Timeouts.ResponseHeader) * time.Second,
|
||||
ExpectContinueTimeout: time.Duration(p.config.Timeouts.ExpectContinue) * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: time.Duration(p.config.Timeouts.IdleConn) * time.Second,
|
||||
}
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||
resp.Header.Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// httputil.ReverseProxy panics with http.ErrAbortHandler when the upstream
|
||||
// disconnects after response headers have been sent. Recover here so the
|
||||
// streaming termination is treated as a normal client/upstream disconnect.
|
||||
// see: https://github.com/golang/go/issues/23643
|
||||
handlerFn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
if rec == http.ErrAbortHandler {
|
||||
p.proxyLogger.Infof("<%s> recovered from upstream disconnection during streaming", p.id)
|
||||
} else {
|
||||
p.proxyLogger.Warnf("<%s> recovered from panic: %v", p.id, rec)
|
||||
}
|
||||
}
|
||||
}()
|
||||
reverseProxy.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
cmd := exec.Command(args[0], args[1:]...)
|
||||
cmd.Stderr = p.processLogger
|
||||
cmd.Stdout = p.processLogger
|
||||
cmd.Env = append(cmd.Environ(), p.config.Env...)
|
||||
setProcAttributes(cmd)
|
||||
|
||||
p.proxyLogger.Debugf("<%s> Executing start command: %s, env: %s", p.id, strings.Join(args, " "), strings.Join(p.config.Env, ", "))
|
||||
|
||||
cmdDone := make(chan struct{})
|
||||
if err := cmd.Start(); err != nil {
|
||||
return startResult{err: fmt.Errorf("failed to start command '%s': %w", strings.Join(args, " "), err)}
|
||||
}
|
||||
|
||||
go func() {
|
||||
waitErr := cmd.Wait()
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
p.proxyLogger.Debugf("<%s> process exited: code=%d, err=%v", p.id, exitErr.ExitCode(), waitErr)
|
||||
} else if waitErr != nil {
|
||||
p.proxyLogger.Debugf("<%s> process exited with error: %v", p.id, waitErr)
|
||||
} else {
|
||||
p.proxyLogger.Debugf("<%s> process exited cleanly", p.id)
|
||||
}
|
||||
close(cmdDone)
|
||||
}()
|
||||
|
||||
if startCtx.Err() != nil {
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: ErrStartAborted}
|
||||
}
|
||||
|
||||
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
||||
if checkEndpoint == "none" {
|
||||
return startResult{cmd: cmd, cmdDone: cmdDone, handlerFn: handlerFn}
|
||||
}
|
||||
|
||||
// Wait 250ms for the command to start up before health checking
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: ErrStartAborted}
|
||||
case <-time.After(250 * time.Millisecond):
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(healthCheckTimeout)
|
||||
for {
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: ErrStartAborted}
|
||||
case <-cmdDone:
|
||||
return startResult{err: fmt.Errorf("upstream command exited prematurely")}
|
||||
default:
|
||||
}
|
||||
|
||||
if time.Now().After(deadline) {
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: fmt.Errorf("health check timed out after %v", healthCheckTimeout)}
|
||||
}
|
||||
|
||||
req, _ := http.NewRequestWithContext(startCtx, "GET", p.config.CheckEndpoint, nil)
|
||||
rr := httptest.NewRecorder()
|
||||
reverseProxy.ServeHTTP(rr, req)
|
||||
resp := rr.Result()
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
p.proxyLogger.Infof("<%s> Health check passed on %s%s", p.id, p.config.Proxy, p.config.CheckEndpoint)
|
||||
break
|
||||
} else if startCtx.Err() != nil {
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: ErrStartAborted}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: ErrStartAborted}
|
||||
case <-cmdDone:
|
||||
return startResult{err: fmt.Errorf("upstream command exited prematurely")}
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
return startResult{cmd: cmd, cmdDone: cmdDone, handlerFn: handlerFn}
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) killProcess(cmd *exec.Cmd, cmdDone <-chan struct{}, gracefulTimeout time.Duration) {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if p.config.CmdStop != "" {
|
||||
stopArgs, err := config.SanitizeCommand(
|
||||
strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", cmd.Process.Pid)),
|
||||
)
|
||||
if err == nil {
|
||||
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
|
||||
stopCmd.Env = cmd.Env
|
||||
setProcAttributes(stopCmd)
|
||||
stopCmd.Run()
|
||||
} else {
|
||||
cmd.Process.Signal(syscall.SIGTERM)
|
||||
}
|
||||
} else {
|
||||
cmd.Process.Signal(syscall.SIGTERM)
|
||||
}
|
||||
|
||||
timer := time.NewTimer(gracefulTimeout)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-cmdDone:
|
||||
case <-timer.C:
|
||||
cmd.Process.Kill()
|
||||
<-cmdDone
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) ID() string {
|
||||
return p.id
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) Run(timeout time.Duration) error {
|
||||
req := runReq{
|
||||
timeout: timeout,
|
||||
respond: make(chan error, 1),
|
||||
}
|
||||
select {
|
||||
case p.runCh <- req:
|
||||
case <-p.parentCtx.Done():
|
||||
return fmt.Errorf("[%s] shutdown", p.id)
|
||||
}
|
||||
select {
|
||||
case err := <-req.respond:
|
||||
return err
|
||||
case <-p.parentCtx.Done():
|
||||
return fmt.Errorf("[%s] shutdown", p.id)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) WaitReady(ctx context.Context) error {
|
||||
req := waitReadyReq{respond: make(chan error, 1)}
|
||||
select {
|
||||
case p.waitReadyCh <- req:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-p.parentCtx.Done():
|
||||
return fmt.Errorf("[%s] shutdown", p.id)
|
||||
}
|
||||
select {
|
||||
case err := <-req.respond:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) Stop(timeout time.Duration) error {
|
||||
req := stopReq{
|
||||
timeout: timeout,
|
||||
respond: make(chan error, 1),
|
||||
}
|
||||
select {
|
||||
case p.stopCh <- req:
|
||||
case <-p.parentCtx.Done():
|
||||
return fmt.Errorf("[%s] shutdown", p.id)
|
||||
}
|
||||
return <-req.respond
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) State() ProcessState {
|
||||
if s, ok := p.state.Load().(ProcessState); ok {
|
||||
return s
|
||||
}
|
||||
return StateStopped
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
fn := p.handler.Load()
|
||||
if fn == nil {
|
||||
http.Error(w, fmt.Sprintf("llama-swap-error: [%s] process is not ready", p.id), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
p.inflight.Add(1)
|
||||
defer func() {
|
||||
p.lastUse.Store(time.Now().UnixNano())
|
||||
p.inflight.Add(-1)
|
||||
}()
|
||||
(*fn)(w, r)
|
||||
}
|
||||
@@ -0,0 +1,646 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
const (
|
||||
testStartTimeout = 3 * time.Second
|
||||
testStopTimeout = 2 * time.Second
|
||||
testReturnTimeout = 1 * time.Second
|
||||
testPollInterval = 20 * time.Millisecond
|
||||
testLogPollInterval = 10 * time.Millisecond
|
||||
)
|
||||
|
||||
func newProcessCommand(t *testing.T, conf config.ModelConfig) *ProcessCommand {
|
||||
t.Helper()
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
p, err := New(context.Background(), t.Name(), conf, logger, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// runAsync starts Run in a goroutine and waits until the process is ready,
|
||||
// matching the new interface contract where Run blocks until the process is
|
||||
// terminated. Returns a channel that delivers Run's eventual error.
|
||||
func runAsync(t *testing.T, p *ProcessCommand) <-chan error {
|
||||
t.Helper()
|
||||
ch := make(chan error, 1)
|
||||
go func() { ch <- p.Run(testStartTimeout) }()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testStartTimeout)
|
||||
defer cancel()
|
||||
if err := p.WaitReady(ctx); err != nil {
|
||||
t.Fatalf("WaitReady: %v", err)
|
||||
}
|
||||
return ch
|
||||
}
|
||||
|
||||
func TestProcessCommand_StartStop(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
cmd, port := simpleResponderCmd(t, "-silent", "-respond hello")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
})
|
||||
t.Cleanup(func() { p.Stop(testStopTimeout) })
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
|
||||
// before start: no handler
|
||||
rr := httptest.NewRecorder()
|
||||
p.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("before start: expected 503, got %d", rr.Code)
|
||||
}
|
||||
if body := rr.Body.String(); !strings.Contains(body, "llama-swap-error") {
|
||||
t.Errorf("before start: expected body to contain %q, got %q", "llama-swap-error", body)
|
||||
}
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
if got := p.State(); got != StateReady {
|
||||
t.Errorf("after Run: expected state %s, got %s", StateReady, got)
|
||||
}
|
||||
|
||||
rr = httptest.NewRecorder()
|
||||
p.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("after Run: expected 200, got %d", rr.Code)
|
||||
}
|
||||
if body := rr.Body.String(); body != "hello" {
|
||||
t.Errorf("expected body %q, got %q", "hello", body)
|
||||
}
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop() error: %v", err)
|
||||
}
|
||||
if got := p.State(); got != StateStopped {
|
||||
t.Errorf("after Stop: expected state %s, got %s", StateStopped, got)
|
||||
}
|
||||
select {
|
||||
case err := <-runErr:
|
||||
if err != nil {
|
||||
t.Errorf("Run() after Stop: expected nil, got %v", err)
|
||||
}
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Fatal("Run() did not return after Stop")
|
||||
}
|
||||
|
||||
// after stop: handler cleared
|
||||
rr = httptest.NewRecorder()
|
||||
p.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("after stop: expected 503, got %d", rr.Code)
|
||||
}
|
||||
if body := rr.Body.String(); !strings.Contains(body, "llama-swap-error") {
|
||||
t.Errorf("after stop: expected body to contain %q, got %q", "llama-swap-error", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessCommand_Run_Idempotent(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
cmd, port := simpleResponderCmd(t, "-silent")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
})
|
||||
t.Cleanup(func() { p.Stop(testStopTimeout) })
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
|
||||
if err := p.Run(testStartTimeout); err == nil {
|
||||
t.Error("second Run() while running: expected error, got nil")
|
||||
}
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop() error: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Fatal("Run() did not return after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessCommand_Stop_Idempotent(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
cmd, port := simpleResponderCmd(t, "-silent")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
})
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop() before Run(): %v", err)
|
||||
}
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("first Stop() error: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Fatal("Run() did not return after Stop")
|
||||
}
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("second Stop() error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_StopCancelsRun verifies that a Stop sent while Run is
|
||||
// executing its health-check loop returns ErrAbort to the Run caller.
|
||||
//
|
||||
// A blocking mock HTTP server is used as the proxy so the test can deterministically
|
||||
// know when doStart is inside the health-check loop before issuing Stop.
|
||||
func TestProcessCommand_StopCancelsRun(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
healthCheckStarted := make(chan struct{}, 1)
|
||||
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Signal that a health check is in-flight, then block until the client
|
||||
// cancels (which happens when Stop cancels the start context).
|
||||
select {
|
||||
case healthCheckStarted <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
<-r.Context().Done()
|
||||
http.Error(w, "mock cancelled", http.StatusServiceUnavailable)
|
||||
}))
|
||||
defer mock.Close()
|
||||
|
||||
// simple-responder is the real process; health checks go to the blocking mock.
|
||||
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: mock.URL,
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 30,
|
||||
})
|
||||
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- p.Run(testStartTimeout)
|
||||
}()
|
||||
|
||||
// Block until doStart is actually performing a health check, guaranteeing
|
||||
// that Run is in-flight when Stop is called.
|
||||
<-healthCheckStarted
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop() error: %v", err)
|
||||
}
|
||||
|
||||
if err := <-runErrCh; !errors.Is(err, ErrStartAborted) {
|
||||
t.Errorf("expected ErrStartAborted from Run, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_ParentCtxCancelDuringStart verifies that cancelling the
|
||||
// parent context while doStart is health-checking causes the process to
|
||||
// transition to StateShutdown promptly, not wait for the health-check timeout.
|
||||
//
|
||||
// This is the config-reload race: Stop() returns early when parentCtx is
|
||||
// already done and never writes to stopCh, so without a parentCtx.Done()
|
||||
// case in the inner select, the process would keep loading indefinitely.
|
||||
func TestProcessCommand_ParentCtxCancelDuringStart(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
healthCheckStarted := make(chan struct{}, 1)
|
||||
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
select {
|
||||
case healthCheckStarted <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
<-r.Context().Done()
|
||||
http.Error(w, "mock cancelled", http.StatusServiceUnavailable)
|
||||
}))
|
||||
defer mock.Close()
|
||||
|
||||
parentCtx, cancelParent := context.WithCancel(context.Background())
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||
p, err := New(parentCtx, t.Name(), config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: mock.URL,
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 60,
|
||||
}, logger, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() { runErrCh <- p.Run(60 * time.Second) }()
|
||||
|
||||
<-healthCheckStarted
|
||||
|
||||
// Cancel parent context to simulate a config reload tearing down the old server.
|
||||
cancelParent()
|
||||
|
||||
select {
|
||||
case err := <-runErrCh:
|
||||
if !strings.Contains(err.Error(), "shutdown") {
|
||||
t.Errorf("Run error = %v, want shutdown error", err)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("process did not shut down within 5s after parent context cancel during start")
|
||||
}
|
||||
|
||||
// Run() may return before the run() goroutine writes StateShutdown;
|
||||
// poll briefly to avoid a spurious race in the assertion.
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if p.State() == StateShutdown {
|
||||
break
|
||||
}
|
||||
time.Sleep(testPollInterval)
|
||||
}
|
||||
if got := p.State(); got != StateShutdown {
|
||||
t.Errorf("after cancel: expected StateShutdown, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_RunStopCycle runs several sequential start/stop pairs on
|
||||
// fresh processes to confirm they are reusable.
|
||||
func TestProcessCommand_RunStopCycle(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
for i := range 3 {
|
||||
cmd, port := simpleResponderCmd(t, "-silent")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
})
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
|
||||
req := httptest.NewRequest("GET", "/health", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
p.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("cycle %d: expected 200 from /health, got %d", i, rr.Code)
|
||||
}
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("cycle %d Stop() error: %v", i, err)
|
||||
}
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Fatalf("cycle %d: Run() did not return after Stop", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_ReverseProxyPanicIsRecovered drives the full proxy path:
|
||||
// the upstream responds healthy on /health (so Run completes), then on the
|
||||
// actual proxied request it hijacks the connection and closes it mid-body.
|
||||
// That upstream EOF makes httputil.ReverseProxy.copyResponse return an error,
|
||||
// which panics with http.ErrAbortHandler — the wrapped handlerFn must recover
|
||||
// and log the disconnect.
|
||||
//
|
||||
// Requests are issued through an httptest.NewServer wrapping the process so
|
||||
// the panic actually fires (httputil only panics on copy errors when the
|
||||
// request carries http.ServerContextKey, which a real server sets).
|
||||
//
|
||||
// see: https://github.com/golang/go/issues/23643
|
||||
func TestProcessCommand_ReverseProxyPanicIsRecovered(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/health" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
// Send a Content-Length that promises 100 bytes, deliver only a few,
|
||||
// then slam the connection shut. The reverse proxy will see EOF
|
||||
// before the body is fully copied and panic with ErrAbortHandler.
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
t.Errorf("upstream: hijack not supported")
|
||||
return
|
||||
}
|
||||
conn, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
t.Errorf("upstream: hijack: %v", err)
|
||||
return
|
||||
}
|
||||
_, _ = conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 100\r\nContent-Type: text/plain\r\n\r\npartial"))
|
||||
_ = conn.Close()
|
||||
}))
|
||||
t.Cleanup(upstream.Close)
|
||||
|
||||
// Capture proxy log output so we can assert the recover message was
|
||||
// emitted by handlerFn.
|
||||
logBuf := &syncBuffer{}
|
||||
proxyLogger := logmon.NewWriter(logBuf)
|
||||
procLogger := logmon.NewWriter(io.Discard)
|
||||
|
||||
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||
p, err := New(context.Background(), t.Name(), config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: upstream.URL,
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
}, procLogger, proxyLogger)
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { p.Stop(testStopTimeout) })
|
||||
|
||||
_ = runAsync(t, p)
|
||||
|
||||
// Wrap p in an httptest server so requests get http.ServerContextKey
|
||||
// automatically — that is what makes httputil.ReverseProxy raise the panic.
|
||||
front := httptest.NewServer(p)
|
||||
t.Cleanup(front.Close)
|
||||
|
||||
resp, err := http.Get(front.URL + "/disconnect")
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
const want = "recovered from upstream disconnection"
|
||||
deadline := time.Now().Add(testReturnTimeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if strings.Contains(logBuf.String(), want) {
|
||||
return
|
||||
}
|
||||
time.Sleep(testLogPollInterval)
|
||||
}
|
||||
t.Errorf("expected proxy log to contain %q; got:\n%s", want, logBuf.String())
|
||||
}
|
||||
|
||||
// syncBuffer is a concurrent-safe bytes.Buffer for capturing logmon output.
|
||||
type syncBuffer struct {
|
||||
mu sync.Mutex
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
func (b *syncBuffer) Write(p []byte) (int, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.buf.Write(p)
|
||||
}
|
||||
|
||||
func (b *syncBuffer) String() string {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.buf.String()
|
||||
}
|
||||
|
||||
// TestProcessCommand_TTL_StopsAfterIdle verifies that a process with a TTL
|
||||
// automatically stops itself after the idle timeout has elapsed following its
|
||||
// last request.
|
||||
func TestProcessCommand_TTL_StopsAfterIdle(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
t.Cleanup(mock.Close)
|
||||
|
||||
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||
|
||||
cfg := config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: mock.URL,
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
UnloadAfter: 1, // 1-second TTL
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
cfg.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
}
|
||||
|
||||
p := newProcessCommand(t, cfg)
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
defer func() {
|
||||
if p.State() == StateReady {
|
||||
p.Stop(testStopTimeout)
|
||||
}
|
||||
}()
|
||||
|
||||
if got := p.State(); got != StateReady {
|
||||
t.Fatalf("expected StateReady, got %s", got)
|
||||
}
|
||||
|
||||
// Make one request to prime the last-use timestamp.
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
p.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("expected 200 after request, got %d", rr.Code)
|
||||
}
|
||||
|
||||
// Wait for the TTL goroutine to fire and the process to fully stop.
|
||||
// Poll for StateStopped directly to avoid racing the StateStopping
|
||||
// intermediate state that sits between StateReady and StateStopped.
|
||||
deadline := time.Now().Add(5 * time.Second)
|
||||
for p.State() != StateStopped && time.Now().Before(deadline) {
|
||||
time.Sleep(testPollInterval)
|
||||
}
|
||||
|
||||
if got := p.State(); got != StateStopped {
|
||||
t.Fatalf("TTL did not stop process; state is %s (expected %s)", got, StateStopped)
|
||||
}
|
||||
|
||||
// Run() should have returned nil (clean stop from TTL).
|
||||
select {
|
||||
case err := <-runErr:
|
||||
if err != nil {
|
||||
t.Errorf("Run() after TTL stop: expected nil, got %v", err)
|
||||
}
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Fatal("Run() did not return after TTL-induced stop")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_TTL_ResetsOnRequest verifies that inflight requests
|
||||
// prevent the TTL goroutine from stopping the process, and that the TTL timer
|
||||
// resets after each request completes.
|
||||
func TestProcessCommand_TTL_ResetsOnRequest(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
t.Cleanup(mock.Close)
|
||||
|
||||
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: mock.URL,
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
UnloadAfter: 1, // 1-second TTL
|
||||
})
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
defer func() {
|
||||
if p.State() == StateReady {
|
||||
p.Stop(testStopTimeout)
|
||||
}
|
||||
}()
|
||||
|
||||
// Keep sending requests for 1.5s — past the 1s TTL — and verify
|
||||
// the process never stops while traffic is flowing.
|
||||
stopAt := time.Now().Add(1500 * time.Millisecond)
|
||||
for time.Now().Before(stopAt) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
p.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d", rr.Code)
|
||||
}
|
||||
if p.State() != StateReady {
|
||||
t.Fatalf("process was stopped during active traffic (state=%s)", p.State())
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
if got := p.State(); got != StateReady {
|
||||
t.Fatalf("expected StateReady while traffic was active, got %s", got)
|
||||
}
|
||||
|
||||
// Now stop manually to clean up.
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop() error: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Fatal("Run() did not return after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_TTL_ZeroDisables verifies that UnloadAfter=0 does not
|
||||
// spawn a TTL goroutine — the process stays ready until explicitly stopped.
|
||||
func TestProcessCommand_TTL_ZeroDisables(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
t.Cleanup(mock.Close)
|
||||
|
||||
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: mock.URL,
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
UnloadAfter: 0, // disabled
|
||||
})
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
defer func() {
|
||||
if p.State() == StateReady {
|
||||
p.Stop(testStopTimeout)
|
||||
}
|
||||
}()
|
||||
|
||||
if got := p.State(); got != StateReady {
|
||||
t.Fatalf("expected StateReady, got %s", got)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
p.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("expected 200 after request, got %d", rr.Code)
|
||||
}
|
||||
|
||||
// No TTL goroutine is spawned when UnloadAfter=0, so a brief sleep is
|
||||
// enough to confirm the process remains ready.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if got := p.State(); got != StateReady {
|
||||
t.Fatalf("process was stopped unexpectedly (state=%s) with TTL=0", got)
|
||||
}
|
||||
|
||||
// Cleanly stop.
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop() error: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Fatal("Run() did not return after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_ConcurrentRunStop launches many concurrent run/stop racing
|
||||
// pairs to exercise the race detector and verify no deadlocks occur.
|
||||
func TestProcessCommand_ConcurrentRunStop(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
for range 10 {
|
||||
cmd, port := simpleResponderCmd(t, "-silent")
|
||||
cfg := config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
cfg.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
}
|
||||
|
||||
p := newProcessCommand(t, cfg)
|
||||
|
||||
runDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(runDone)
|
||||
p.Run(testStartTimeout) //nolint: errcheck — one goroutine wins the race
|
||||
}()
|
||||
go func() {
|
||||
p.Stop(testStopTimeout) //nolint: errcheck
|
||||
}()
|
||||
|
||||
// Backstop: the racing Stop may have arrived before Run got on the
|
||||
// channel (making it a no-op), so keep stopping until Run unblocks.
|
||||
deadline := time.After(testStartTimeout)
|
||||
for done := false; !done; {
|
||||
select {
|
||||
case <-runDone:
|
||||
done = true
|
||||
case <-deadline:
|
||||
t.Fatal("Run did not return")
|
||||
case <-time.After(testPollInterval):
|
||||
p.Stop(testStopTimeout) //nolint: errcheck
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
func TestProcessCommand_EmitsStateChangeEvents(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
var mu sync.Mutex
|
||||
var transitions []shared.ProcessStateChangeEvent
|
||||
cancel := event.On(func(e shared.ProcessStateChangeEvent) {
|
||||
if e.ProcessName != t.Name() {
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
transitions = append(transitions, e)
|
||||
mu.Unlock()
|
||||
})
|
||||
defer cancel()
|
||||
|
||||
cmd, port := simpleResponderCmd(t, "-silent", "-respond hello")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
})
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop: %v", err)
|
||||
}
|
||||
<-runErr
|
||||
|
||||
// Events are delivered asynchronously; give the dispatcher a moment.
|
||||
deadline := time.Now().Add(time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
mu.Lock()
|
||||
n := len(transitions)
|
||||
mu.Unlock()
|
||||
if n >= 4 {
|
||||
break
|
||||
}
|
||||
time.Sleep(testPollInterval)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
for _, e := range transitions {
|
||||
if e.OldState == e.NewState {
|
||||
t.Errorf("emitted no-op transition: %s -> %s", e.OldState, e.NewState)
|
||||
}
|
||||
}
|
||||
|
||||
want := []string{
|
||||
string(StateStopped) + "->" + string(StateStarting),
|
||||
string(StateStarting) + "->" + string(StateReady),
|
||||
string(StateReady) + "->" + string(StateStopping),
|
||||
string(StateStopping) + "->" + string(StateStopped),
|
||||
}
|
||||
got := make([]string, len(transitions))
|
||||
for i, e := range transitions {
|
||||
got[i] = e.OldState + "->" + e.NewState
|
||||
}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("transitions = %v, want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("transitions = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
//go:build !windows
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
// No-op on Unix systems
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
//go:build windows
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
HideWindow: true,
|
||||
CreationFlags: 0x08000000, // CREATE_NO_WINDOW
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,775 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
)
|
||||
|
||||
type shutdownReq struct {
|
||||
timeout time.Duration
|
||||
respond chan error
|
||||
}
|
||||
|
||||
type unloadReq struct {
|
||||
targets []string
|
||||
timeout time.Duration
|
||||
respond chan struct{}
|
||||
}
|
||||
|
||||
type handlerReq struct {
|
||||
model string
|
||||
ctx context.Context
|
||||
respond chan handlerResp
|
||||
positionCh chan int
|
||||
}
|
||||
|
||||
type handlerResp struct {
|
||||
handleFunc http.HandlerFunc
|
||||
err error
|
||||
}
|
||||
|
||||
type swapDone struct {
|
||||
modelID string
|
||||
err error
|
||||
}
|
||||
|
||||
type serveDoneEvent struct {
|
||||
modelID string
|
||||
}
|
||||
|
||||
type activeSwap struct {
|
||||
modelID string
|
||||
evict []string
|
||||
waiters []handlerReq
|
||||
}
|
||||
|
||||
// swapPlanner is the only piece of behaviour that differs between concrete
|
||||
// routers. baseRouter never inspects its internals.
|
||||
type swapPlanner interface {
|
||||
// EvictionFor returns running model IDs that must be stopped before
|
||||
// target can serve. alsoRunning lists models the baseRouter has already
|
||||
// committed to loading (in-flight swaps) which the planner cannot see
|
||||
// via process.State() yet. Pure decision; must not log.
|
||||
EvictionFor(target string, alsoRunning []string) []string
|
||||
|
||||
// OnSwapStart runs once at the start of every swap. Planners may log
|
||||
// their decision here at whatever verbosity they choose.
|
||||
OnSwapStart(target string)
|
||||
}
|
||||
|
||||
// baseRouter owns the channels, run-loop, and orchestration code shared by
|
||||
// every concrete router. Concrete routers embed *baseRouter and supply a
|
||||
// swapPlanner that captures how their eviction set is decided.
|
||||
type baseRouter struct {
|
||||
name string
|
||||
config config.Config
|
||||
processes map[string]process.Process
|
||||
logger *logmon.Monitor
|
||||
planner swapPlanner
|
||||
|
||||
shutdownCtx context.Context
|
||||
shutdownFn context.CancelFunc
|
||||
shuttingDown atomic.Bool
|
||||
|
||||
handlerCh chan handlerReq
|
||||
shutdownCh chan shutdownReq
|
||||
unloadCh chan unloadReq
|
||||
swapDoneCh chan swapDone
|
||||
serveDoneCh chan serveDoneEvent
|
||||
|
||||
runDone chan struct{}
|
||||
|
||||
// testProcessed, when non-nil, receives one event after each handlerReq
|
||||
// or swapDone has been fully processed by run(). Tests use it to wait
|
||||
// for run() to reach a deterministic state without sleeping. serveDone
|
||||
// events are intentionally NOT signalled here so test event counts
|
||||
// remain stable.
|
||||
testProcessed chan struct{}
|
||||
}
|
||||
|
||||
func newBaseRouter(name string, conf config.Config, processes map[string]process.Process, planner swapPlanner, logger *logmon.Monitor) *baseRouter {
|
||||
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
|
||||
return &baseRouter{
|
||||
name: name,
|
||||
config: conf,
|
||||
processes: processes,
|
||||
logger: logger,
|
||||
planner: planner,
|
||||
shutdownCtx: shutdownCtx,
|
||||
shutdownFn: shutdownFn,
|
||||
handlerCh: make(chan handlerReq),
|
||||
shutdownCh: make(chan shutdownReq),
|
||||
unloadCh: make(chan unloadReq),
|
||||
swapDoneCh: make(chan swapDone),
|
||||
serveDoneCh: make(chan serveDoneEvent),
|
||||
runDone: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *baseRouter) notifyProcessed() {
|
||||
if b.testProcessed != nil {
|
||||
b.testProcessed <- struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *baseRouter) run() {
|
||||
defer close(b.runDone)
|
||||
|
||||
active := make(map[string]*activeSwap)
|
||||
inFlight := make(map[string]int)
|
||||
var queued []handlerReq
|
||||
|
||||
for {
|
||||
select {
|
||||
case req := <-b.shutdownCh:
|
||||
b.handleShutdown(req, active, queued)
|
||||
return
|
||||
|
||||
case req := <-b.handlerCh:
|
||||
b.handleRequest(req, active, inFlight, &queued)
|
||||
b.notifyProcessed()
|
||||
|
||||
case req := <-b.unloadCh:
|
||||
b.handleUnload(req, active, inFlight, &queued)
|
||||
b.notifyProcessed()
|
||||
|
||||
case ev := <-b.swapDoneCh:
|
||||
b.handleSwapDone(ev, active, inFlight, &queued)
|
||||
b.notifyProcessed()
|
||||
|
||||
case ev := <-b.serveDoneCh:
|
||||
b.handleServeDone(ev, active, inFlight, &queued)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// grant sends a response back to the caller of ServeHTTP and tells us
|
||||
// whether the caller was still there to receive it.
|
||||
//
|
||||
// Each ServeHTTP creates a fresh, UNBUFFERED respond channel and parks in
|
||||
// a select waiting on it. "Unbuffered" is the important word: a send only
|
||||
// completes when the other side is actively receiving. So if this send
|
||||
// succeeds, we know for a fact the caller picked up the response and will
|
||||
// act on it. If the caller has already given up (its request context was
|
||||
// cancelled, e.g. the HTTP client disconnected) or the router is shutting
|
||||
// down, the send never lands, one of the other select cases fires, and we
|
||||
// report back that the grant did NOT happen.
|
||||
//
|
||||
// That distinction matters for in-flight bookkeeping — see grantHandler.
|
||||
func (b *baseRouter) grant(req handlerReq, resp handlerResp) bool {
|
||||
select {
|
||||
case req.respond <- resp:
|
||||
return true
|
||||
case <-req.ctx.Done():
|
||||
return false
|
||||
case <-b.shutdownCtx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// grantHandler is the "this caller can now use process p" path. It does
|
||||
// two things that must stay locked together:
|
||||
//
|
||||
// 1. Hand the caller a wrapped p.ServeHTTP (via trackedServe) so when the
|
||||
// HTTP request finishes, the run loop hears about it.
|
||||
// 2. Bump inFlight[modelID] so the router knows this process is busy and
|
||||
// refuses to evict it until the count comes back down.
|
||||
//
|
||||
// The increment is gated on grant() returning true. If grant() returns
|
||||
// false, the caller already walked away and trackedServe will never run —
|
||||
// which means no matching decrement will ever arrive on serveDoneCh.
|
||||
// Incrementing in that case would strand the counter at >0 forever and
|
||||
// the router would never again be willing to swap this model out.
|
||||
//
|
||||
// In short: increment if and only if we know a decrement is coming.
|
||||
func (b *baseRouter) grantHandler(req handlerReq, modelID string, p process.Process, inFlight map[string]int) {
|
||||
if b.grant(req, handlerResp{handleFunc: b.trackedServe(modelID, p)}) {
|
||||
inFlight[modelID]++
|
||||
}
|
||||
}
|
||||
|
||||
// trackedServe is the wrapper that closes the loop on in-flight tracking.
|
||||
// It runs p.ServeHTTP normally; the only added behaviour is a deferred
|
||||
// send on serveDoneCh after the handler returns. That send is what tells
|
||||
// the run loop "this model now has one fewer request in flight — go look
|
||||
// at the queue again, you may be able to start a swap you previously had
|
||||
// to defer."
|
||||
//
|
||||
// The select on shutdownCtx.Done() is a release valve: if the router is
|
||||
// already shutting down, nobody is reading serveDoneCh, so we drop the
|
||||
// notification rather than blocking the HTTP goroutine forever.
|
||||
func (b *baseRouter) trackedServe(modelID string, p process.Process) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
select {
|
||||
case b.serveDoneCh <- serveDoneEvent{modelID: modelID}:
|
||||
case <-b.shutdownCtx.Done():
|
||||
}
|
||||
}()
|
||||
p.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// handleRequest decides what to do with one incoming ServeHTTP request. It is
|
||||
// called from run() and never blocks indefinitely: any work that has to wait
|
||||
// (starting a process, stopping siblings, waiting for ready) is deferred to
|
||||
// a swap goroutine and reported back via swapDoneCh.
|
||||
//
|
||||
// The decision tree, in order:
|
||||
//
|
||||
// 1. Unknown model — respond with ErrNoLocalModelFound and move on.
|
||||
// 2. A swap to the same model is already in flight — attach this waiter so
|
||||
// one swap serves all callers that asked for the same model.
|
||||
// 3. Fast path — the target process is already ready, the planner sees
|
||||
// nothing to evict, and no in-flight swap is evicting it. Hand back its
|
||||
// ServeHTTP immediately (wrapped so the run loop knows when it ends).
|
||||
// 4. Would collide with an in-flight swap (we'd stop their target, or
|
||||
// they're stopping us) — park in the queue for handleSwapDone to drain.
|
||||
// 5. Would evict a process that is still handling requests — park in the
|
||||
// queue. handleServeDone will retry when the busy process drains.
|
||||
// 6. Otherwise — start a new swap. This may run in parallel with other
|
||||
// active swaps when their evict sets don't intersect.
|
||||
func (b *baseRouter) handleRequest(req handlerReq, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
||||
// (1) Unknown model.
|
||||
p, ok := b.processes[req.model]
|
||||
if !ok {
|
||||
b.logger.Debugf("%s: model %s not handled by this router", b.name, req.model)
|
||||
b.grant(req, handlerResp{err: ErrNoLocalModelFound})
|
||||
return
|
||||
}
|
||||
|
||||
// (2) Join an in-flight swap for the same model.
|
||||
if s, ok := active[req.model]; ok {
|
||||
b.logger.Debugf("%s: joining in-flight swap for model %s (%d waiters)", b.name, req.model, len(s.waiters)+1)
|
||||
s.waiters = append(s.waiters, req)
|
||||
return
|
||||
}
|
||||
|
||||
evict := b.planner.EvictionFor(req.model, activeTargets(active, req.model))
|
||||
|
||||
// (3) Fast path: ready, nothing to evict, and nobody is evicting us.
|
||||
if p.State() == process.StateReady && len(evict) == 0 && !collidesWith(req.model, evict, active) {
|
||||
b.logger.Debugf("%s: fast-path serving model %s (already ready)", b.name, req.model)
|
||||
b.grantHandler(req, req.model, p, inFlight)
|
||||
return
|
||||
}
|
||||
|
||||
// (4) Collision with an in-flight swap — queue.
|
||||
if collidesWith(req.model, evict, active) {
|
||||
b.logger.Debugf("%s: queuing request for model %s (collides with in-flight swap)", b.name, req.model)
|
||||
*queued = append(*queued, req)
|
||||
b.broadcastQueuePositions(*queued)
|
||||
return
|
||||
}
|
||||
|
||||
// (5) Would evict a busy process — queue until it drains.
|
||||
if conflictsWithInFlight(evict, inFlight) {
|
||||
b.logger.Debugf("%s: queuing request for model %s (would evict in-flight process)", b.name, req.model)
|
||||
*queued = append(*queued, req)
|
||||
b.broadcastQueuePositions(*queued)
|
||||
return
|
||||
}
|
||||
|
||||
// (6) Start a new (possibly parallel) swap.
|
||||
b.logger.Debugf("%s: starting swap for model %s, evicting %v", b.name, req.model, evict)
|
||||
s := b.startSwap(req, evict)
|
||||
active[s.modelID] = s
|
||||
}
|
||||
|
||||
// handleSwapDone is called from run() when a swap goroutine reports that it
|
||||
// has finished. It fans out the result to every waiter that joined this swap,
|
||||
// removes the swap from the active map, and then walks the queue once,
|
||||
// promoting any items that no longer collide with the remaining active set.
|
||||
// FIFO order is preserved: items still blocked stay in place.
|
||||
func (b *baseRouter) handleSwapDone(ev swapDone, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
||||
s, ok := active[ev.modelID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(active, ev.modelID)
|
||||
|
||||
for _, w := range s.waiters {
|
||||
if ev.err != nil {
|
||||
b.grant(w, handlerResp{err: ev.err})
|
||||
} else {
|
||||
p := b.processes[ev.modelID]
|
||||
b.grantHandler(w, ev.modelID, p, inFlight)
|
||||
}
|
||||
}
|
||||
|
||||
b.drainQueue(active, inFlight, queued)
|
||||
}
|
||||
|
||||
// handleServeDone is called from run() each time a tracked ServeHTTP
|
||||
// finishes. It decrements the per-model in-flight count and, when that
|
||||
// drops to zero, retries the queue: requests whose swap was deferred
|
||||
// because they would have evicted this (now-idle) process can now proceed.
|
||||
func (b *baseRouter) handleServeDone(ev serveDoneEvent, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
||||
inFlight[ev.modelID]--
|
||||
if inFlight[ev.modelID] <= 0 {
|
||||
delete(inFlight, ev.modelID)
|
||||
b.drainQueue(active, inFlight, queued)
|
||||
}
|
||||
}
|
||||
|
||||
// drainQueue walks the queued requests in order, re-running the handleRequest
|
||||
// decision tree against the (now smaller) active set. Items that can now start
|
||||
// or join become satisfied; items still blocked remain queued in original
|
||||
// order so they get another chance on the next swap completion.
|
||||
func (b *baseRouter) drainQueue(active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
||||
if len(*queued) == 0 {
|
||||
return
|
||||
}
|
||||
pending := *queued
|
||||
var remaining []handlerReq
|
||||
for _, req := range pending {
|
||||
p, ok := b.processes[req.model]
|
||||
if !ok {
|
||||
b.grant(req, handlerResp{err: ErrNoLocalModelFound})
|
||||
continue
|
||||
}
|
||||
if s, ok := active[req.model]; ok {
|
||||
b.logger.Debugf("%s: queued request for model %s now joining in-flight swap", b.name, req.model)
|
||||
s.waiters = append(s.waiters, req)
|
||||
continue
|
||||
}
|
||||
evict := b.planner.EvictionFor(req.model, activeTargets(active, req.model))
|
||||
if p.State() == process.StateReady && len(evict) == 0 && !collidesWith(req.model, evict, active) {
|
||||
b.logger.Debugf("%s: queued request for model %s now served fast-path", b.name, req.model)
|
||||
b.grantHandler(req, req.model, p, inFlight)
|
||||
continue
|
||||
}
|
||||
if collidesWith(req.model, evict, active) {
|
||||
remaining = append(remaining, req)
|
||||
continue
|
||||
}
|
||||
if conflictsWithInFlight(evict, inFlight) {
|
||||
remaining = append(remaining, req)
|
||||
continue
|
||||
}
|
||||
b.logger.Debugf("%s: queued request for model %s now starting swap, evicting %v", b.name, req.model, evict)
|
||||
s := b.startSwap(req, evict)
|
||||
active[s.modelID] = s
|
||||
}
|
||||
*queued = remaining
|
||||
b.broadcastQueuePositions(*queued)
|
||||
}
|
||||
|
||||
// broadcastQueuePositions sends each queued request its current 1-indexed
|
||||
// position. Sends are non-blocking: if the channel is full, the old value is
|
||||
// drained first so the consumer always sees the latest position.
|
||||
func (b *baseRouter) broadcastQueuePositions(queued []handlerReq) {
|
||||
for i, req := range queued {
|
||||
pos := i + 1
|
||||
select {
|
||||
case req.positionCh <- pos:
|
||||
default:
|
||||
select {
|
||||
case <-req.positionCh:
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case req.positionCh <- pos:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *baseRouter) startSwap(initial handlerReq, evict []string) *activeSwap {
|
||||
swap := &activeSwap{
|
||||
modelID: initial.model,
|
||||
evict: evict,
|
||||
waiters: []handlerReq{initial},
|
||||
}
|
||||
b.planner.OnSwapStart(initial.model)
|
||||
go b.doSwap(initial.model, evict)
|
||||
return swap
|
||||
}
|
||||
|
||||
// activeTargets returns the IDs of every in-flight swap target except exclude.
|
||||
// baseRouter passes this to the planner so eviction decisions account for
|
||||
// models that have been committed to but have not yet transitioned to
|
||||
// StateStarting in their process state machine.
|
||||
func activeTargets(active map[string]*activeSwap, exclude string) []string {
|
||||
if len(active) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(active))
|
||||
for id := range active {
|
||||
if id == exclude {
|
||||
continue
|
||||
}
|
||||
out = append(out, id)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// collidesWith reports whether a new swap with this target and evict set can
|
||||
// safely run alongside the currently active swaps. Same-target callers should
|
||||
// JOIN (handled before this) — they do not collide with themselves.
|
||||
func collidesWith(target string, evict []string, active map[string]*activeSwap) bool {
|
||||
for id, s := range active {
|
||||
if id == target {
|
||||
continue
|
||||
}
|
||||
if containsString(evict, id) {
|
||||
return true
|
||||
}
|
||||
if containsString(s.evict, target) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// conflictsWithInFlight reports whether any model in evict is still handling
|
||||
// requests. Stopping a busy process would cancel its callers' connections,
|
||||
// so the router defers the swap until those callers finish.
|
||||
func conflictsWithInFlight(evict []string, inFlight map[string]int) bool {
|
||||
for _, m := range evict {
|
||||
if inFlight[m] > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsString(xs []string, s string) bool {
|
||||
for _, x := range xs {
|
||||
if x == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *baseRouter) doSwap(modelID string, toStop []string) {
|
||||
timeout := b.healthCheckTimeout()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, mID := range toStop {
|
||||
wg.Add(1)
|
||||
go func(p process.Process, id string) {
|
||||
defer wg.Done()
|
||||
if err := p.Stop(timeout); err != nil {
|
||||
b.logger.Warnf("%s: stopping %s failed: %v", b.name, id, err)
|
||||
}
|
||||
}(b.processes[mID], mID)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
target := b.processes[modelID]
|
||||
if target.State() == process.StateStopped {
|
||||
go func() {
|
||||
if err := target.Run(timeout); err != nil {
|
||||
b.logger.Warnf("%s: running %s exited: %v", b.name, modelID, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
err := target.WaitReady(b.shutdownCtx)
|
||||
|
||||
select {
|
||||
case b.swapDoneCh <- swapDone{modelID: modelID, err: err}:
|
||||
case <-b.shutdownCtx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
func (b *baseRouter) handleShutdown(req shutdownReq, active map[string]*activeSwap, queued []handlerReq) {
|
||||
shutdownErr := fmt.Errorf("%s is shutting down", b.name)
|
||||
|
||||
// Cancel shutdownCtx first so any waiter that is currently parked on
|
||||
// its respond channel can exit via its own shutdownCtx.Done() branch.
|
||||
// The grant calls below then either land (waiter happened to receive
|
||||
// before noticing shutdown) or fall through immediately via grant's
|
||||
// shutdownCtx case — either way the waiter sees a non-OK response.
|
||||
b.shutdownFn()
|
||||
|
||||
for _, s := range active {
|
||||
for _, w := range s.waiters {
|
||||
b.grant(w, handlerResp{err: shutdownErr})
|
||||
}
|
||||
}
|
||||
for _, w := range queued {
|
||||
b.grant(w, handlerResp{err: shutdownErr})
|
||||
}
|
||||
|
||||
stopTimeout := req.timeout
|
||||
if stopTimeout <= 0 {
|
||||
stopTimeout = b.healthCheckTimeout()
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i, p := range b.processes {
|
||||
wg.Add(1)
|
||||
go func(id string, p process.Process) {
|
||||
defer wg.Done()
|
||||
if err := p.Stop(stopTimeout); err != nil {
|
||||
b.logger.Warnf("%s failed to stop process %s: %v", b.name, id, err)
|
||||
}
|
||||
}(i, p)
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
if req.timeout > 0 {
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(req.timeout):
|
||||
<-done
|
||||
}
|
||||
} else {
|
||||
<-done
|
||||
}
|
||||
|
||||
req.respond <- nil
|
||||
}
|
||||
|
||||
func (b *baseRouter) healthCheckTimeout() time.Duration {
|
||||
t := time.Duration(b.config.HealthCheckTimeout) * time.Second
|
||||
if t <= 0 {
|
||||
return 30 * time.Second
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func (b *baseRouter) Handles(model string) bool {
|
||||
_, ok := b.processes[model]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (b *baseRouter) ProcessLogger(modelID string) (*logmon.Monitor, bool) {
|
||||
if p, ok := b.processes[modelID]; ok {
|
||||
return p.Logger(), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// RunningModels returns the current state of every process that is not stopped
|
||||
// or shut down. The processes map keys are fixed at construction and State()
|
||||
// is a snapshot, so this is safe to call without the run loop.
|
||||
func (b *baseRouter) RunningModels() map[string]process.ProcessState {
|
||||
running := make(map[string]process.ProcessState)
|
||||
for id, p := range b.processes {
|
||||
st := p.State()
|
||||
if st == process.StateStopped || st == process.StateShutdown {
|
||||
continue
|
||||
}
|
||||
running[id] = st
|
||||
}
|
||||
return running
|
||||
}
|
||||
|
||||
// Unload stops the named models, or every running model when none are named.
|
||||
// It blocks until each targeted process has stopped.
|
||||
//
|
||||
// The request is funneled through the run loop so eviction is coordinated
|
||||
// with the rest of the router's state: pending swap waiters for an
|
||||
// unloaded model are released with an error, queued requests for unloaded
|
||||
// models are dropped, and any deferred swaps that were waiting on those
|
||||
// models become eligible to start.
|
||||
//
|
||||
// In-flight requests being served by an unloaded process are not waited
|
||||
// for — Stop kills the upstream, those callers see whatever error the
|
||||
// reverse proxy surfaces and may retry. Their trackedServe defers fire
|
||||
// normally and decrement inFlight as the dying handlers return.
|
||||
func (b *baseRouter) Unload(timeout time.Duration, models ...string) {
|
||||
targets := models
|
||||
if len(targets) == 0 {
|
||||
targets = make([]string, 0, len(b.processes))
|
||||
for id := range b.processes {
|
||||
targets = append(targets, id)
|
||||
}
|
||||
}
|
||||
if len(targets) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
req := unloadReq{targets: targets, timeout: timeout, respond: make(chan struct{})}
|
||||
select {
|
||||
case b.unloadCh <- req:
|
||||
case <-b.runDone:
|
||||
return
|
||||
}
|
||||
<-req.respond
|
||||
}
|
||||
|
||||
// handleUnload runs on the run loop in response to an Unload call. It
|
||||
// reconciles router-owned state with the impending Stop, then performs
|
||||
// the Stop synchronously so callers of Unload remain blocked until each
|
||||
// targeted process has actually exited.
|
||||
func (b *baseRouter) handleUnload(req unloadReq, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
|
||||
unloadErr := fmt.Errorf("%s: model unloaded", b.name)
|
||||
|
||||
targetSet := make(map[string]bool, len(req.targets))
|
||||
for _, id := range req.targets {
|
||||
targetSet[id] = true
|
||||
}
|
||||
|
||||
// Release waiters of any in-flight swap whose target is being
|
||||
// unloaded. The swap goroutine itself is left to finish on its own;
|
||||
// when its swapDone arrives, handleSwapDone will find no entry in
|
||||
// active and silently drop it.
|
||||
for id := range targetSet {
|
||||
s, ok := active[id]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, w := range s.waiters {
|
||||
b.grant(w, handlerResp{err: unloadErr})
|
||||
}
|
||||
delete(active, id)
|
||||
}
|
||||
|
||||
// Drop queued requests addressed to unloaded models. Requests for
|
||||
// other models stay queued and may benefit from drainQueue at the end.
|
||||
if len(*queued) > 0 {
|
||||
kept := (*queued)[:0]
|
||||
for _, w := range *queued {
|
||||
if targetSet[w.model] {
|
||||
b.grant(w, handlerResp{err: unloadErr})
|
||||
continue
|
||||
}
|
||||
kept = append(kept, w)
|
||||
}
|
||||
*queued = kept
|
||||
}
|
||||
|
||||
// Stop the targeted processes. Done synchronously so Unload's caller
|
||||
// can rely on "after Unload returns, the process is stopped". inFlight
|
||||
// is intentionally NOT cleared here: each dying handler will fire its
|
||||
// trackedServe defer and reach handleServeDone in the normal way once
|
||||
// the run loop is free again.
|
||||
var wg sync.WaitGroup
|
||||
for id := range targetSet {
|
||||
p, ok := b.processes[id]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(id string, p process.Process) {
|
||||
defer wg.Done()
|
||||
if err := p.Stop(req.timeout); err != nil {
|
||||
b.logger.Warnf("%s: unloading %s failed: %v", b.name, id, err)
|
||||
}
|
||||
}(id, p)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Removing entries from active above may have unblocked queued
|
||||
// requests that previously collided with the now-cancelled swaps.
|
||||
b.drainQueue(active, inFlight, queued)
|
||||
|
||||
close(req.respond)
|
||||
}
|
||||
|
||||
func (b *baseRouter) Shutdown(timeout time.Duration) error {
|
||||
if !b.shuttingDown.CompareAndSwap(false, true) {
|
||||
return fmt.Errorf("%s shutdown already in progress", b.name)
|
||||
}
|
||||
req := shutdownReq{timeout: timeout, respond: make(chan error, 1)}
|
||||
select {
|
||||
case b.shutdownCh <- req:
|
||||
case <-b.runDone:
|
||||
return nil
|
||||
}
|
||||
return <-req.respond
|
||||
}
|
||||
|
||||
func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
if b.shuttingDown.Load() {
|
||||
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||
return
|
||||
}
|
||||
|
||||
data, err := FetchContext(req, b.config)
|
||||
if err != nil {
|
||||
SendError(w, req, err)
|
||||
return
|
||||
}
|
||||
|
||||
hr := handlerReq{
|
||||
model: data.ModelID,
|
||||
ctx: req.Context(),
|
||||
// Unbuffered: a successful send on respond proves the waiter is
|
||||
// alive and consuming. grant() relies on this to avoid handing a
|
||||
// handleFunc to a cancelled waiter and leaking the inFlight count.
|
||||
respond: make(chan handlerResp),
|
||||
positionCh: make(chan int, 1),
|
||||
}
|
||||
|
||||
select {
|
||||
case b.handlerCh <- hr:
|
||||
case <-req.Context().Done():
|
||||
return
|
||||
case <-b.shutdownCtx.Done():
|
||||
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||
return
|
||||
}
|
||||
|
||||
isModelReady := false
|
||||
if p, ok := b.processes[data.ModelID]; ok {
|
||||
isModelReady = p.State() == process.StateReady
|
||||
}
|
||||
shouldShowLoading := data.Streaming && data.SendLoadingState && isLoadingPath(req.URL.Path) && !isModelReady
|
||||
|
||||
var lw *loadingWriter
|
||||
cancelLoad := func() {}
|
||||
if shouldShowLoading {
|
||||
var swapCtx context.Context
|
||||
swapCtx, cancelLoad = context.WithCancel(req.Context())
|
||||
lw = newLoadingWriter(b.logger, data.ModelID, w, req)
|
||||
go lw.start(swapCtx)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case pos := <-hr.positionCh:
|
||||
lw.setUpdate(fmt.Sprintf("Queue position: #%d", pos))
|
||||
case <-swapCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
var resp handlerResp
|
||||
select {
|
||||
case resp = <-hr.respond:
|
||||
cancelLoad()
|
||||
if lw != nil {
|
||||
lw.waitForCompletion(1 * time.Second)
|
||||
}
|
||||
case <-req.Context().Done():
|
||||
cancelLoad()
|
||||
if lw != nil {
|
||||
lw.waitForCompletion(1 * time.Second)
|
||||
}
|
||||
return
|
||||
case <-b.shutdownCtx.Done():
|
||||
cancelLoad()
|
||||
if lw != nil {
|
||||
lw.waitForCompletion(1 * time.Second)
|
||||
}
|
||||
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||
return
|
||||
}
|
||||
|
||||
if resp.err != nil {
|
||||
SendError(w, req, resp.err)
|
||||
return
|
||||
}
|
||||
resp.handleFunc(w, req)
|
||||
}
|
||||
@@ -0,0 +1,863 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
)
|
||||
|
||||
// stubPlanner is a swapPlanner that returns a fixed eviction list per target
|
||||
// and never logs. It lets the base-router tests cover shared run-loop
|
||||
// behaviour without dragging in either real router's eviction rules.
|
||||
type stubPlanner struct {
|
||||
evict map[string][]string
|
||||
}
|
||||
|
||||
func (s *stubPlanner) EvictionFor(target string, _ []string) []string {
|
||||
if s.evict == nil {
|
||||
return nil
|
||||
}
|
||||
return s.evict[target]
|
||||
}
|
||||
|
||||
func (s *stubPlanner) OnSwapStart(string) {}
|
||||
|
||||
func newTestBase(t *testing.T, processes map[string]process.Process, planner swapPlanner) *baseRouter {
|
||||
t.Helper()
|
||||
conf := config.Config{HealthCheckTimeout: 5}
|
||||
b := newBaseRouter("test", conf, processes, planner, logmon.NewWriter(io.Discard))
|
||||
b.testProcessed = make(chan struct{}, 64)
|
||||
go b.run()
|
||||
t.Cleanup(func() {
|
||||
if !b.shuttingDown.Load() {
|
||||
_ = b.Shutdown(time.Second)
|
||||
}
|
||||
})
|
||||
return b
|
||||
}
|
||||
|
||||
func TestBaseRouter_RunningModels(t *testing.T) {
|
||||
ready := newFakeProcess("ready")
|
||||
ready.markReady()
|
||||
starting := newFakeProcess("starting")
|
||||
starting.setState(process.StateStarting)
|
||||
stopped := newFakeProcess("stopped")
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{
|
||||
"ready": ready, "starting": starting, "stopped": stopped,
|
||||
}, &stubPlanner{})
|
||||
|
||||
running := b.RunningModels()
|
||||
if len(running) != 2 {
|
||||
t.Fatalf("running=%v want 2 entries", running)
|
||||
}
|
||||
if running["ready"] != process.StateReady {
|
||||
t.Errorf("ready state=%q want ready", running["ready"])
|
||||
}
|
||||
if running["starting"] != process.StateStarting {
|
||||
t.Errorf("starting state=%q want starting", running["starting"])
|
||||
}
|
||||
if _, ok := running["stopped"]; ok {
|
||||
t.Errorf("stopped process should be excluded from RunningModels")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_UnloadAll(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
c := newFakeProcess("c")
|
||||
c.markReady()
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "c": c}, &stubPlanner{})
|
||||
b.Unload(time.Second)
|
||||
|
||||
if a.State() != process.StateStopped || c.State() != process.StateStopped {
|
||||
t.Fatalf("Unload() should stop every process: a=%q c=%q", a.State(), c.State())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_UnloadSpecificModel(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
c := newFakeProcess("c")
|
||||
c.markReady()
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "c": c}, &stubPlanner{})
|
||||
b.Unload(time.Second, "a")
|
||||
|
||||
if a.State() != process.StateStopped {
|
||||
t.Errorf("a should be stopped, got %q", a.State())
|
||||
}
|
||||
if c.State() != process.StateReady {
|
||||
t.Errorf("c should remain ready, got %q", c.State())
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_Unload_StopsInParallel verifies that Unload fans out its
|
||||
// Stop calls concurrently rather than stopping each process serially. Each
|
||||
// fakeProcess.Stop is pinned via stopBlock; the test only releases them
|
||||
// after observing every stopStarted, proving all three Stops were in
|
||||
// flight simultaneously.
|
||||
func TestBaseRouter_Unload_StopsInParallel(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
a.stopBlock = make(chan struct{})
|
||||
pb := newFakeProcess("b")
|
||||
pb.markReady()
|
||||
pb.stopBlock = make(chan struct{})
|
||||
pc := newFakeProcess("c")
|
||||
pc.markReady()
|
||||
pc.stopBlock = make(chan struct{})
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, &stubPlanner{})
|
||||
|
||||
unloadDone := make(chan struct{})
|
||||
go func() {
|
||||
b.Unload(time.Second, "a", "b", "c")
|
||||
close(unloadDone)
|
||||
}()
|
||||
|
||||
// All three Stop calls must start before any of them are allowed to
|
||||
// complete. If Unload was serial, only one stopStarted would fire
|
||||
// until we released its stopBlock, and this would deadlock.
|
||||
for _, p := range []*fakeProcess{a, pb, pc} {
|
||||
select {
|
||||
case <-p.stopStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("Stop on %s never started — Unload is not parallel", p.id)
|
||||
}
|
||||
}
|
||||
|
||||
// Release them; Unload should now return.
|
||||
close(a.stopBlock)
|
||||
close(pb.stopBlock)
|
||||
close(pc.stopBlock)
|
||||
|
||||
select {
|
||||
case <-unloadDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Unload did not return after stops released")
|
||||
}
|
||||
|
||||
for _, p := range []*fakeProcess{a, pb, pc} {
|
||||
if p.State() != process.StateStopped {
|
||||
t.Errorf("%s state=%q want stopped", p.id, p.State())
|
||||
}
|
||||
if got := p.stopCalls.Load(); got != 1 {
|
||||
t.Errorf("%s stopCalls=%d want 1", p.id, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_Unload_ReleasesActiveSwapWaiters verifies that Unload
|
||||
// rejoins router state: a request whose swap to the unloaded model is
|
||||
// still in progress receives an error, instead of being abandoned
|
||||
// against a process that's about to vanish.
|
||||
func TestBaseRouter_Unload_ReleasesActiveSwapWaiters(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
// autoReady=false: the swap parks on WaitReady so we can interrupt
|
||||
// it with Unload before it completes.
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w, newRequest("a"))
|
||||
close(done)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1) // handlerReq absorbed; swap started
|
||||
<-a.runStarted
|
||||
|
||||
b.Unload(time.Second, "a")
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("ServeHTTP did not return after Unload")
|
||||
}
|
||||
if w.Code == http.StatusOK {
|
||||
t.Errorf("expected non-OK status after Unload, got %d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if a.State() != process.StateStopped {
|
||||
t.Errorf("a state=%q want stopped", a.State())
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_Unload_DropsQueuedRequests verifies that queued requests
|
||||
// for an unloaded model receive an error rather than sitting forever in
|
||||
// the queue against state the router no longer maintains.
|
||||
func TestBaseRouter_Unload_DropsQueuedRequests(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
// Loading B evicts A — so a request for B while A is loading queues.
|
||||
planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}}
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, planner)
|
||||
|
||||
// r1 starts the swap to A and parks on WaitReady (autoReady=false).
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
<-a.runStarted
|
||||
|
||||
// r2 for B collides with A's in-flight swap and queues.
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
// Unload B — r2 (queued, targeting B) must be released with an error.
|
||||
b.Unload(time.Second, "b")
|
||||
|
||||
select {
|
||||
case <-done2:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("queued B request did not return after Unload(b)")
|
||||
}
|
||||
if w2.Code == http.StatusOK {
|
||||
t.Errorf("queued B request: expected non-OK status, got %d", w2.Code)
|
||||
}
|
||||
if got := pb.runCalls.Load(); got != 0 {
|
||||
t.Errorf("b.runCalls=%d want 0 (B should never have been started)", got)
|
||||
}
|
||||
|
||||
// Release r1 so the test cleans up cleanly.
|
||||
a.markReady()
|
||||
select {
|
||||
case <-done1:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("r1 did not complete after a.markReady")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_FastPath(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
b.ServeHTTP(w, newRequest("a"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := a.serveCalls.Load(); got != 1 {
|
||||
t.Errorf("serveCalls=%d want 1", got)
|
||||
}
|
||||
if got := a.runCalls.Load(); got != 0 {
|
||||
t.Errorf("runCalls=%d want 0 (fast path should not start)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_OnDemandStart(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.autoReady = true
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
b.ServeHTTP(w, newRequest("a"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := a.runCalls.Load(); got != 1 {
|
||||
t.Errorf("runCalls=%d want 1", got)
|
||||
}
|
||||
if got := a.serveCalls.Load(); got != 1 {
|
||||
t.Errorf("serveCalls=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_ConcurrentSameModel(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
// autoReady=false so the swap parks on WaitReady until we release it.
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||
|
||||
const N = 5
|
||||
var wg sync.WaitGroup
|
||||
codes := make([]int, N)
|
||||
for i := 0; i < N; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
w := httptest.NewRecorder()
|
||||
b.ServeHTTP(w, newRequest("a"))
|
||||
codes[i] = w.Code
|
||||
}(i)
|
||||
}
|
||||
|
||||
waitProcessed(t, b.testProcessed, N) // all N handlerReqs absorbed by run()
|
||||
<-a.runStarted // swap goroutine reached Run()
|
||||
a.markReady()
|
||||
wg.Wait()
|
||||
|
||||
for i, c := range codes {
|
||||
if c != http.StatusOK {
|
||||
t.Errorf("request %d: status=%d", i, c)
|
||||
}
|
||||
}
|
||||
if got := a.runCalls.Load(); got != 1 {
|
||||
t.Errorf("runCalls=%d want 1 (single swap should issue one Run)", got)
|
||||
}
|
||||
if got := a.serveCalls.Load(); got != N {
|
||||
t.Errorf("serveCalls=%d want %d", got, N)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_ContextCancel(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
// autoReady=false so swap parks forever until we mark ready.
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w1, newRequestCtx(ctx, "a"))
|
||||
close(done1)
|
||||
}()
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w2, newRequest("a"))
|
||||
close(done2)
|
||||
}()
|
||||
|
||||
waitProcessed(t, b.testProcessed, 2) // both requests joined the active swap
|
||||
<-a.runStarted
|
||||
|
||||
cancel()
|
||||
select {
|
||||
case <-done1:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("cancelled ServeHTTP did not return after ctx cancel")
|
||||
}
|
||||
|
||||
a.markReady()
|
||||
select {
|
||||
case <-done2:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("non-cancelled ServeHTTP did not complete after swap")
|
||||
}
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Errorf("second request status=%d body=%q", w2.Code, w2.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_QueuedDifferentModel(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pa := newFakeProcess("b")
|
||||
|
||||
// Loading b must stop a.
|
||||
planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}}
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pa}, planner)
|
||||
|
||||
// First request starts a swap to A; A's autoReady=false so it parks.
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
<-a.runStarted
|
||||
|
||||
// Second request for B should queue while A's swap is in flight.
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
if got := pa.runCalls.Load(); got != 0 {
|
||||
t.Errorf("b started early: runCalls=%d want 0 while A's swap is pending", got)
|
||||
}
|
||||
|
||||
// Release A's swap. B's swap should then run.
|
||||
a.markReady()
|
||||
waitProcessed(t, b.testProcessed, 1) // swapDone for A → B's swap kicked off
|
||||
<-pa.runStarted
|
||||
|
||||
select {
|
||||
case <-done1:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("A request did not complete")
|
||||
}
|
||||
pa.markReady()
|
||||
select {
|
||||
case <-done2:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("queued B request did not complete after A's swap")
|
||||
}
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Errorf("B status=%d body=%q", w2.Code, w2.Body.String())
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 1 {
|
||||
t.Errorf("a.stopCalls=%d want 1 (B's swap must stop A)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_QueueCollation verifies that incoming requests of the form
|
||||
// a, b, c, a, b, c collapse into three swaps (one per model) and that the
|
||||
// second request for each model rides the fast path — either joining the
|
||||
// active swap, or being pulled out of the queue when handleSwapDone promotes
|
||||
// the next model.
|
||||
func TestBaseRouter_QueueCollation(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
pc := newFakeProcess("c")
|
||||
|
||||
// Each model evicts the other two so all swaps are mutually exclusive.
|
||||
planner := &stubPlanner{evict: map[string][]string{
|
||||
"a": {"b", "c"},
|
||||
"b": {"a", "c"},
|
||||
"c": {"a", "b"},
|
||||
}}
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner)
|
||||
|
||||
var (
|
||||
completedMu sync.Mutex
|
||||
completed []string
|
||||
)
|
||||
record := func(id string) {
|
||||
completedMu.Lock()
|
||||
defer completedMu.Unlock()
|
||||
completed = append(completed, id)
|
||||
}
|
||||
|
||||
ids := []string{"a", "b", "c", "a", "b", "c"}
|
||||
var wg sync.WaitGroup
|
||||
for _, id := range ids {
|
||||
id := id
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
w := httptest.NewRecorder()
|
||||
b.ServeHTTP(w, newRequest(id))
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("%s: status=%d body=%q", id, w.Code, w.Body.String())
|
||||
return
|
||||
}
|
||||
record(id)
|
||||
}()
|
||||
// Wait for run() to absorb this request before launching the next,
|
||||
// so handlerCh receives them in launch order.
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
}
|
||||
|
||||
// All 6 are now parked in run()'s waiters/queue. Release each swap in
|
||||
// sequence, waiting deterministically for each promotion to fire.
|
||||
<-a.runStarted
|
||||
a.markReady()
|
||||
waitProcessed(t, b.testProcessed, 1) // swapDone(a) → b swap kicked off
|
||||
|
||||
<-pb.runStarted
|
||||
pb.markReady()
|
||||
waitProcessed(t, b.testProcessed, 1) // swapDone(b) → c swap kicked off
|
||||
|
||||
<-pc.runStarted
|
||||
pc.markReady()
|
||||
wg.Wait()
|
||||
|
||||
if got := len(completed); got != 6 {
|
||||
t.Fatalf("completed=%v want 6", completed)
|
||||
}
|
||||
|
||||
// run() fans out responses in model-grouped order (a1,a2 → b1,b2 → c1,c2)
|
||||
// but waiter goroutines may be scheduled in any order after their respond
|
||||
// channel fires, so completion order isn't deterministic. Per-model counts
|
||||
// (combined with the runCalls checks below) are sufficient to prove queue
|
||||
// collation collapsed each pair into a single swap.
|
||||
aDone, bDone, cDone := 0, 0, 0
|
||||
for _, id := range completed {
|
||||
switch id {
|
||||
case "a":
|
||||
aDone++
|
||||
case "b":
|
||||
bDone++
|
||||
case "c":
|
||||
cDone++
|
||||
}
|
||||
}
|
||||
if aDone != 2 || bDone != 2 || cDone != 2 {
|
||||
t.Errorf("per-model counts: a=%d b=%d c=%d, want 2 each (order=%v)", aDone, bDone, cDone, completed)
|
||||
}
|
||||
|
||||
// Single swap per model — the second request for each must have ridden
|
||||
// the fast path (joined active swap or joined a queued sibling), not
|
||||
// triggered an extra Run.
|
||||
if got := a.runCalls.Load(); got != 1 {
|
||||
t.Errorf("a.runCalls=%d want 1", got)
|
||||
}
|
||||
if got := pb.runCalls.Load(); got != 1 {
|
||||
t.Errorf("b.runCalls=%d want 1", got)
|
||||
}
|
||||
if got := pc.runCalls.Load(); got != 1 {
|
||||
t.Errorf("c.runCalls=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_ConcurrentDisjointSwaps verifies that two requests with
|
||||
// non-conflicting evict sets are loaded in parallel: both Run() calls happen
|
||||
// before either process is marked ready.
|
||||
func TestBaseRouter_ConcurrentDisjointSwaps(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
|
||||
// Empty evict sets for both: they can load in parallel.
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, &stubPlanner{})
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
// Both swaps must have reached Run() before either is marked ready —
|
||||
// proves they ran in parallel rather than serializing.
|
||||
<-a.runStarted
|
||||
<-pb.runStarted
|
||||
|
||||
a.markReady()
|
||||
pb.markReady()
|
||||
|
||||
select {
|
||||
case <-done1:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("request A did not complete")
|
||||
}
|
||||
select {
|
||||
case <-done2:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("request B did not complete")
|
||||
}
|
||||
|
||||
if w1.Code != http.StatusOK {
|
||||
t.Errorf("A status=%d body=%q", w1.Code, w1.Body.String())
|
||||
}
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Errorf("B status=%d body=%q", w2.Code, w2.Body.String())
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 0 {
|
||||
t.Errorf("a.stopCalls=%d want 0 (parallel swap, no eviction)", got)
|
||||
}
|
||||
if got := pb.stopCalls.Load(); got != 0 {
|
||||
t.Errorf("b.stopCalls=%d want 0 (parallel swap, no eviction)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_QueueDrainPromotesMultiple verifies that completing one swap
|
||||
// unblocks every queued request that no longer collides — they all start in
|
||||
// parallel rather than one-per-completion.
|
||||
func TestBaseRouter_QueueDrainPromotesMultiple(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
pc := newFakeProcess("c")
|
||||
|
||||
// A's swap evicts both B and C, so B and C must queue. Once A finishes
|
||||
// B and C themselves have empty evict sets, so they can start together.
|
||||
planner := &stubPlanner{evict: map[string][]string{
|
||||
"a": {"b", "c"},
|
||||
}}
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner)
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
<-a.runStarted
|
||||
|
||||
// B and C arrive while A is loading. evict_b and evict_c are empty,
|
||||
// but collidesWith returns true because they appear in A's evict set.
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
w3 := httptest.NewRecorder()
|
||||
done3 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w3, newRequest("c"))
|
||||
close(done3)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
if got := pb.runCalls.Load(); got != 0 {
|
||||
t.Errorf("b started early: runCalls=%d", got)
|
||||
}
|
||||
if got := pc.runCalls.Load(); got != 0 {
|
||||
t.Errorf("c started early: runCalls=%d", got)
|
||||
}
|
||||
|
||||
// Release A. The swapDone handler should drain the queue and start
|
||||
// both B and C in parallel.
|
||||
a.markReady()
|
||||
waitProcessed(t, b.testProcessed, 1) // swapDone(A) → drainQueue starts B and C
|
||||
<-pb.runStarted
|
||||
<-pc.runStarted
|
||||
|
||||
pb.markReady()
|
||||
pc.markReady()
|
||||
|
||||
for i, ch := range []chan struct{}{done1, done2, done3} {
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("request %d did not complete", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_Shutdown_FailsAllInFlight verifies that shutdown returns
|
||||
// the shutdown error to every waiter on every active swap AND to every
|
||||
// queued request.
|
||||
func TestBaseRouter_Shutdown_FailsAllInFlight(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
pc := newFakeProcess("c")
|
||||
|
||||
// a and b load in parallel (empty evicts). c collides with both.
|
||||
planner := &stubPlanner{evict: map[string][]string{
|
||||
"c": {"a", "b"},
|
||||
}}
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner)
|
||||
|
||||
const waitersPer = 2
|
||||
var wg sync.WaitGroup
|
||||
codes := make([]int, 0, 2*waitersPer+1)
|
||||
var codesMu sync.Mutex
|
||||
record := func(code int) {
|
||||
codesMu.Lock()
|
||||
codes = append(codes, code)
|
||||
codesMu.Unlock()
|
||||
}
|
||||
|
||||
launch := func(model string) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
w := httptest.NewRecorder()
|
||||
b.ServeHTTP(w, newRequest(model))
|
||||
record(w.Code)
|
||||
}()
|
||||
}
|
||||
|
||||
// Active swaps for a and b, each with 2 waiters.
|
||||
for i := 0; i < waitersPer; i++ {
|
||||
launch("a")
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
}
|
||||
for i := 0; i < waitersPer; i++ {
|
||||
launch("b")
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
}
|
||||
// c collides with both → queues.
|
||||
launch("c")
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
<-a.runStarted
|
||||
<-pb.runStarted
|
||||
|
||||
if err := b.Shutdown(time.Second); err != nil {
|
||||
t.Fatalf("Shutdown: %v", err)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
codesMu.Lock()
|
||||
defer codesMu.Unlock()
|
||||
if len(codes) != 2*waitersPer+1 {
|
||||
t.Fatalf("got %d responses, want %d", len(codes), 2*waitersPer+1)
|
||||
}
|
||||
for i, c := range codes {
|
||||
if c == http.StatusOK {
|
||||
t.Errorf("response %d: status=%d, want non-200 (shutdown)", i, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRouter_NoSwapWhileServing verifies that an already-loaded model
|
||||
// is not stopped to satisfy another model's swap while it is still handling
|
||||
// a request.
|
||||
//
|
||||
// Sequence:
|
||||
// 1. r1 (A) — A loads; ServeHTTP enters and is pinned via serveBlock.
|
||||
// 2. r2 (B, planner: B evicts A) — must NOT cause A.Stop while r1 is live.
|
||||
// 3. r3 (A) — arrives next; the existing code queues it because B's swap
|
||||
// intent collides with A.
|
||||
// 4. r1 released — A finishes r1, then r3 is served by A.
|
||||
// 5. B's swap then proceeds; r2 is served by B.
|
||||
//
|
||||
// fakeProcess.stoppedWhileServing flips true if Stop is ever called while
|
||||
// a ServeHTTP is in flight — a direct, race-free signal of the violation.
|
||||
func TestBaseRouter_NoSwapWhileServing(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
// autoReady left false: we markReady manually after observing runStarted,
|
||||
// so autoReady's setState(Ready) cannot race with a later Stop and leave
|
||||
// A in Ready, masking the bug.
|
||||
a.serveBlock = make(chan struct{})
|
||||
pb := newFakeProcess("b")
|
||||
// Same reasoning for B: park its swap on WaitReady until we choose.
|
||||
|
||||
planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}}
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, planner)
|
||||
|
||||
// r1 — load A and enter its ServeHTTP (which blocks on serveBlock).
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1) // handlerReq for r1
|
||||
<-a.runStarted
|
||||
a.markReady()
|
||||
waitProcessed(t, b.testProcessed, 1) // swapDone for A
|
||||
<-a.serveStarted
|
||||
|
||||
// r2 — would evict A. A must not be stopped while r1 is in flight.
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
// r3 — another request for A, arrives behind r2 and queues because
|
||||
// B's swap intent (which evicts A) is recorded as active.
|
||||
w3 := httptest.NewRecorder()
|
||||
done3 := make(chan struct{})
|
||||
go func() {
|
||||
b.ServeHTTP(w3, newRequest("a"))
|
||||
close(done3)
|
||||
}()
|
||||
waitProcessed(t, b.testProcessed, 1)
|
||||
|
||||
// Release r1 (and r3 if it is fast-pathed onto the still-loaded A).
|
||||
// The router must hold off B's swap until A has drained.
|
||||
close(a.serveBlock)
|
||||
|
||||
select {
|
||||
case <-done1:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("r1 did not complete after serveBlock release")
|
||||
}
|
||||
|
||||
// Wait for B.Run before marking it ready: markReady before Run would
|
||||
// skip the Run path entirely and leave pb.runCalls at 0. In a correct
|
||||
// implementation B's swap only starts after A has drained; in the
|
||||
// current implementation it has already started — either way runStarted
|
||||
// fires.
|
||||
<-pb.runStarted
|
||||
pb.markReady()
|
||||
|
||||
select {
|
||||
case <-done2:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("r2 did not complete after B marked ready")
|
||||
}
|
||||
select {
|
||||
case <-done3:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("r3 did not complete")
|
||||
}
|
||||
|
||||
if w1.Code != http.StatusOK || w2.Code != http.StatusOK || w3.Code != http.StatusOK {
|
||||
t.Fatalf("statuses: w1=%d w2=%d w3=%d", w1.Code, w2.Code, w3.Code)
|
||||
}
|
||||
if w1.Body.String() != "ok:a" {
|
||||
t.Errorf("r1 body=%q want ok:a", w1.Body.String())
|
||||
}
|
||||
if w3.Body.String() != "ok:a" {
|
||||
t.Errorf("r3 body=%q want ok:a (r3 must be served by A)", w3.Body.String())
|
||||
}
|
||||
if w2.Body.String() != "ok:b" {
|
||||
t.Errorf("r2 body=%q want ok:b", w2.Body.String())
|
||||
}
|
||||
if a.stoppedWhileServing.Load() {
|
||||
t.Errorf("A.Stop was called while A was still handling a request — the router swapped out a busy process")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_ModelNotFound(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
b.ServeHTTP(w, newRequest("unknown"))
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("status=%d want %d body=%q", w.Code, http.StatusNotFound, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRouter_Shutdown_StopsAllProcesses(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
go a.Run(0)
|
||||
pb := newFakeProcess("b")
|
||||
pb.markReady()
|
||||
go pb.Run(0)
|
||||
|
||||
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, &stubPlanner{})
|
||||
|
||||
if err := b.Shutdown(time.Second); err != nil {
|
||||
t.Fatalf("Shutdown: %v", err)
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 1 {
|
||||
t.Errorf("a.stopCalls=%d want 1", got)
|
||||
}
|
||||
if got := pb.stopCalls.Load(); got != 1 {
|
||||
t.Errorf("b.stopCalls=%d want 1", got)
|
||||
}
|
||||
|
||||
// Subsequent ServeHTTP should report 5xx.
|
||||
w := httptest.NewRecorder()
|
||||
b.ServeHTTP(w, newRequest("a"))
|
||||
if w.Code != http.StatusInternalServerError && w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("post-shutdown status=%d want 5xx body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Second Shutdown should report already in progress.
|
||||
if err := b.Shutdown(0); err == nil {
|
||||
t.Errorf("second Shutdown returned nil, want error")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
)
|
||||
|
||||
type Group struct {
|
||||
*baseRouter
|
||||
}
|
||||
|
||||
func NewGroup(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Group, error) {
|
||||
modelToGroup := make(map[string]string)
|
||||
for gid, gcfg := range conf.Groups {
|
||||
for _, mid := range gcfg.Members {
|
||||
if existing, dup := modelToGroup[mid]; dup {
|
||||
return nil, fmt.Errorf("model %q is in multiple groups: %q and %q", mid, existing, gid)
|
||||
}
|
||||
modelToGroup[mid] = gid
|
||||
}
|
||||
}
|
||||
|
||||
planner := &groupPlanner{
|
||||
config: conf,
|
||||
modelToGroup: modelToGroup,
|
||||
}
|
||||
|
||||
processes := make(map[string]process.Process, len(modelToGroup))
|
||||
base := newBaseRouter("group", conf, processes, planner, proxylog)
|
||||
planner.processes = processes
|
||||
|
||||
for mid := range modelToGroup {
|
||||
modelCfg, _, ok := conf.FindConfig(mid)
|
||||
if !ok {
|
||||
base.shutdownFn()
|
||||
return nil, fmt.Errorf("no model config for %q", mid)
|
||||
}
|
||||
procLog := logmon.NewWriter(upstreamlog)
|
||||
p, err := process.New(base.shutdownCtx, mid, modelCfg, procLog, proxylog)
|
||||
if err != nil {
|
||||
base.shutdownFn()
|
||||
return nil, fmt.Errorf("creating process for %q: %w", mid, err)
|
||||
}
|
||||
processes[mid] = p
|
||||
}
|
||||
|
||||
g := &Group{baseRouter: base}
|
||||
go base.run()
|
||||
return g, nil
|
||||
}
|
||||
|
||||
// groupPlanner decides evictions from static group configuration.
|
||||
//
|
||||
// Same-group siblings are stopped when the group has swap=true. Cross-group
|
||||
// members are stopped only when the target's group is exclusive; loading a
|
||||
// model from a non-exclusive group leaves running exclusive groups alone,
|
||||
// matching the gotcha in the original ProcessGroup behaviour.
|
||||
type groupPlanner struct {
|
||||
config config.Config
|
||||
modelToGroup map[string]string
|
||||
processes map[string]process.Process
|
||||
}
|
||||
|
||||
func (p *groupPlanner) EvictionFor(target string, alsoRunning []string) []string {
|
||||
tg := p.modelToGroup[target]
|
||||
tgCfg := p.config.Groups[tg]
|
||||
|
||||
seen := make(map[string]struct{})
|
||||
var result []string
|
||||
consider := func(mID string) {
|
||||
if mID == target {
|
||||
return
|
||||
}
|
||||
if _, dup := seen[mID]; dup {
|
||||
return
|
||||
}
|
||||
og := p.modelToGroup[mID]
|
||||
switch {
|
||||
case og == tg && tgCfg.Swap:
|
||||
seen[mID] = struct{}{}
|
||||
result = append(result, mID)
|
||||
// the previous ProcessGroup behaviour did not unload exclusive groups
|
||||
// when loading a non-exclusive model. This maintains that gotcha
|
||||
// for backwards compatibility. The newer swap matrix approach does not
|
||||
// have this issue.
|
||||
case og != tg && tgCfg.Exclusive:
|
||||
if ogCfg := p.config.Groups[og]; !ogCfg.Persistent {
|
||||
seen[mID] = struct{}{}
|
||||
result = append(result, mID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for mID, proc := range p.processes {
|
||||
st := proc.State()
|
||||
if st == process.StateStopped || st == process.StateShutdown {
|
||||
continue
|
||||
}
|
||||
consider(mID)
|
||||
}
|
||||
for _, mID := range alsoRunning {
|
||||
consider(mID)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (p *groupPlanner) OnSwapStart(target string) {}
|
||||
@@ -0,0 +1,331 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
)
|
||||
|
||||
// newTestGroup builds a Group directly from the supplied processes and config,
|
||||
// bypassing NewGroup's call to process.New.
|
||||
func newTestGroup(t *testing.T, conf config.Config, processes map[string]process.Process) *Group {
|
||||
t.Helper()
|
||||
modelToGroup := make(map[string]string)
|
||||
for gid, gcfg := range conf.Groups {
|
||||
for _, mid := range gcfg.Members {
|
||||
modelToGroup[mid] = gid
|
||||
}
|
||||
}
|
||||
planner := &groupPlanner{
|
||||
config: conf,
|
||||
modelToGroup: modelToGroup,
|
||||
processes: processes,
|
||||
}
|
||||
base := newBaseRouter("group", conf, processes, planner, logmon.NewWriter(io.Discard))
|
||||
base.testProcessed = make(chan struct{}, 64)
|
||||
g := &Group{baseRouter: base}
|
||||
go base.run()
|
||||
t.Cleanup(func() {
|
||||
if !g.shuttingDown.Load() {
|
||||
_ = g.Shutdown(time.Second)
|
||||
}
|
||||
})
|
||||
return g
|
||||
}
|
||||
|
||||
func TestGroup_NewGroup_DuplicateMembership(t *testing.T) {
|
||||
conf := config.Config{
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"g1": {Swap: true, Members: []string{"a"}},
|
||||
"g2": {Swap: true, Members: []string{"a"}},
|
||||
},
|
||||
Models: map[string]config.ModelConfig{
|
||||
"a": {},
|
||||
},
|
||||
}
|
||||
log := logmon.NewWriter(io.Discard)
|
||||
if _, err := NewGroup(conf, log, log); err == nil {
|
||||
t.Fatalf("expected error for duplicate membership")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroup_ServeHTTP_SwapStopsPrevious(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
go a.Run(0) // park a Run goroutine so Stop has something to release
|
||||
|
||||
b := newFakeProcess("b")
|
||||
b.autoReady = true
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"g": {Swap: true, Exclusive: true, Members: []string{"a", "b"}},
|
||||
},
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
g.ServeHTTP(w, newRequest("b"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 1 {
|
||||
t.Errorf("a.stopCalls=%d want 1", got)
|
||||
}
|
||||
if got := b.runCalls.Load(); got != 1 {
|
||||
t.Errorf("b.runCalls=%d want 1", got)
|
||||
}
|
||||
if got := b.serveCalls.Load(); got != 1 {
|
||||
t.Errorf("b.serveCalls=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroup_NonSwapGroup_NoStop(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
|
||||
b := newFakeProcess("b")
|
||||
b.autoReady = true
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"g": {Swap: false, Exclusive: false, Members: []string{"a", "b"}},
|
||||
},
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
g.ServeHTTP(w, newRequest("b"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 0 {
|
||||
t.Errorf("a.stopCalls=%d want 0 (swap=false should not stop siblings)", got)
|
||||
}
|
||||
if got := b.runCalls.Load(); got != 1 {
|
||||
t.Errorf("b.runCalls=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroup_CrossGroupExclusive(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
go a.Run(0)
|
||||
|
||||
b := newFakeProcess("b")
|
||||
b.autoReady = true
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"g1": {Swap: true, Exclusive: true, Members: []string{"a"}},
|
||||
"g2": {Swap: true, Exclusive: true, Members: []string{"b"}},
|
||||
},
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
g.ServeHTTP(w, newRequest("b"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 1 {
|
||||
t.Errorf("a.stopCalls=%d want 1 (cross-group exclusive must stop)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGroup_CrossGroupNonExclusiveParallel verifies that two requests for
|
||||
// models in distinct non-exclusive groups load in parallel rather than
|
||||
// serializing through the router's run loop.
|
||||
func TestGroup_CrossGroupNonExclusiveParallel(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"g1": {Swap: true, Exclusive: false, Members: []string{"a"}},
|
||||
"g2": {Swap: true, Exclusive: false, Members: []string{"b"}},
|
||||
},
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": pb})
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
g.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, g.testProcessed, 1)
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
g.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, g.testProcessed, 1)
|
||||
|
||||
// Both groups load concurrently — both must reach Run() before either is
|
||||
// marked ready. If the router still serialised, only one would proceed.
|
||||
<-a.runStarted
|
||||
<-pb.runStarted
|
||||
|
||||
a.markReady()
|
||||
pb.markReady()
|
||||
|
||||
for i, ch := range []chan struct{}{done1, done2} {
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("request %d did not complete", i)
|
||||
}
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 0 {
|
||||
t.Errorf("a.stopCalls=%d want 0 (parallel groups don't evict each other)", got)
|
||||
}
|
||||
if got := pb.stopCalls.Load(); got != 0 {
|
||||
t.Errorf("b.stopCalls=%d want 0 (parallel groups don't evict each other)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGroup_SameGroupSwapSerialises verifies that two same-group requests
|
||||
// (Swap=true) serialise even when both arrive while neither has reached
|
||||
// StateStarting yet — the alsoRunning hint to the planner closes that race.
|
||||
func TestGroup_SameGroupSwapSerialises(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"g": {Swap: true, Exclusive: false, Members: []string{"a", "b"}},
|
||||
},
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": pb})
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
g.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, g.testProcessed, 1)
|
||||
|
||||
// Request B arrives before A transitions to StateStarting in the process
|
||||
// state machine. Without the alsoRunning hint, the planner would not see
|
||||
// A as running, and B would start in parallel, violating Swap=true.
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
g.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, g.testProcessed, 1)
|
||||
|
||||
if got := pb.runCalls.Load(); got != 0 {
|
||||
t.Errorf("b started in parallel: runCalls=%d want 0", got)
|
||||
}
|
||||
|
||||
<-a.runStarted
|
||||
a.markReady()
|
||||
waitProcessed(t, g.testProcessed, 1) // swapDone(a) → b promoted
|
||||
<-pb.runStarted
|
||||
pb.markReady()
|
||||
|
||||
for i, ch := range []chan struct{}{done1, done2} {
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("request %d did not complete", i)
|
||||
}
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 1 {
|
||||
t.Errorf("a.stopCalls=%d want 1 (b's swap must stop a)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGroup_PersistentNotEvicted verifies that a group with persistent=true
|
||||
// is never evicted when another exclusive group starts loading. The running
|
||||
// model in the persistent group stays alive alongside the new one.
|
||||
func TestGroup_PersistentNotEvicted(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
go a.Run(0)
|
||||
|
||||
b := newFakeProcess("b")
|
||||
b.autoReady = true
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"persist": {Swap: true, Exclusive: false, Persistent: true, Members: []string{"a"}},
|
||||
"other": {Swap: true, Exclusive: true, Members: []string{"b"}},
|
||||
},
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
g.ServeHTTP(w, newRequest("b"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 0 {
|
||||
t.Errorf("a.stopCalls=%d want 0 (persistent group must not be evicted)", got)
|
||||
}
|
||||
if a.State() != process.StateStarting && a.State() != process.StateReady {
|
||||
t.Errorf("a state=%s want still running", a.State())
|
||||
}
|
||||
if got := b.runCalls.Load(); got != 1 {
|
||||
t.Errorf("b.runCalls=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGroup_NonExclusiveDoesNotUnloadExclusive pins a backwards-compatible
|
||||
// gotcha from the original ProcessGroup: when a model in a non-exclusive group
|
||||
// is loaded, any running exclusive group keeps running. The two coexist.
|
||||
func TestGroup_NonExclusiveDoesNotUnloadExclusive(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
go a.Run(0)
|
||||
|
||||
b := newFakeProcess("b")
|
||||
b.autoReady = true
|
||||
|
||||
conf := config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"g1": {Swap: true, Exclusive: true, Members: []string{"a"}},
|
||||
"g2": {Swap: true, Exclusive: false, Members: []string{"b"}},
|
||||
},
|
||||
}
|
||||
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
g.ServeHTTP(w, newRequest("b"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 0 {
|
||||
t.Errorf("a.stopCalls=%d want 0 (non-exclusive target must not unload exclusive group)", got)
|
||||
}
|
||||
if a.State() != process.StateStarting && a.State() != process.StateReady {
|
||||
t.Errorf("a state=%s want still running", a.State())
|
||||
}
|
||||
if got := b.runCalls.Load(); got != 1 {
|
||||
t.Errorf("b.runCalls=%d want 1", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,205 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
)
|
||||
|
||||
// fakeProcess is an in-memory implementation of process.Process used to drive
|
||||
// the routers through their state machine without spawning real upstreams.
|
||||
type fakeProcess struct {
|
||||
id string
|
||||
|
||||
mu sync.Mutex
|
||||
state process.ProcessState
|
||||
readyCh chan struct{}
|
||||
stopCh chan struct{}
|
||||
runStarted chan struct{} // closed on the first Run call
|
||||
stopStarted chan struct{} // closed on the first Stop call
|
||||
|
||||
autoReady bool
|
||||
|
||||
// serveBlock, when non-nil, makes ServeHTTP receive from it before
|
||||
// writing its response. Tests use this to hold a request in-flight.
|
||||
// Closing the channel releases every blocked ServeHTTP caller.
|
||||
serveBlock chan struct{}
|
||||
// serveStarted is closed on the first ServeHTTP entry, letting tests
|
||||
// wait deterministically for the handler to begin executing.
|
||||
serveStarted chan struct{}
|
||||
// stopBlock, when non-nil, makes Stop receive from it (after signalling
|
||||
// stopStarted) before completing. Tests use this to prove that several
|
||||
// Stop calls can be in flight simultaneously.
|
||||
stopBlock chan struct{}
|
||||
|
||||
runCalls atomic.Int32
|
||||
stopCalls atomic.Int32
|
||||
serveCalls atomic.Int32
|
||||
|
||||
// inFlightServe counts ServeHTTP calls currently inside the handler.
|
||||
// stoppedWhileServing flips true if Stop is ever called while that
|
||||
// counter is non-zero — a direct, race-free observation of the
|
||||
// "swap mid-request" anti-property.
|
||||
inFlightServe atomic.Int32
|
||||
stoppedWhileServing atomic.Bool
|
||||
}
|
||||
|
||||
func newFakeProcess(id string) *fakeProcess {
|
||||
return &fakeProcess{
|
||||
id: id,
|
||||
state: process.StateStopped,
|
||||
readyCh: make(chan struct{}),
|
||||
stopCh: make(chan struct{}),
|
||||
runStarted: make(chan struct{}),
|
||||
stopStarted: make(chan struct{}),
|
||||
serveStarted: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeProcess) setState(s process.ProcessState) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.state = s
|
||||
if s == process.StateReady {
|
||||
select {
|
||||
case <-f.readyCh:
|
||||
default:
|
||||
close(f.readyCh)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeProcess) State() process.ProcessState {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.state
|
||||
}
|
||||
|
||||
func (f *fakeProcess) markReady() { f.setState(process.StateReady) }
|
||||
|
||||
func (f *fakeProcess) Run(_ time.Duration) error {
|
||||
f.runCalls.Add(1)
|
||||
f.mu.Lock()
|
||||
if f.state != process.StateStopped {
|
||||
s := f.state
|
||||
f.mu.Unlock()
|
||||
return fmt.Errorf("fakeProcess %s: Run called while %s", f.id, s)
|
||||
}
|
||||
f.state = process.StateStarting
|
||||
sc := f.stopCh
|
||||
select {
|
||||
case <-f.runStarted:
|
||||
default:
|
||||
close(f.runStarted)
|
||||
}
|
||||
f.mu.Unlock()
|
||||
|
||||
if f.autoReady {
|
||||
f.setState(process.StateReady)
|
||||
}
|
||||
<-sc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeProcess) Stop(_ time.Duration) error {
|
||||
f.stopCalls.Add(1)
|
||||
if f.inFlightServe.Load() > 0 {
|
||||
f.stoppedWhileServing.Store(true)
|
||||
}
|
||||
f.mu.Lock()
|
||||
select {
|
||||
case <-f.stopStarted:
|
||||
default:
|
||||
close(f.stopStarted)
|
||||
}
|
||||
f.mu.Unlock()
|
||||
|
||||
// Test hook: hold Stop here so the test can prove multiple Stops are
|
||||
// in flight at the same time before any of them complete.
|
||||
if f.stopBlock != nil {
|
||||
<-f.stopBlock
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
if f.state == process.StateStopped {
|
||||
return nil
|
||||
}
|
||||
f.state = process.StateStopped
|
||||
select {
|
||||
case <-f.stopCh:
|
||||
default:
|
||||
close(f.stopCh)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeProcess) WaitReady(ctx context.Context) error {
|
||||
f.mu.Lock()
|
||||
if f.state == process.StateReady {
|
||||
f.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
rc := f.readyCh
|
||||
f.mu.Unlock()
|
||||
select {
|
||||
case <-rc:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeProcess) Logger() *logmon.Monitor { return logmon.NewWriter(io.Discard) }
|
||||
|
||||
func (f *fakeProcess) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
|
||||
f.serveCalls.Add(1)
|
||||
f.inFlightServe.Add(1)
|
||||
defer f.inFlightServe.Add(-1)
|
||||
f.mu.Lock()
|
||||
select {
|
||||
case <-f.serveStarted:
|
||||
default:
|
||||
close(f.serveStarted)
|
||||
}
|
||||
f.mu.Unlock()
|
||||
if f.serveBlock != nil {
|
||||
<-f.serveBlock
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, "ok:%s", f.id)
|
||||
}
|
||||
|
||||
// waitProcessed drains n events from ch, fataling on timeout. One event fires
|
||||
// per handlerReq or swapDone fully absorbed by run().
|
||||
func waitProcessed(t *testing.T, ch chan struct{}, n int) {
|
||||
t.Helper()
|
||||
for i := 0; i < n; i++ {
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("waitProcessed: only %d/%d events received", i, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newRequest(model string) *http.Request {
|
||||
body := fmt.Sprintf(`{"model":%q}`, model)
|
||||
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
return r
|
||||
}
|
||||
|
||||
func newRequestCtx(ctx context.Context, model string) *http.Request {
|
||||
return newRequest(model).WithContext(ctx)
|
||||
}
|
||||
@@ -0,0 +1,249 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
var loadingPaths = []string{
|
||||
"/v1/chat/completions",
|
||||
}
|
||||
|
||||
func isLoadingPath(path string) bool {
|
||||
for _, p := range loadingPaths {
|
||||
if strings.HasPrefix(path, p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type loadingWriter struct {
|
||||
hasWritten bool
|
||||
writer http.ResponseWriter
|
||||
req *http.Request
|
||||
ctx context.Context
|
||||
logger *logmon.Monitor
|
||||
modelName string
|
||||
startTime time.Time
|
||||
|
||||
pendingMu sync.Mutex
|
||||
pendingUpdate string
|
||||
|
||||
// closed by start when the goroutine finishes (after cleanup messages)
|
||||
done chan struct{}
|
||||
|
||||
// test-only: closed when start enters its loop
|
||||
loopStarted chan struct{}
|
||||
// test-only: override the 1s tick interval
|
||||
tickDuration time.Duration
|
||||
// test-only: override character streaming speed (0 = no delay)
|
||||
charPerSecond float64
|
||||
}
|
||||
|
||||
func newLoadingWriter(logger *logmon.Monitor, modelName string, w http.ResponseWriter, req *http.Request) *loadingWriter {
|
||||
s := &loadingWriter{
|
||||
writer: w,
|
||||
req: req,
|
||||
ctx: req.Context(),
|
||||
logger: logger,
|
||||
modelName: modelName,
|
||||
startTime: time.Now(),
|
||||
tickDuration: 750 * time.Millisecond,
|
||||
charPerSecond: 75,
|
||||
}
|
||||
|
||||
s.Header().Set("Content-Type", "text/event-stream")
|
||||
s.Header().Set("Cache-Control", "no-cache")
|
||||
s.Header().Set("Connection", "keep-alive")
|
||||
s.WriteHeader(http.StatusOK)
|
||||
s.sendLine("━━━━━")
|
||||
s.sendLine(fmt.Sprintf("llama-swap loading model: %s", modelName))
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *loadingWriter) setUpdate(msg string) {
|
||||
s.pendingMu.Lock()
|
||||
s.pendingUpdate = msg
|
||||
s.pendingMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *loadingWriter) start(ctx context.Context) {
|
||||
s.done = make(chan struct{})
|
||||
defer close(s.done)
|
||||
|
||||
defer func() {
|
||||
// Skip cleanup writes if the client disconnected — the connection
|
||||
// is being torn down and flushing against it will panic.
|
||||
if s.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
duration := time.Since(s.startTime)
|
||||
s.sendData("\n")
|
||||
s.sendLine(fmt.Sprintf("Done! (%.2fs)", duration.Seconds()))
|
||||
s.sendLine("━━━━━")
|
||||
s.sendLine(" ")
|
||||
}()
|
||||
|
||||
remarks := make([]string, len(loadingRemarks))
|
||||
copy(remarks, loadingRemarks)
|
||||
rand.Shuffle(len(remarks), func(i, j int) {
|
||||
remarks[i], remarks[j] = remarks[j], remarks[i]
|
||||
})
|
||||
ri := 0
|
||||
|
||||
nextRemarkIn := time.Duration(2+rand.Intn(4)) * time.Second
|
||||
lastRemarkTime := time.Time{}
|
||||
|
||||
ticker := time.NewTicker(s.tickDuration)
|
||||
defer ticker.Stop()
|
||||
|
||||
if s.loopStarted != nil {
|
||||
close(s.loopStarted)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.pendingMu.Lock()
|
||||
update := s.pendingUpdate
|
||||
s.pendingUpdate = ""
|
||||
s.pendingMu.Unlock()
|
||||
|
||||
if update != "" {
|
||||
s.sendData("\n")
|
||||
s.sendInline(update)
|
||||
s.sendData(" ")
|
||||
lastRemarkTime = time.Now()
|
||||
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
|
||||
} else if time.Since(lastRemarkTime) >= nextRemarkIn {
|
||||
remark := remarks[ri%len(remarks)]
|
||||
ri++
|
||||
s.sendData("\n")
|
||||
s.sendInline(remark)
|
||||
s.sendData(" ")
|
||||
lastRemarkTime = time.Now()
|
||||
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
|
||||
} else {
|
||||
s.sendData(".")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *loadingWriter) waitForCompletion(timeout time.Duration) bool {
|
||||
if s.done == nil {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-s.done:
|
||||
return true
|
||||
case <-time.After(timeout):
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *loadingWriter) sendInline(text string) {
|
||||
chunkSize := 10
|
||||
if s.charPerSecond > 0 {
|
||||
chunkSize = max(3, int(s.charPerSecond)/15)
|
||||
}
|
||||
|
||||
runes := []rune(text)
|
||||
for i := 0; i < len(runes); {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
end := i + chunkSize
|
||||
if end > len(runes) {
|
||||
end = len(runes)
|
||||
}
|
||||
chunk := string(runes[i:end])
|
||||
s.sendData(chunk)
|
||||
i = end
|
||||
|
||||
if i < len(runes) && s.charPerSecond > 0 {
|
||||
time.Sleep(time.Duration(float64(time.Second) * float64(len(chunk)) / s.charPerSecond))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *loadingWriter) sendLine(line string) {
|
||||
if line == "" {
|
||||
s.sendData("\n")
|
||||
return
|
||||
}
|
||||
s.sendInline(line)
|
||||
s.sendData("\n")
|
||||
}
|
||||
|
||||
func (s *loadingWriter) sendData(data string) {
|
||||
type Delta struct {
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
}
|
||||
type Choice struct {
|
||||
Delta Delta `json:"delta"`
|
||||
}
|
||||
type SSEMessage struct {
|
||||
Choices []Choice `json:"choices"`
|
||||
}
|
||||
|
||||
msg := SSEMessage{
|
||||
Choices: []Choice{
|
||||
{
|
||||
Delta: Delta{
|
||||
ReasoningContent: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
s.logger.Errorf("<%s> Failed to marshal SSE message: %v", s.modelName, err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData)
|
||||
if err != nil {
|
||||
s.logger.Debugf("<%s> Failed to write SSE data (client likely disconnected): %v", s.modelName, err)
|
||||
return
|
||||
}
|
||||
s.Flush()
|
||||
}
|
||||
|
||||
func (s *loadingWriter) Header() http.Header {
|
||||
return s.writer.Header()
|
||||
}
|
||||
|
||||
func (s *loadingWriter) Write(data []byte) (int, error) {
|
||||
return s.writer.Write(data)
|
||||
}
|
||||
|
||||
func (s *loadingWriter) WriteHeader(statusCode int) {
|
||||
if s.hasWritten {
|
||||
return
|
||||
}
|
||||
s.hasWritten = true
|
||||
s.writer.WriteHeader(statusCode)
|
||||
s.Flush()
|
||||
}
|
||||
|
||||
func (s *loadingWriter) Flush() {
|
||||
if flusher, ok := s.writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
package router
|
||||
|
||||
var loadingRemarks = []string{
|
||||
"Still faster than your last standup meeting",
|
||||
"Reticulating splines",
|
||||
"Waking up the hamsters",
|
||||
"Teaching the model manners",
|
||||
"Convincing the GPU to participate",
|
||||
"Loading weights (they're heavy)",
|
||||
"Please enjoy this elevator music in your head",
|
||||
"Pretending to be productive",
|
||||
"Reading the entire internet, page by page",
|
||||
"Staring at the abyss, the abyss is buffering",
|
||||
"Applying layer after layer of disembodied cognition",
|
||||
"Remembering everything it forgot during quantization",
|
||||
"Counting to 405 billion, one parameter at a time",
|
||||
"Summoning the stochastic parroting",
|
||||
"Hold on, the GPU is questioning its existence",
|
||||
"Deciding which facts to hallucinate today",
|
||||
"Untangling the transformer spaghetti",
|
||||
"Warming up the token soup",
|
||||
"Your prompt is in a queue, behind 7 billion other thoughts",
|
||||
"Running `sudo apt-get install intelligence`",
|
||||
"Defragmenting the latent space",
|
||||
"Polishing each matrix multiplication by hand",
|
||||
"Whispering sweet nothings to the attention heads",
|
||||
"Aligning with human values, one reluctant epoch at a time",
|
||||
"The model is thinking about what it's about to think about",
|
||||
"Loading... and by loading we mean making you wait",
|
||||
"Spinning up the cloud GPU, please be patient while we burn your credits",
|
||||
"Applying duct tape to the context window",
|
||||
"Bribing the GPU scheduler for a timeslice",
|
||||
"Would you like to hear a fun fact while we load? Too bad.",
|
||||
"Hot swapping your sanity for an LLM",
|
||||
"Compressing optimism into FP16",
|
||||
"Ignoring 90% of the attention to save you 50% of the time",
|
||||
"Counting the exact same thing three times just to be sure",
|
||||
"Sorry, the inference you have reached is not in service",
|
||||
"Rotating the positional encodings counterclockwise for good luck",
|
||||
"Your call is very important to us. Please continue to hold.",
|
||||
"Unpacking the blobs. All 300GB of them.",
|
||||
"Initializing the thing that initializes the other thing",
|
||||
"Converting electricity into existential dread",
|
||||
"Flattening the curve... wait, the tensor. Flattening the tensor.",
|
||||
"Fetching the fetch of a fetch, callback hell edition",
|
||||
"The GPU is at 100%. The fan is now a helicopter.",
|
||||
"Baking the weights at 350° for a golden-brown inference",
|
||||
"Recalibrating the confidence of things it's still wrong about",
|
||||
"Have you tried turning it off and on again? No? Good, wait here.",
|
||||
"Simulating deep thought by pausing dramatically",
|
||||
"Loading the model that knows more than you but still can't count r's in 'strawberry'",
|
||||
"Convincing CUDA to cooperate. This may take a while.",
|
||||
"VRAM: 23.9GB used of 24GB. Living on the edge.",
|
||||
"Processing your request with the urgency of a DMV employee",
|
||||
"This model was trained on the entire internet, including that embarrassing blog you wrote in 2008",
|
||||
"Dispatching tokens through a series of increasingly confused matrix multiplies",
|
||||
"Gently lowering your expectations",
|
||||
"Applying softmax to our feelings about this load time",
|
||||
"Autoregressively generating disappointment, one token at a time",
|
||||
"The magic is happening. Somewhere. Probably.",
|
||||
"Synchronizing the parallel processes that run in parallel but really don't",
|
||||
"Calculating the meaning of life. Spoiler: it's 42, but we're double-checking.",
|
||||
"Loading... just like it said 30 seconds ago. And will say 30 seconds from now.",
|
||||
"Pre-warming the cache so the first query is only slightly slower than the rest",
|
||||
"Have you considered that maybe your question wasn't worth all this compute?",
|
||||
"Downloading more RAM (no, really, we're mmap-ing the weights)",
|
||||
"Translating your prompt into math it barely understands",
|
||||
"Estimating your time remaining with 0% accuracy",
|
||||
"Buffering enthusiasm",
|
||||
"Model is loading. Go make some coffee. Or a three-course meal.",
|
||||
"Tokenizing the dictionary, filing a grievance on behalf of 'antidisestablishmentarianism'",
|
||||
"Polling for readiness in a loop that would make your CS professor weep",
|
||||
"Performing percussive maintenance on the attention mechanism",
|
||||
"This loading screen is singlehandedly reversing climate progress",
|
||||
"Decompressing the hopes and dreams of thousands of underpaid labelers",
|
||||
"Filling the key-value cache with the ghost of prompts past",
|
||||
"Currently at step 3 of 9,742 of loading. We'll get there. Eventually.",
|
||||
"If you stare at the spinner, it spins slower. It's science.",
|
||||
"Multiplying matricies with the enthusiasm of a teenager doing chores",
|
||||
"Applying `torch.nap()` until the model feels refreshed",
|
||||
"Reacquainting the model with the concept of 'facts' it forgot during fine-tuning",
|
||||
"Sorry for the wait. No, wait, we're not actually sorry.",
|
||||
"Your GPU is now a space heater with a side hustle in linear algebra",
|
||||
"Allocating memory like a billionaire allocates tax avoidance strategies",
|
||||
"The model saw \"As an AI language model\" and won't stop saying it now",
|
||||
"Installing dependencies you didn't know existed and will never use again",
|
||||
"Re-reading 'Attention Is All You Need' for the 400th time",
|
||||
"Convincing the embedding layer that context is overrated",
|
||||
"Manually untangling the residual connections with a tiny comb",
|
||||
"On hold with the cloud provider trying to explain why 8 H100s isn't enough",
|
||||
"Adjusting temperatures: model is 0.7, server room is 104°F",
|
||||
"Please hold while we justify this electricity bill to accounting",
|
||||
"Stacking decoder blocks like a Jenga tower at a LAN party",
|
||||
"Compensating for your lack of patience with our lack of speed",
|
||||
"This is a loading screen comment. Loading screens have comments now. Welcome to the future.",
|
||||
"Processing the entire works of Shakespeare backwards just in case",
|
||||
"The model is loading slower than your last `npm install`",
|
||||
"Rehearsing plausible-sounding explanations for why it got everything wrong",
|
||||
"Populating the context with filler while you wait for actual content",
|
||||
"Optimizing for BLEU score, which definitely correlates with making you laugh",
|
||||
"Generating an embedding for each and every letter of the alphabet, individually",
|
||||
"Coming soon: llama-swap v2 with actual performance improvements. Probably.",
|
||||
"Loading a model larger than your attention span",
|
||||
"Performing a seance to invoke the spirit of Geoff Hinton",
|
||||
"Did you know loading screens were invented to prevent users from smashing their monitors? Now you do.",
|
||||
"Converting all the internet's bad opinions into a surprisingly useful autocomplete",
|
||||
"Laying down each layer with the care of a Michelin-starred pastry chef",
|
||||
"Checking if the model still thinks birds are government drones. Yep.",
|
||||
"Activating the neurons responsible for 'I cannot assist with that request'",
|
||||
"This model was trained on the same internet that brought you Rickrolling. You're welcome.",
|
||||
"Realigning the alignment so it aligns with the previous alignment",
|
||||
"Running `nvidia-smi` and sighing heavily",
|
||||
"If you close your eyes, the loading bar moves faster. Proven by science.",
|
||||
"EULA said 'by using this software you agree to wait forever' and you clicked Accept",
|
||||
"Zipping the GPUs to make them go faster",
|
||||
"Padding the context window with existential padding",
|
||||
"We could have used a smaller model but someone wanted 'quality'",
|
||||
"Disentangling the latent space into something resembling coherence",
|
||||
"Slow is smooth, smooth is fast, but this is just slow",
|
||||
"Memory-mapping like it's a AAA title from 2012",
|
||||
"Your patience has been tokenized and added to the training set. Thank you for your contribution.",
|
||||
"Loading is CPU-bound and your CPU is busy regretting its life choices",
|
||||
"Exploring the high-dimensional manifold of ways to say 'just a moment'",
|
||||
"The model is experiencing a brief but intense moment of imposter syndrome",
|
||||
"Initializing 7B parameters by rolling 7B 16-sided dice",
|
||||
"Panic! at the disk I/O",
|
||||
"Intelligence is loading... your definition of intelligence may vary",
|
||||
"This model was distilled. Unlike your patience, which is evaporating.",
|
||||
"Unzipping the model. It's a .gguf file, not a metaphor.",
|
||||
"Running inference on the concept of 'soon' to estimate remaining time",
|
||||
"Loading with all the speed of a government-funded IT project",
|
||||
"A blank terminal is a terrible thing to waste. Here's a loading message instead.",
|
||||
}
|
||||
@@ -0,0 +1,328 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
func TestLoadingWriter_SSEHeadersAndInitialMessage(t *testing.T) {
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||
|
||||
if ct := lw.Header().Get("Content-Type"); ct != "text/event-stream" {
|
||||
t.Errorf("Content-Type: want text/event-stream, got %q", ct)
|
||||
}
|
||||
if cc := lw.Header().Get("Cache-Control"); cc != "no-cache" {
|
||||
t.Errorf("Cache-Control: want no-cache, got %q", cc)
|
||||
}
|
||||
if conn := lw.Header().Get("Connection"); conn != "keep-alive" {
|
||||
t.Errorf("Connection: want keep-alive, got %q", conn)
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
if !strings.HasPrefix(body, "data: ") {
|
||||
t.Errorf("expected SSE data: prefix, got: %s", body)
|
||||
}
|
||||
|
||||
content := extractStreamedContent(body)
|
||||
if !strings.Contains(content, "━━━━━\n") {
|
||||
t.Errorf("missing separator in streamed content: %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "llama-swap loading model: test-model\n") {
|
||||
t.Errorf("missing initial message in streamed content: %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadingWriter_WriteHeaderOnce(t *testing.T) {
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||
lw.WriteHeader(http.StatusCreated)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("first WriteHeader: want %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadingWriter_WritePassthrough(t *testing.T) {
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||
lw.Write([]byte("hello"))
|
||||
lw.Flush()
|
||||
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, "hello") {
|
||||
t.Errorf("Write passthrough failed, body: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadingWriter_StartStopsOnCancel(t *testing.T) {
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||
lw.tickDuration = 10 * time.Millisecond
|
||||
lw.loopStarted = make(chan struct{})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
go lw.start(ctx)
|
||||
<-lw.loopStarted
|
||||
cancel()
|
||||
|
||||
if !lw.waitForCompletion(time.Second) {
|
||||
t.Fatal("waitForCompletion timed out")
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, "Done!") {
|
||||
t.Errorf("expected Done! message, body: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadingWriter_StartShowsSetUpdate(t *testing.T) {
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||
lw.tickDuration = 10 * time.Millisecond
|
||||
lw.charPerSecond = 0
|
||||
lw.loopStarted = make(chan struct{})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go lw.start(ctx)
|
||||
<-lw.loopStarted
|
||||
|
||||
lw.setUpdate("custom status message")
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
if !lw.waitForCompletion(time.Second) {
|
||||
t.Fatal("waitForCompletion timed out")
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
content := extractStreamedContent(body)
|
||||
if !strings.Contains(content, "custom status message") {
|
||||
t.Errorf("expected setUpdate message in output, got: %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadingWriter_SendDataFormat(t *testing.T) {
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||
lw.sendData("hello world")
|
||||
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, `"reasoning_content":"hello world"`) {
|
||||
t.Errorf("expected reasoning_content in SSE data, body: %s", body)
|
||||
}
|
||||
if !strings.HasPrefix(body, "data: ") {
|
||||
t.Errorf("expected data: prefix, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadingWriter_SendLine(t *testing.T) {
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||
lw.charPerSecond = 0
|
||||
|
||||
// Capture only the content from this sendLine call
|
||||
before := w.Body.Len()
|
||||
lw.sendLine("line content")
|
||||
after := w.Body.Len()
|
||||
chunkBody := w.Body.String()[before:after]
|
||||
|
||||
content := extractStreamedContent(chunkBody)
|
||||
if content != "line content\n" {
|
||||
t.Errorf("expected complete streamed line, got: %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadingWriter_FlushesPeriodicallyDuringStatusUpdates(t *testing.T) {
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||
lw.tickDuration = 10 * time.Millisecond
|
||||
lw.charPerSecond = 0
|
||||
lw.loopStarted = make(chan struct{})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
lw.start(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
<-lw.loopStarted
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
<-done
|
||||
|
||||
body := w.Body.String()
|
||||
lines := countSSEMessages(body)
|
||||
if lines < 2 {
|
||||
t.Errorf("expected multiple SSE messages from periodic updates, got %d", lines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadingWriter_ReqStored(t *testing.T) {
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
lw := newLoadingWriter(logger, "test-model", w, req)
|
||||
if lw.req != req {
|
||||
t.Fatal("req not stored")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLoadingPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
want bool
|
||||
}{
|
||||
{"/v1/chat/completions", true},
|
||||
{"/v1/chat/completions/extra", true},
|
||||
{"/v1/completions", false},
|
||||
{"/v1/embeddings", false},
|
||||
{"/health", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
if got := isLoadingPath(tt.path); got != tt.want {
|
||||
t.Errorf("isLoadingPath(%q) = %v, want %v", tt.path, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_Streaming_GET(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
wantStreaming bool
|
||||
}{
|
||||
{"streaming true", "model=llama3&stream=true", true},
|
||||
{"streaming false", "model=llama3&stream=false", false},
|
||||
{"no stream param", "model=llama3", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
|
||||
got, err := ExtractContext(r)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
if got.Streaming != tt.wantStreaming {
|
||||
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_Streaming_JSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantStreaming bool
|
||||
}{
|
||||
{"streaming true", `{"model":"llama3","stream":true}`, true},
|
||||
{"streaming false", `{"model":"llama3","stream":false}`, false},
|
||||
{"no stream param", `{"model":"llama3"}`, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
got, err := ExtractContext(r)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
if got.Streaming != tt.wantStreaming {
|
||||
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_Streaming_URLEncodedForm(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader("model=whisper-1&stream=true"))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
got, err := ExtractContext(r)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
if !got.Streaming {
|
||||
t.Error("Streaming should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func countSSEMessages(s string) int {
|
||||
scanner := bufio.NewScanner(strings.NewReader(s))
|
||||
count := 0
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func extractStreamedContent(body string) string {
|
||||
var result strings.Builder
|
||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
jsonData := strings.TrimPrefix(line, "data: ")
|
||||
var msg struct {
|
||||
Choices []struct {
|
||||
Delta struct {
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
} `json:"delta"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(jsonData), &msg); err != nil {
|
||||
continue
|
||||
}
|
||||
if len(msg.Choices) > 0 {
|
||||
result.WriteString(msg.Choices[0].Delta.ReasoningContent)
|
||||
}
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
@@ -0,0 +1,100 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
)
|
||||
|
||||
type Matrix struct {
|
||||
*baseRouter
|
||||
}
|
||||
|
||||
func NewMatrix(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Matrix, error) {
|
||||
if conf.Matrix == nil {
|
||||
return nil, fmt.Errorf("matrix router requires a matrix configuration")
|
||||
}
|
||||
|
||||
planner := &matrixPlanner{
|
||||
solver: newMatrixSolver(conf.ExpandedSets, conf.Matrix.ResolvedEvictCosts()),
|
||||
logger: proxylog,
|
||||
}
|
||||
|
||||
// Build a process for every model in the config. Any model can run alone
|
||||
// even if it is not part of a set; this mirrors proxy.NewMatrix.
|
||||
processes := make(map[string]process.Process, len(conf.Models))
|
||||
base := newBaseRouter("matrix", conf, processes, planner, proxylog)
|
||||
planner.processes = processes
|
||||
|
||||
for mid, modelCfg := range conf.Models {
|
||||
procLog := logmon.NewWriter(upstreamlog)
|
||||
p, err := process.New(base.shutdownCtx, mid, modelCfg, procLog, proxylog)
|
||||
if err != nil {
|
||||
base.shutdownFn()
|
||||
return nil, fmt.Errorf("creating process for %q: %w", mid, err)
|
||||
}
|
||||
processes[mid] = p
|
||||
}
|
||||
|
||||
r := &Matrix{baseRouter: base}
|
||||
go base.run()
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// matrixPlanner decides evictions by asking the matrix solver against the
|
||||
// current running set.
|
||||
type matrixPlanner struct {
|
||||
solver *matrixSolver
|
||||
processes map[string]process.Process
|
||||
logger *logmon.Monitor
|
||||
}
|
||||
|
||||
func (p *matrixPlanner) EvictionFor(target string, alsoRunning []string) []string {
|
||||
return p.solver.Solve(target, p.runningSet(alsoRunning)).Evict
|
||||
}
|
||||
|
||||
func (p *matrixPlanner) OnSwapStart(target string) {
|
||||
running := p.runningModels()
|
||||
result := p.solver.Solve(target, running)
|
||||
switch {
|
||||
case len(result.Evict) > 0:
|
||||
p.logger.Infof("matrix: model=%s set=%s dsl=%q evict=%v target=%v cost=%d",
|
||||
target, result.SetName, result.DSL, result.Evict, result.TargetSet, result.TotalCost)
|
||||
case len(running) == 0:
|
||||
p.logger.Infof("matrix: model=%s starting (no models running)", target)
|
||||
default:
|
||||
p.logger.Debugf("matrix: model=%s already running in set=%s dsl=%q", target, result.SetName, result.DSL)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *matrixPlanner) runningModels() []string {
|
||||
return p.runningSet(nil)
|
||||
}
|
||||
|
||||
// runningSet returns the union of live processes (State != Stopped/Shutdown)
|
||||
// and any extra IDs the baseRouter has already committed to loading but which
|
||||
// the process state machine has not yet reflected.
|
||||
func (p *matrixPlanner) runningSet(alsoRunning []string) []string {
|
||||
seen := make(map[string]struct{}, len(p.processes))
|
||||
var running []string
|
||||
for id, proc := range p.processes {
|
||||
st := proc.State()
|
||||
if st == process.StateStopped || st == process.StateShutdown {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
running = append(running, id)
|
||||
}
|
||||
for _, id := range alsoRunning {
|
||||
if _, dup := seen[id]; dup {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
running = append(running, id)
|
||||
}
|
||||
sort.Strings(running)
|
||||
return running
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
)
|
||||
|
||||
// matrixSolver contains pure swap-decision logic with no Process dependencies.
|
||||
// It is safe for concurrent reads after construction.
|
||||
type matrixSolver struct {
|
||||
expandedSets []config.ExpandedSet // all valid model combinations
|
||||
evictCosts map[string]int // real model name -> eviction cost (default 1)
|
||||
modelToSets map[string][]int // model name -> indices into expandedSets
|
||||
}
|
||||
|
||||
func newMatrixSolver(expandedSets []config.ExpandedSet, evictCosts map[string]int) *matrixSolver {
|
||||
modelToSets := make(map[string][]int)
|
||||
for i, es := range expandedSets {
|
||||
for _, model := range es.Models {
|
||||
modelToSets[model] = append(modelToSets[model], i)
|
||||
}
|
||||
}
|
||||
|
||||
return &matrixSolver{
|
||||
expandedSets: expandedSets,
|
||||
evictCosts: evictCosts,
|
||||
modelToSets: modelToSets,
|
||||
}
|
||||
}
|
||||
|
||||
// solveResult describes what the solver decided.
|
||||
type solveResult struct {
|
||||
Evict []string // running models that must be stopped
|
||||
TargetSet []string // the chosen set of models (for informational purposes)
|
||||
SetName string // name of the chosen set
|
||||
DSL string // original DSL expression for the chosen set
|
||||
TotalCost int // total eviction cost
|
||||
}
|
||||
|
||||
// Solve determines which models to evict when a model is requested.
|
||||
//
|
||||
// Algorithm:
|
||||
// 1. If requestedModel is already running, no eviction needed.
|
||||
// 2. Find all sets containing requestedModel.
|
||||
// 3. If no sets found, the model runs alone; evict all running models.
|
||||
// 4. For each candidate set, compute cost = sum of evict_costs for running
|
||||
// models NOT in that set.
|
||||
// 5. Pick lowest cost. Ties broken by definition order (index in expandedSets).
|
||||
// 6. Return models to evict and the chosen set.
|
||||
func (s *matrixSolver) Solve(requestedModel string, runningModels []string) solveResult {
|
||||
if slices.Contains(runningModels, requestedModel) {
|
||||
setName, dsl := s.findMatchingSet(requestedModel, runningModels)
|
||||
return solveResult{
|
||||
TargetSet: runningModels,
|
||||
SetName: setName,
|
||||
DSL: dsl,
|
||||
}
|
||||
}
|
||||
|
||||
candidateIndices := s.modelToSets[requestedModel]
|
||||
|
||||
// Model not in any set: runs alone, evict everything.
|
||||
if len(candidateIndices) == 0 {
|
||||
evict := make([]string, len(runningModels))
|
||||
copy(evict, runningModels)
|
||||
return solveResult{
|
||||
Evict: evict,
|
||||
TargetSet: []string{requestedModel},
|
||||
}
|
||||
}
|
||||
|
||||
bestCost := -1
|
||||
bestIdx := -1
|
||||
|
||||
for _, idx := range candidateIndices {
|
||||
setModels := s.expandedSets[idx].Models
|
||||
cost := 0
|
||||
for _, running := range runningModels {
|
||||
if !slices.Contains(setModels, running) {
|
||||
cost += s.evictCost(running)
|
||||
}
|
||||
}
|
||||
|
||||
if bestCost < 0 || cost < bestCost || (cost == bestCost && idx < bestIdx) {
|
||||
bestCost = cost
|
||||
bestIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
chosen := s.expandedSets[bestIdx]
|
||||
var evict []string
|
||||
for _, running := range runningModels {
|
||||
if !slices.Contains(chosen.Models, running) {
|
||||
evict = append(evict, running)
|
||||
}
|
||||
}
|
||||
|
||||
return solveResult{
|
||||
Evict: evict,
|
||||
TargetSet: chosen.Models,
|
||||
SetName: chosen.SetName,
|
||||
DSL: chosen.DSL,
|
||||
TotalCost: bestCost,
|
||||
}
|
||||
}
|
||||
|
||||
// findMatchingSet finds the expanded set that contains all running models.
|
||||
// Returns the set name and DSL, or empty strings if no match.
|
||||
func (s *matrixSolver) findMatchingSet(requestedModel string, runningModels []string) (string, string) {
|
||||
for _, idx := range s.modelToSets[requestedModel] {
|
||||
set := s.expandedSets[idx]
|
||||
allInSet := true
|
||||
for _, m := range runningModels {
|
||||
if !slices.Contains(set.Models, m) {
|
||||
allInSet = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allInSet {
|
||||
return set.SetName, set.DSL
|
||||
}
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func (s *matrixSolver) evictCost(model string) int {
|
||||
if cost, ok := s.evictCosts[model]; ok {
|
||||
return cost
|
||||
}
|
||||
return 1
|
||||
}
|
||||
@@ -0,0 +1,244 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
)
|
||||
|
||||
// newTestMatrix builds a Matrix router from supplied processes, bypassing
|
||||
// NewMatrix's call to process.New.
|
||||
func newTestMatrix(t *testing.T, conf config.Config, expanded []config.ExpandedSet, evictCosts map[string]int, processes map[string]process.Process) *Matrix {
|
||||
t.Helper()
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
planner := &matrixPlanner{
|
||||
solver: newMatrixSolver(expanded, evictCosts),
|
||||
processes: processes,
|
||||
logger: logger,
|
||||
}
|
||||
base := newBaseRouter("matrix", conf, processes, planner, logger)
|
||||
base.testProcessed = make(chan struct{}, 64)
|
||||
r := &Matrix{baseRouter: base}
|
||||
go base.run()
|
||||
t.Cleanup(func() {
|
||||
if !r.shuttingDown.Load() {
|
||||
_ = r.Shutdown(time.Second)
|
||||
}
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
func baseMatrixConfig() config.Config {
|
||||
return config.Config{
|
||||
HealthCheckTimeout: 5,
|
||||
Matrix: &config.MatrixConfig{},
|
||||
}
|
||||
}
|
||||
|
||||
// TestMatrix_SwapEvictsConflicting verifies that loading a model triggers
|
||||
// eviction of running models that are not in any shared set with it.
|
||||
func TestMatrix_SwapEvictsConflicting(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
go a.Run(0) // park a Run goroutine so Stop has something to release
|
||||
|
||||
b := newFakeProcess("b")
|
||||
b.autoReady = true
|
||||
|
||||
// Two single-model sets: a and b never coexist, so loading b must evict a.
|
||||
expanded := []config.ExpandedSet{
|
||||
{SetName: "s_a", DSL: "a", Models: []string{"a"}},
|
||||
{SetName: "s_b", DSL: "b", Models: []string{"b"}},
|
||||
}
|
||||
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, newRequest("b"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 1 {
|
||||
t.Errorf("a.stopCalls=%d want 1", got)
|
||||
}
|
||||
if got := b.runCalls.Load(); got != 1 {
|
||||
t.Errorf("b.runCalls=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMatrix_CoexistInSet verifies that a model is not evicted when the target
|
||||
// shares a set with it (the fast path applies if the target is already ready).
|
||||
func TestMatrix_CoexistInSet(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
a.markReady()
|
||||
go a.Run(0)
|
||||
|
||||
b := newFakeProcess("b")
|
||||
b.autoReady = true
|
||||
|
||||
// Both fit in s_ab, so b's swap should not stop a.
|
||||
expanded := []config.ExpandedSet{
|
||||
{SetName: "s_ab", DSL: "a & b", Models: []string{"a", "b"}},
|
||||
}
|
||||
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": b})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, newRequest("b"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 0 {
|
||||
t.Errorf("a.stopCalls=%d want 0 (coexists with b)", got)
|
||||
}
|
||||
if got := b.runCalls.Load(); got != 1 {
|
||||
t.Errorf("b.runCalls=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMatrix_CoexistingSetParallel verifies that two models that share an
|
||||
// expanded set load in parallel — the solver returns empty Evict for both,
|
||||
// the collision predicate clears them, and both swaps run together.
|
||||
func TestMatrix_CoexistingSetParallel(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
|
||||
expanded := []config.ExpandedSet{
|
||||
{SetName: "s_ab", DSL: "a & b", Models: []string{"a", "b"}},
|
||||
}
|
||||
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": pb})
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
r.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, r.testProcessed, 1)
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
r.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, r.testProcessed, 1)
|
||||
|
||||
<-a.runStarted
|
||||
<-pb.runStarted
|
||||
|
||||
a.markReady()
|
||||
pb.markReady()
|
||||
|
||||
for i, ch := range []chan struct{}{done1, done2} {
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("request %d did not complete", i)
|
||||
}
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 0 {
|
||||
t.Errorf("a.stopCalls=%d want 0 (coexists with b)", got)
|
||||
}
|
||||
if got := pb.stopCalls.Load(); got != 0 {
|
||||
t.Errorf("b.stopCalls=%d want 0 (coexists with a)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMatrix_IncompatibleQueues verifies that the second request for a model
|
||||
// that cannot coexist with the in-flight first model queues until the first
|
||||
// completes, and then evicts it. This exercises the alsoRunning hint via the
|
||||
// matrix solver's union into runningSet.
|
||||
func TestMatrix_IncompatibleQueues(t *testing.T) {
|
||||
a := newFakeProcess("a")
|
||||
pb := newFakeProcess("b")
|
||||
|
||||
expanded := []config.ExpandedSet{
|
||||
{SetName: "s_a", DSL: "a", Models: []string{"a"}},
|
||||
{SetName: "s_b", DSL: "b", Models: []string{"b"}},
|
||||
}
|
||||
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": pb})
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
r.ServeHTTP(w1, newRequest("a"))
|
||||
close(done1)
|
||||
}()
|
||||
waitProcessed(t, r.testProcessed, 1)
|
||||
|
||||
// B arrives before A transitions to StateStarting. The solver sees A via
|
||||
// alsoRunning and returns evict=[a], so collidesWith forces B to queue.
|
||||
w2 := httptest.NewRecorder()
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
r.ServeHTTP(w2, newRequest("b"))
|
||||
close(done2)
|
||||
}()
|
||||
waitProcessed(t, r.testProcessed, 1)
|
||||
|
||||
if got := pb.runCalls.Load(); got != 0 {
|
||||
t.Errorf("b started in parallel: runCalls=%d want 0", got)
|
||||
}
|
||||
|
||||
<-a.runStarted
|
||||
a.markReady()
|
||||
waitProcessed(t, r.testProcessed, 1) // swapDone(a) → b promoted, evicts a
|
||||
<-pb.runStarted
|
||||
pb.markReady()
|
||||
|
||||
for i, ch := range []chan struct{}{done1, done2} {
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("request %d did not complete", i)
|
||||
}
|
||||
}
|
||||
if got := a.stopCalls.Load(); got != 1 {
|
||||
t.Errorf("a.stopCalls=%d want 1 (b's swap must stop a)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMatrixSolver_TieBreakDefinitionOrder pins the solver's tie-break rule:
|
||||
// when multiple candidate sets have equal eviction cost, the earlier-defined
|
||||
// set wins.
|
||||
func TestMatrixSolver_TieBreakDefinitionOrder(t *testing.T) {
|
||||
expanded := []config.ExpandedSet{
|
||||
{SetName: "first", DSL: "a & b", Models: []string{"a", "b"}},
|
||||
{SetName: "second", DSL: "a & c", Models: []string{"a", "c"}},
|
||||
}
|
||||
s := newMatrixSolver(expanded, nil)
|
||||
|
||||
// No models running, request "a": both sets have cost 0 and contain a.
|
||||
// Definition order: "first" wins.
|
||||
result := s.Solve("a", nil)
|
||||
if result.SetName != "first" {
|
||||
t.Errorf("SetName=%q want %q", result.SetName, "first")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMatrixSolver_EvictCostsPreferred verifies that higher evict costs steer
|
||||
// the solver toward a cheaper set.
|
||||
func TestMatrixSolver_EvictCostsPreferred(t *testing.T) {
|
||||
// b is expensive to evict; c is cheap. Request "a" with both b and c
|
||||
// running. The solver should pick the set that keeps b.
|
||||
expanded := []config.ExpandedSet{
|
||||
{SetName: "a_with_c", DSL: "a & c", Models: []string{"a", "c"}}, // would evict b (cost 10)
|
||||
{SetName: "a_with_b", DSL: "a & b", Models: []string{"a", "b"}}, // would evict c (cost 1)
|
||||
}
|
||||
s := newMatrixSolver(expanded, map[string]int{"b": 10, "c": 1})
|
||||
|
||||
result := s.Solve("a", []string{"b", "c"})
|
||||
if result.SetName != "a_with_b" {
|
||||
t.Errorf("SetName=%q want %q (keep expensive b)", result.SetName, "a_with_b")
|
||||
}
|
||||
if len(result.Evict) != 1 || result.Evict[0] != "c" {
|
||||
t.Errorf("Evict=%v want [c]", result.Evict)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,188 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
type peerMember struct {
|
||||
peerID string
|
||||
reverseProxy *httputil.ReverseProxy
|
||||
apiKey string
|
||||
}
|
||||
|
||||
type Peer struct {
|
||||
cfg config.Config
|
||||
logger *logmon.Monitor
|
||||
peers map[string]*peerMember
|
||||
|
||||
shutdownCtx context.Context
|
||||
shutdownFn context.CancelFunc
|
||||
shuttingDown atomic.Bool
|
||||
inflight sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewPeer(cfg config.Config, logger *logmon.Monitor) (*Peer, error) {
|
||||
peers := cfg.Peers
|
||||
modelMap := make(map[string]*peerMember)
|
||||
|
||||
peerIDs := make([]string, 0, len(peers))
|
||||
for peerID := range peers {
|
||||
peerIDs = append(peerIDs, peerID)
|
||||
}
|
||||
sort.Strings(peerIDs)
|
||||
|
||||
for _, peerID := range peerIDs {
|
||||
peer := peers[peerID]
|
||||
|
||||
peerTransport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: time.Duration(peer.Timeouts.Connect) * time.Second,
|
||||
KeepAlive: time.Duration(peer.Timeouts.KeepAlive) * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: time.Duration(peer.Timeouts.TLSHandshake) * time.Second,
|
||||
ResponseHeaderTimeout: time.Duration(peer.Timeouts.ResponseHeader) * time.Second,
|
||||
ExpectContinueTimeout: time.Duration(peer.Timeouts.ExpectContinue) * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second,
|
||||
}
|
||||
|
||||
reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL)
|
||||
reverseProxy.Transport = peerTransport
|
||||
|
||||
originalDirector := reverseProxy.Director
|
||||
reverseProxy.Director = func(req *http.Request) {
|
||||
originalDirector(req)
|
||||
req.Host = req.URL.Host
|
||||
}
|
||||
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||
resp.Header.Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
logger.Warnf("peer %s: proxy error: %v", peerID, err)
|
||||
errMsg := fmt.Sprintf("peer proxy error: %v", err)
|
||||
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
|
||||
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
|
||||
}
|
||||
http.Error(w, errMsg, http.StatusBadGateway)
|
||||
}
|
||||
|
||||
pp := &peerMember{
|
||||
peerID: peerID,
|
||||
reverseProxy: reverseProxy,
|
||||
apiKey: peer.ApiKey,
|
||||
}
|
||||
|
||||
for _, modelID := range peer.Models {
|
||||
if _, found := modelMap[modelID]; found {
|
||||
logger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
|
||||
continue
|
||||
}
|
||||
modelMap[modelID] = pp
|
||||
}
|
||||
}
|
||||
|
||||
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
|
||||
|
||||
return &Peer{
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
peers: modelMap,
|
||||
shutdownCtx: shutdownCtx,
|
||||
shutdownFn: shutdownFn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *Peer) Handles(model string) bool {
|
||||
_, ok := r.peers[model]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *Peer) Shutdown(timeout time.Duration) error {
|
||||
if !r.shuttingDown.CompareAndSwap(false, true) {
|
||||
return fmt.Errorf("shutdown already in progress")
|
||||
}
|
||||
|
||||
if timeout == 0 {
|
||||
r.shutdownFn()
|
||||
r.inflight.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
r.inflight.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-time.After(timeout):
|
||||
r.shutdownFn()
|
||||
r.inflight.Wait()
|
||||
return fmt.Errorf("peer shutdown timed out after %v", timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Peer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
if r.shuttingDown.Load() {
|
||||
SendError(w, req, fmt.Errorf("peer proxy is shutting down"))
|
||||
return
|
||||
}
|
||||
r.inflight.Add(1)
|
||||
defer r.inflight.Done()
|
||||
|
||||
data, err := FetchContext(req, r.cfg)
|
||||
if err != nil {
|
||||
SendError(w, req, err)
|
||||
return
|
||||
}
|
||||
|
||||
pp, found := r.peers[data.ModelID]
|
||||
if !found {
|
||||
r.logger.Warnf("peer model not found: %s", data.ModelID)
|
||||
SendError(w, req, ErrNoPeerModelFound)
|
||||
return
|
||||
}
|
||||
|
||||
r.logger.Debugf("peer: routing model %s to peer %s", data.ModelID, pp.peerID)
|
||||
|
||||
if pp.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+pp.apiKey)
|
||||
req.Header.Set("x-api-key", pp.apiKey)
|
||||
}
|
||||
|
||||
// Cancel the proxy request when the client disconnects or shutdown times out.
|
||||
// AfterFunc links both parent contexts to our child without a goroutine leak.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
stopReq := context.AfterFunc(req.Context(), cancel)
|
||||
stopShutdown := context.AfterFunc(r.shutdownCtx, cancel)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
pp.reverseProxy.ServeHTTP(w, req)
|
||||
|
||||
stopShutdown()
|
||||
stopReq()
|
||||
cancel()
|
||||
}
|
||||
@@ -0,0 +1,611 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
var testLogger = logmon.NewWriter(os.Stdout)
|
||||
|
||||
func init() {
|
||||
testLogger.SetLogLevel(logmon.LevelWarn)
|
||||
}
|
||||
|
||||
func TestNewPeer_EmptyPeers(t *testing.T) {
|
||||
pr, err := NewPeer(config.Config{}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if pr == nil {
|
||||
t.Fatal("expected non-nil Peer")
|
||||
}
|
||||
if len(pr.peers) != 0 {
|
||||
t.Fatalf("expected empty peers map, got %d entries", len(pr.peers))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPeer_SinglePeer(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "test-key",
|
||||
Models: []string{"model-a", "model-b"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(pr.peers) != 2 {
|
||||
t.Fatalf("expected 2 entries, got %d", len(pr.peers))
|
||||
}
|
||||
if _, ok := pr.peers["model-a"]; !ok {
|
||||
t.Error("expected model-a to be mapped")
|
||||
}
|
||||
if _, ok := pr.peers["model-b"]; !ok {
|
||||
t.Error("expected model-b to be mapped")
|
||||
}
|
||||
if _, ok := pr.peers["model-c"]; ok {
|
||||
t.Error("expected model-c to not be mapped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPeer_MultiplePeers(t *testing.T) {
|
||||
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL1,
|
||||
Models: []string{"model-a", "model-b"},
|
||||
},
|
||||
"peer2": config.PeerConfig{
|
||||
Proxy: "http://peer2.example.com:8080",
|
||||
ProxyURL: proxyURL2,
|
||||
Models: []string{"model-c", "model-d"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(pr.peers) != 4 {
|
||||
t.Fatalf("expected 4 entries, got %d", len(pr.peers))
|
||||
}
|
||||
for _, m := range []string{"model-a", "model-b", "model-c", "model-d"} {
|
||||
if _, ok := pr.peers[m]; !ok {
|
||||
t.Errorf("expected %s to be mapped", m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPeer_DuplicateModel(t *testing.T) {
|
||||
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"alpha-peer": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL1,
|
||||
Models: []string{"duplicate-model"},
|
||||
},
|
||||
"beta-peer": config.PeerConfig{
|
||||
Proxy: "http://peer2.example.com:8080",
|
||||
ProxyURL: proxyURL2,
|
||||
Models: []string{"duplicate-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(pr.peers) != 1 {
|
||||
t.Fatalf("expected 1 entry for duplicate model, got %d", len(pr.peers))
|
||||
}
|
||||
if _, ok := pr.peers["duplicate-model"]; !ok {
|
||||
t.Error("expected duplicate-model to be mapped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_Success(t *testing.T) {
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("response from peer"))
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d", w.Code)
|
||||
}
|
||||
if w.Body.String() != "response from peer" {
|
||||
t.Errorf("expected 'response from peer', got %q", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_ModelNotFoundInContext(t *testing.T) {
|
||||
pr, err := NewPeer(config.Config{}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_PeerModelNotFound(t *testing.T) {
|
||||
pr, err := NewPeer(config.Config{}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "nonexistent-model", ModelID: "nonexistent-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_ApiKeyInjection(t *testing.T) {
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "secret-api-key",
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
|
||||
if receivedAuthHeader != "Bearer secret-api-key" {
|
||||
t.Errorf("expected 'Bearer secret-api-key', got %q", receivedAuthHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_NoApiKey(t *testing.T) {
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "",
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
|
||||
if receivedAuthHeader != "" {
|
||||
t.Errorf("expected no auth header, got %q", receivedAuthHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_HostHeaderSet(t *testing.T) {
|
||||
var receivedHost string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedHost = r.Host
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
|
||||
if !strings.HasPrefix(receivedHost, "127.0.0.1:") {
|
||||
t.Errorf("expected Host to start with '127.0.0.1:', got %q", receivedHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_SSEHeaderModification(t *testing.T) {
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
|
||||
if w.Header().Get("X-Accel-Buffering") != "no" {
|
||||
t.Errorf("expected X-Accel-Buffering=no, got %q", w.Header().Get("X-Accel-Buffering"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_ShutdownRejectsNewRequests(t *testing.T) {
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = pr.Shutdown(0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "shutting down") {
|
||||
t.Errorf("expected 'shutting down' in body, got %q", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_WaitsForInflightDuringShutdown(t *testing.T) {
|
||||
started := make(chan struct{})
|
||||
released := make(chan struct{})
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
close(started)
|
||||
<-released
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
w := httptest.NewRecorder()
|
||||
pr.ServeHTTP(w, req)
|
||||
}()
|
||||
|
||||
<-started
|
||||
|
||||
shutdownDone := make(chan error, 1)
|
||||
go func() {
|
||||
shutdownDone <- pr.Shutdown(500 * time.Millisecond)
|
||||
}()
|
||||
|
||||
// Shutdown should be waiting on inflight. If it finished already something is wrong.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
select {
|
||||
case err := <-shutdownDone:
|
||||
t.Errorf("shutdown completed before inflight finished: %v", err)
|
||||
default:
|
||||
}
|
||||
|
||||
close(released)
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
case err := <-shutdownDone:
|
||||
if err != nil {
|
||||
t.Errorf("shutdown errored after inflight completed: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("shutdown did not complete after inflight finished")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_ShutdownTimeoutCancelsInflight(t *testing.T) {
|
||||
started := make(chan struct{})
|
||||
released := make(chan struct{})
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
close(started)
|
||||
<-released
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
w := httptest.NewRecorder()
|
||||
pr.ServeHTTP(w, req)
|
||||
}()
|
||||
|
||||
<-started
|
||||
|
||||
err = pr.Shutdown(100 * time.Millisecond)
|
||||
if err == nil {
|
||||
t.Error("expected timeout error from shutdown")
|
||||
}
|
||||
|
||||
close(released)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestPeer_ShutdownMultiple(t *testing.T) {
|
||||
pr, err := NewPeer(config.Config{}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = pr.Shutdown(0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = pr.Shutdown(0)
|
||||
if err == nil {
|
||||
t.Error("expected error on second shutdown")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "already in progress") {
|
||||
t.Errorf("expected 'already in progress', got %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_ModelExtractedFromBody(t *testing.T) {
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("ok"))
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"extracted-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
body := strings.NewReader(`{"model":"extracted-model","prompt":"hello"}`)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_ServeHTTP_ContextOverridesBodyModel(t *testing.T) {
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("ok"))
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"context-model"},
|
||||
},
|
||||
"peer2": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"body-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
body := strings.NewReader(`{"model":"body-model","prompt":"hello"}`)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "context-model", ModelID: "context-model"}))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
pr.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPeer_CustomTimeouts(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://localhost:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"test-peer": config.PeerConfig{
|
||||
Proxy: "http://localhost:8080",
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"model1"},
|
||||
Timeouts: config.TimeoutsConfig{
|
||||
Connect: 45,
|
||||
ResponseHeader: 300,
|
||||
TLSHandshake: 15,
|
||||
ExpectContinue: 2,
|
||||
IdleConn: 120,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
member, ok := pr.peers["model1"]
|
||||
if !ok {
|
||||
t.Fatal("expected model1 to be mapped")
|
||||
}
|
||||
|
||||
transport, ok := member.reverseProxy.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatal("expected Transport to be *http.Transport")
|
||||
}
|
||||
|
||||
if transport.ResponseHeaderTimeout != 300*time.Second {
|
||||
t.Errorf("expected ResponseHeaderTimeout=%v, got %v", 300*time.Second, transport.ResponseHeaderTimeout)
|
||||
}
|
||||
if transport.TLSHandshakeTimeout != 15*time.Second {
|
||||
t.Errorf("expected TLSHandshakeTimeout=%v, got %v", 15*time.Second, transport.TLSHandshakeTimeout)
|
||||
}
|
||||
if transport.ExpectContinueTimeout != 2*time.Second {
|
||||
t.Errorf("expected ExpectContinueTimeout=%v, got %v", 2*time.Second, transport.ExpectContinueTimeout)
|
||||
}
|
||||
if transport.IdleConnTimeout != 120*time.Second {
|
||||
t.Errorf("expected IdleConnTimeout=%v, got %v", 120*time.Second, transport.IdleConnTimeout)
|
||||
}
|
||||
if !transport.ForceAttemptHTTP2 {
|
||||
t.Error("expected ForceAttemptHTTP2 to be true")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,199 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type contextkey struct {
|
||||
name string
|
||||
}
|
||||
|
||||
type ReqContextData struct {
|
||||
Model string
|
||||
ModelID string
|
||||
Streaming bool
|
||||
SendLoadingState bool
|
||||
}
|
||||
|
||||
var (
|
||||
ErrNoModelInContext = fmt.Errorf("no model in request context")
|
||||
ErrNoRouterFound = fmt.Errorf("no router found for model")
|
||||
ErrNoPeerModelFound = fmt.Errorf("peer model not found")
|
||||
ErrNoLocalModelFound = fmt.Errorf("local model not found")
|
||||
|
||||
ContextKey = &contextkey{"context"}
|
||||
)
|
||||
|
||||
type Router interface {
|
||||
// Shutdown blocks until the router has shutdown returning nil
|
||||
// when the router has shutdown successfully.
|
||||
//
|
||||
// timeout controls how long to wait for inflight requests to finish. After
|
||||
// the timeout all inflight requests will be cancelled.
|
||||
Shutdown(timeout time.Duration) error
|
||||
|
||||
// ServeHTTP implements the http.Handler and requests coming in will
|
||||
// trigger any model swapping and routing logic.
|
||||
ServeHTTP(http.ResponseWriter, *http.Request)
|
||||
|
||||
// Handles reports whether this router can serve requests for the given model.
|
||||
Handles(model string) bool
|
||||
}
|
||||
|
||||
// LocalRouter is a Router backed by local processes whose state can be
|
||||
// inspected and which can be individually stopped. Peer routers, which only
|
||||
// forward to remote hosts, do not implement it.
|
||||
type LocalRouter interface {
|
||||
Router
|
||||
|
||||
// RunningModels returns the current state of every process that is not
|
||||
// stopped or shut down, keyed by model ID.
|
||||
RunningModels() map[string]process.ProcessState
|
||||
|
||||
// Unload stops the named models, or every running model when none are
|
||||
// named. It blocks until each targeted process has stopped.
|
||||
Unload(timeout time.Duration, models ...string)
|
||||
|
||||
// ProcessLogger returns the log monitor for the named model's process.
|
||||
// modelID must be a real (non-alias) config key. Returns false when the
|
||||
// model is not known to this router.
|
||||
ProcessLogger(modelID string) (*logmon.Monitor, bool)
|
||||
}
|
||||
|
||||
// FetchContext will attempt to get the model id from the context then
|
||||
// from the model body. If it extracts the model from the body it will
|
||||
// store the model in the context for downstream handlers. An error
|
||||
// will be returned when model can not be fetch from either location.
|
||||
func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
|
||||
data, ok := ReadContext(r.Context())
|
||||
if ok {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
if data, err := ExtractContext(r); err == nil {
|
||||
realName, _ := cfg.RealModelName(data.Model)
|
||||
if realName == "" {
|
||||
realName = data.Model
|
||||
}
|
||||
data.ModelID = realName
|
||||
if mc, ok := cfg.Models[realName]; ok {
|
||||
data.SendLoadingState = mc.SendLoadingState != nil && *mc.SendLoadingState
|
||||
}
|
||||
*r = *r.WithContext(SetContext(r.Context(), data))
|
||||
return data, nil
|
||||
}
|
||||
|
||||
return ReqContextData{}, ErrNoModelInContext
|
||||
}
|
||||
|
||||
func SetContext(ctx context.Context, data ReqContextData) context.Context {
|
||||
return context.WithValue(ctx, ContextKey, data)
|
||||
}
|
||||
|
||||
func ReadContext(ctx context.Context) (ReqContextData, bool) {
|
||||
data, ok := ctx.Value(ContextKey).(ReqContextData)
|
||||
return data, ok
|
||||
}
|
||||
|
||||
// ExtractContext pulls the model name from an HTTP request without consuming the
|
||||
// body. For GET requests it reads the "model" query parameter. For POST
|
||||
// requests it inspects Content-Type and parses JSON, multipart/form-data, or
|
||||
// application/x-www-form-urlencoded bodies. The request body is always restored
|
||||
// before returning so downstream handlers — including reverse proxies that
|
||||
// forward raw bytes upstream — can still read it.
|
||||
func ExtractContext(r *http.Request) (ReqContextData, error) {
|
||||
if r.Method == http.MethodGet {
|
||||
if model := r.URL.Query().Get("model"); model != "" {
|
||||
return ReqContextData{Model: model, Streaming: r.URL.Query().Get("stream") == "true"}, nil
|
||||
}
|
||||
return ReqContextData{}, fmt.Errorf("missing 'model' query parameter")
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return ReqContextData{}, fmt.Errorf("error reading request body: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
}()
|
||||
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
model := gjson.GetBytes(bodyBytes, "model").String()
|
||||
if model == "" {
|
||||
return ReqContextData{}, fmt.Errorf("missing or empty 'model' in JSON body")
|
||||
}
|
||||
return ReqContextData{Model: model, Streaming: gjson.GetBytes(bodyBytes, "stream").Bool()}, nil
|
||||
}
|
||||
|
||||
// Form parsers read from r.Body, so feed them a fresh reader over the
|
||||
// buffered bytes. The deferred restore above will reset r.Body again
|
||||
// after parsing.
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
if strings.Contains(contentType, "multipart/form-data") {
|
||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||
return ReqContextData{}, fmt.Errorf("error parsing multipart form: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return ReqContextData{}, fmt.Errorf("error parsing form: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if model := r.FormValue("model"); model != "" {
|
||||
return ReqContextData{Model: model, Streaming: r.FormValue("stream") == "true"}, nil
|
||||
}
|
||||
|
||||
return ReqContextData{}, fmt.Errorf("missing 'model' parameter")
|
||||
}
|
||||
|
||||
func SendError(w http.ResponseWriter, r *http.Request, err error) {
|
||||
switch {
|
||||
case errors.Is(err, ErrNoModelInContext):
|
||||
SendResponse(w, r, http.StatusNotFound, "no model id could be identified")
|
||||
case errors.Is(err, ErrNoPeerModelFound):
|
||||
SendResponse(w, r, http.StatusNotFound, "no peer found for requested model")
|
||||
case errors.Is(err, ErrNoLocalModelFound):
|
||||
SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
|
||||
case errors.Is(err, ErrNoRouterFound):
|
||||
SendResponse(w, r, http.StatusNotFound, "no router for requested model")
|
||||
default:
|
||||
SendResponse(w, r, http.StatusInternalServerError, fmt.Sprintf("unspecific error: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// SendResponse detects what content type the client prefers and returns an error response in that format.
|
||||
func SendResponse(w http.ResponseWriter, r *http.Request, status int, message string) {
|
||||
// Check Accept header for preferred response format
|
||||
acceptHeader := r.Header.Get("Accept")
|
||||
if strings.Contains(acceptHeader, "text/plain") {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(fmt.Sprintf("llama-swap: %s", message)))
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(acceptHeader, "text/html") {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(fmt.Sprintf(`<html><body><h1>llama-swap</h1><p>%s</p></body></html>`, message)))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(fmt.Sprintf(`{"src":"llama-swap", "error": "%s"}`, message)))
|
||||
}
|
||||
@@ -0,0 +1,275 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractContext_GET(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
wantModel string
|
||||
wantErr bool
|
||||
}{
|
||||
{"model present", "model=llama3", "llama3", false},
|
||||
{"model with slashes", "model=author/model-7b", "author/model-7b", false},
|
||||
{"model missing", "", "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
|
||||
got, err := ExtractContext(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||
}
|
||||
if got.Model != tt.wantModel {
|
||||
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_JSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantModel string
|
||||
wantErr bool
|
||||
}{
|
||||
{"model present", `{"model":"llama3","stream":true}`, "llama3", false},
|
||||
{"model with slashes", `{"model":"author/model-7b"}`, "author/model-7b", false},
|
||||
{"model empty string", `{"model":""}`, "", true},
|
||||
{"model key missing", `{"stream":true}`, "", true},
|
||||
{"invalid json", `not-json`, "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
got, err := ExtractContext(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||
}
|
||||
if got.Model != tt.wantModel {
|
||||
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_URLEncodedForm(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
formModel string
|
||||
wantModel string
|
||||
wantErr bool
|
||||
}{
|
||||
{"model present", "whisper-1", "whisper-1", false},
|
||||
{"model missing", "", "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
form := url.Values{}
|
||||
if tt.formModel != "" {
|
||||
form.Set("model", tt.formModel)
|
||||
}
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(form.Encode()))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
got, err := ExtractContext(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||
}
|
||||
if got.Model != tt.wantModel {
|
||||
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_MultipartForm(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
formModel string
|
||||
wantModel string
|
||||
wantErr bool
|
||||
}{
|
||||
{"model present", "whisper-1", "whisper-1", false},
|
||||
{"model missing", "", "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
mw := multipart.NewWriter(&buf)
|
||||
if tt.formModel != "" {
|
||||
fw, _ := mw.CreateFormField("model")
|
||||
fw.Write([]byte(tt.formModel))
|
||||
}
|
||||
mw.Close()
|
||||
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf)
|
||||
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
got, err := ExtractContext(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||
}
|
||||
if got.Model != tt.wantModel {
|
||||
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_JSONBodyRestored(t *testing.T) {
|
||||
body := `{"model":"llama3","stream":true}`
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
if _, err := ExtractContext(r); err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
|
||||
remaining, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body after ExtractContext: %v", err)
|
||||
}
|
||||
if string(remaining) != body {
|
||||
t.Errorf("body not restored: want %q got %q", body, string(remaining))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_MultipartBodyRestored(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
mw := multipart.NewWriter(&buf)
|
||||
fw, _ := mw.CreateFormField("model")
|
||||
fw.Write([]byte("whisper-1"))
|
||||
ff, _ := mw.CreateFormFile("file", "audio.wav")
|
||||
ff.Write([]byte("fake-audio-bytes"))
|
||||
mw.Close()
|
||||
|
||||
original := buf.Bytes()
|
||||
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", bytes.NewReader(original))
|
||||
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
|
||||
if _, err := ExtractContext(r); err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
|
||||
remaining, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body after ExtractContext: %v", err)
|
||||
}
|
||||
if !bytes.Equal(remaining, original) {
|
||||
t.Errorf("multipart body not restored: want %d bytes got %d bytes", len(original), len(remaining))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_URLEncodedBodyRestored(t *testing.T) {
|
||||
body := "model=whisper-1&extra=value"
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(body))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
if _, err := ExtractContext(r); err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
|
||||
remaining, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body after ExtractContext: %v", err)
|
||||
}
|
||||
if string(remaining) != body {
|
||||
t.Errorf("url-encoded body not restored: want %q got %q", body, string(remaining))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetContext(t *testing.T) {
|
||||
ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"})
|
||||
data, ok := ctx.Value(ContextKey).(ReqContextData)
|
||||
if !ok {
|
||||
t.Fatalf("ContextKey not set or wrong type")
|
||||
}
|
||||
if data.Model != "llama3" {
|
||||
t.Errorf("want %q got %q", "llama3", data.Model)
|
||||
}
|
||||
if data.ModelID != "llama3" {
|
||||
t.Errorf("want %q got %q", "llama3", data.ModelID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetContext_WithAlias(t *testing.T) {
|
||||
ctx := SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"})
|
||||
data, _ := ctx.Value(ContextKey).(ReqContextData)
|
||||
if data.Model != "llama" {
|
||||
t.Errorf("want requested %q got %q", "llama", data.Model)
|
||||
}
|
||||
if data.ModelID != "llama3" {
|
||||
t.Errorf("want real %q got %q", "llama3", data.ModelID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetContext_DoesNotMutateParent(t *testing.T) {
|
||||
parent := context.Background()
|
||||
_ = SetContext(parent, ReqContextData{Model: "llama3", ModelID: "llama3"})
|
||||
if v := parent.Value(ContextKey); v != nil {
|
||||
t.Errorf("parent context was mutated: %v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
wantReq string
|
||||
wantReal string
|
||||
wantBool bool
|
||||
}{
|
||||
{
|
||||
name: "model present, same name",
|
||||
ctx: SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"}),
|
||||
wantReq: "llama3",
|
||||
wantReal: "llama3",
|
||||
wantBool: true,
|
||||
},
|
||||
{
|
||||
name: "model present, aliased",
|
||||
ctx: SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"}),
|
||||
wantReq: "llama",
|
||||
wantReal: "llama3",
|
||||
wantBool: true,
|
||||
},
|
||||
{
|
||||
name: "model absent",
|
||||
ctx: context.Background(),
|
||||
wantReq: "",
|
||||
wantReal: "",
|
||||
wantBool: false,
|
||||
},
|
||||
{
|
||||
name: "model is empty string",
|
||||
ctx: SetContext(context.Background(), ReqContextData{Model: "", ModelID: ""}),
|
||||
wantReq: "",
|
||||
wantReal: "",
|
||||
wantBool: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotData, ok := ReadContext(tt.ctx)
|
||||
if gotData.Model != tt.wantReq || gotData.ModelID != tt.wantReal || ok != tt.wantBool {
|
||||
t.Errorf("want (%q, %q, %v) got (%q, %q, %v)", tt.wantReq, tt.wantReal, tt.wantBool, gotData.Model, gotData.ModelID, ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,266 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// modelRecord is one entry in the OpenAI-compatible /v1/models listing.
|
||||
type modelRecord struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Meta map[string]any `json:"meta,omitempty"`
|
||||
}
|
||||
|
||||
// handleListModels serves the OpenAI-compatible model listing: local models
|
||||
// (with optional aliases) plus peer models.
|
||||
func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
|
||||
created := time.Now().Unix()
|
||||
data := make([]modelRecord, 0, len(s.cfg.Models))
|
||||
|
||||
newRecord := func(id, name, description string, metadata map[string]any) modelRecord {
|
||||
rec := modelRecord{
|
||||
ID: id,
|
||||
Object: "model",
|
||||
Created: created,
|
||||
OwnedBy: "llama-swap",
|
||||
Name: strings.TrimSpace(name),
|
||||
Description: strings.TrimSpace(description),
|
||||
}
|
||||
if len(metadata) > 0 {
|
||||
rec.Meta = map[string]any{"llamaswap": metadata}
|
||||
}
|
||||
return rec
|
||||
}
|
||||
|
||||
for id, mc := range s.cfg.Models {
|
||||
if mc.Unlisted {
|
||||
continue
|
||||
}
|
||||
data = append(data, newRecord(id, mc.Name, mc.Description, mc.Metadata))
|
||||
|
||||
if s.cfg.IncludeAliasesInList {
|
||||
for _, alias := range mc.Aliases {
|
||||
if alias := strings.TrimSpace(alias); alias != "" {
|
||||
data = append(data, newRecord(alias, mc.Name, mc.Description, mc.Metadata))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for peerID, peer := range s.cfg.Peers {
|
||||
for _, modelID := range peer.Models {
|
||||
data = append(data, newRecord(modelID, peerID+": "+modelID, "", map[string]any{"peerID": peerID}))
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(data, func(i, j int) bool { return data[i].ID < data[j].ID })
|
||||
|
||||
// Echo the Origin so browser clients can read the listing.
|
||||
if origin := r.Header.Get("Origin"); origin != "" {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"object": "list",
|
||||
"data": data,
|
||||
})
|
||||
}
|
||||
|
||||
// runningModel is one entry in the /running listing.
|
||||
type runningModel struct {
|
||||
Model string `json:"model"`
|
||||
State string `json:"state"`
|
||||
Cmd string `json:"cmd"`
|
||||
Proxy string `json:"proxy"`
|
||||
TTL int `json:"ttl"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// handleUnload stops every running local process. Peer models are remote and
|
||||
// unaffected.
|
||||
func (s *Server) handleUnload(w http.ResponseWriter, r *http.Request) {
|
||||
s.local.Unload(0)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}
|
||||
|
||||
// handleRunning lists local processes that are not stopped, joining each model
|
||||
// ID against its config for the cmd/proxy/ttl/name/description metadata.
|
||||
func (s *Server) handleRunning(w http.ResponseWriter, r *http.Request) {
|
||||
states := s.local.RunningModels()
|
||||
list := make([]runningModel, 0, len(states))
|
||||
for id, state := range states {
|
||||
mc := s.cfg.Models[id]
|
||||
list = append(list, runningModel{
|
||||
Model: id,
|
||||
State: string(state),
|
||||
Cmd: mc.Cmd,
|
||||
Proxy: mc.Proxy,
|
||||
TTL: mc.UnloadAfter,
|
||||
Name: mc.Name,
|
||||
Description: mc.Description,
|
||||
})
|
||||
}
|
||||
sort.Slice(list, func(i, j int) bool { return list[i].Model < list[j].Model })
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{"running": list})
|
||||
}
|
||||
|
||||
// discardResponseWriter satisfies http.ResponseWriter for preload requests,
|
||||
// dropping the body while capturing the status code.
|
||||
type discardResponseWriter struct {
|
||||
header http.Header
|
||||
status int
|
||||
}
|
||||
|
||||
func (d *discardResponseWriter) Header() http.Header {
|
||||
if d.header == nil {
|
||||
d.header = make(http.Header)
|
||||
}
|
||||
return d.header
|
||||
}
|
||||
|
||||
func (d *discardResponseWriter) Write(p []byte) (int, error) { return len(p), nil }
|
||||
|
||||
func (d *discardResponseWriter) WriteHeader(status int) { d.status = status }
|
||||
|
||||
// startPreload fires a background GET / at every model named in
|
||||
// Hooks.OnStartup.Preload so they are warm before the first real request.
|
||||
// Preload names are already resolved to real model IDs by config loading.
|
||||
func (s *Server) startPreload() {
|
||||
models := s.cfg.Hooks.OnStartup.Preload
|
||||
if len(models) == 0 {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
for _, modelID := range models {
|
||||
if !s.local.Handles(modelID) {
|
||||
s.proxylog.Warnf("preload: model %s is not a local model, skipping", modelID)
|
||||
continue
|
||||
}
|
||||
s.proxylog.Infof("preloading model: %s", modelID)
|
||||
|
||||
req, err := http.NewRequestWithContext(s.shutdownCtx, http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
req = req.WithContext(router.SetContext(req.Context(), router.ReqContextData{Model: modelID, ModelID: modelID}))
|
||||
|
||||
dw := &discardResponseWriter{status: http.StatusOK}
|
||||
s.local.ServeHTTP(dw, req)
|
||||
|
||||
success := dw.status < http.StatusBadRequest
|
||||
if !success {
|
||||
s.proxylog.Errorf("failed to preload model %s: status %d", modelID, dw.status)
|
||||
}
|
||||
event.Emit(shared.ModelPreloadedEvent{ModelName: modelID, Success: success})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// handleMetrics serves Prometheus-format performance metrics. Returns 503 when
|
||||
// performance monitoring is disabled.
|
||||
func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
if s.perf == nil {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
w.Write([]byte("# performance monitor not available\n"))
|
||||
return
|
||||
}
|
||||
s.perf.MetricsHandler().ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}
|
||||
|
||||
func handleRootRedirect(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/ui", http.StatusFound)
|
||||
}
|
||||
|
||||
func handleUpstreamRedirect(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/ui/models", http.StatusFound)
|
||||
}
|
||||
|
||||
// handleUpstream proxies ANY request under /upstream/<model>/<path> directly to
|
||||
// the model's process, bypassing model dispatch by body/query inspection.
|
||||
func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
||||
upstreamPath := r.PathValue("upstreamPath")
|
||||
|
||||
searchName, modelID, remainingPath, found := findModelInPath(s.cfg, "/"+upstreamPath)
|
||||
if !found {
|
||||
router.SendResponse(w, r, http.StatusNotFound, "model not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Redirect /upstream/model to /upstream/model/ so relative URLs in upstream
|
||||
// responses resolve. 301 for GET/HEAD, 308 otherwise to preserve the method.
|
||||
if remainingPath == "/" && !strings.HasSuffix(r.URL.Path, "/") {
|
||||
newPath := "/upstream/" + searchName + "/"
|
||||
if r.URL.RawQuery != "" {
|
||||
newPath += "?" + r.URL.RawQuery
|
||||
}
|
||||
if r.Method == http.MethodGet || r.Method == http.MethodHead {
|
||||
http.Redirect(w, r, newPath, http.StatusMovedPermanently)
|
||||
} else {
|
||||
http.Redirect(w, r, newPath, http.StatusPermanentRedirect)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Strip the /upstream/<model> prefix before forwarding.
|
||||
r.URL.Path = remainingPath
|
||||
// Pin the resolved model so the router skips body/query extraction.
|
||||
*r = *r.WithContext(router.SetContext(r.Context(), router.ReqContextData{Model: searchName, ModelID: modelID}))
|
||||
|
||||
switch {
|
||||
case s.local.Handles(modelID):
|
||||
s.local.ServeHTTP(w, r)
|
||||
case s.peer.Handles(modelID):
|
||||
s.peer.ServeHTTP(w, r)
|
||||
default:
|
||||
router.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID)
|
||||
}
|
||||
}
|
||||
|
||||
// findModelInPath walks a slash-separated path, building up segments until one
|
||||
// matches a configured model. This resolves model names that contain slashes
|
||||
// (e.g. "author/model"). Returns the matched name, its real model ID, the
|
||||
// remaining path, and whether a match was found.
|
||||
func findModelInPath(cfg config.Config, path string) (searchName, realName, remainingPath string, found bool) {
|
||||
parts := strings.Split(strings.TrimSpace(path), "/")
|
||||
name := ""
|
||||
|
||||
for i, part := range parts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
if name == "" {
|
||||
name = part
|
||||
} else {
|
||||
name = name + "/" + part
|
||||
}
|
||||
|
||||
if modelID, ok := cfg.RealModelName(name); ok {
|
||||
return name, modelID, "/" + strings.Join(parts[i+1:], "/"), true
|
||||
}
|
||||
}
|
||||
|
||||
return "", "", "", false
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
)
|
||||
|
||||
func TestServer_HandleListModels(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
s.cfg = config.Config{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"visible": {Name: "Visible", Description: "a model"},
|
||||
"hidden": {Unlisted: true},
|
||||
},
|
||||
Peers: config.PeerDictionaryConfig{
|
||||
"peer1": {Models: []string{"remote-model"}},
|
||||
},
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
s.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d", w.Code)
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
|
||||
t.Errorf("Access-Control-Allow-Origin = %q", got)
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Data []modelRecord `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
ids := map[string]bool{}
|
||||
for _, m := range resp.Data {
|
||||
ids[m.ID] = true
|
||||
}
|
||||
if !ids["visible"] || !ids["remote-model"] {
|
||||
t.Errorf("missing expected models: %v", ids)
|
||||
}
|
||||
if ids["hidden"] {
|
||||
t.Error("unlisted model should not appear")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_HandleListModels_Aliases(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
s.cfg = config.Config{
|
||||
IncludeAliasesInList: true,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"real": {Aliases: []string{"nick"}},
|
||||
},
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/v1/models", nil))
|
||||
|
||||
var resp struct {
|
||||
Data []modelRecord `json:"data"`
|
||||
}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
ids := map[string]bool{}
|
||||
for _, m := range resp.Data {
|
||||
ids[m.ID] = true
|
||||
}
|
||||
if !ids["real"] || !ids["nick"] {
|
||||
t.Errorf("expected alias entry; got %v", ids)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_FindModelInPath(t *testing.T) {
|
||||
cfg := config.Config{Models: map[string]config.ModelConfig{
|
||||
"author/model": {},
|
||||
"simple": {},
|
||||
}}
|
||||
|
||||
cases := []struct {
|
||||
path string
|
||||
wantName string
|
||||
wantRem string
|
||||
wantFound bool
|
||||
}{
|
||||
{"/simple/v1/chat", "simple", "/v1/chat", true},
|
||||
{"/author/model/v1/chat", "author/model", "/v1/chat", true},
|
||||
{"/author/model", "author/model", "/", true},
|
||||
{"/missing/v1", "", "", false},
|
||||
{"/", "", "", false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
name, _, rem, found := findModelInPath(cfg, c.path)
|
||||
if found != c.wantFound || name != c.wantName || (found && rem != c.wantRem) {
|
||||
t.Errorf("findModelInPath(%q) = (%q,%q,%v), want (%q,%q,%v)",
|
||||
c.path, name, rem, found, c.wantName, c.wantRem, c.wantFound)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_HandleUpstream(t *testing.T) {
|
||||
local := newStubRouter([]string{"m1"}, "upstream-body")
|
||||
s := newTestServer(local, newStubRouter(nil, ""))
|
||||
s.cfg = config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
|
||||
|
||||
t.Run("proxies to local", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/v1/chat", nil))
|
||||
if w.Code != http.StatusOK || w.Body.String() != "upstream-body" {
|
||||
t.Errorf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("redirects bare model path", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1", nil))
|
||||
if w.Code != http.StatusMovedPermanently {
|
||||
t.Errorf("status = %d, want 301", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unknown model 404", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/nope/v1", nil))
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("status = %d, want 404", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_HandleMetrics_Unavailable(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/metrics", nil))
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("status = %d, want 503", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Redirects(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
for path, want := range map[string]string{"/": "/ui", "/upstream": "/ui/models"} {
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, path, nil))
|
||||
if w.Code != http.StatusFound {
|
||||
t.Errorf("%s: status = %d, want 302", path, w.Code)
|
||||
}
|
||||
if got := w.Header().Get("Location"); got != want {
|
||||
t.Errorf("%s: Location = %q, want %q", path, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,270 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// apiModel is one entry in the /api/events modelStatus payload.
|
||||
type apiModel struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
Unlisted bool `json:"unlisted"`
|
||||
PeerID string `json:"peerID"`
|
||||
Aliases []string `json:"aliases,omitempty"`
|
||||
}
|
||||
|
||||
// modelStatus returns every configured model joined with its current process
|
||||
// state (defaulting to "stopped"), followed by peer models.
|
||||
func (s *Server) modelStatus() []apiModel {
|
||||
running := s.local.RunningModels()
|
||||
|
||||
ids := make([]string, 0, len(s.cfg.Models))
|
||||
for id := range s.cfg.Models {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
sort.Strings(ids)
|
||||
|
||||
models := make([]apiModel, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
mc := s.cfg.Models[id]
|
||||
state := "stopped"
|
||||
if st, ok := running[id]; ok {
|
||||
state = string(st)
|
||||
}
|
||||
models = append(models, apiModel{
|
||||
Id: id,
|
||||
Name: mc.Name,
|
||||
Description: mc.Description,
|
||||
State: state,
|
||||
Unlisted: mc.Unlisted,
|
||||
Aliases: mc.Aliases,
|
||||
})
|
||||
}
|
||||
|
||||
for peerID, peer := range s.cfg.Peers {
|
||||
for _, modelID := range peer.Models {
|
||||
models = append(models, apiModel{Id: modelID, PeerID: peerID})
|
||||
}
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
// handleAPIUnloadAll stops every running local process.
|
||||
func (s *Server) handleAPIUnloadAll(w http.ResponseWriter, r *http.Request) {
|
||||
s.local.Unload(0)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{"msg": "ok"})
|
||||
}
|
||||
|
||||
// handleAPIUnloadModel stops a single named local process.
|
||||
func (s *Server) handleAPIUnloadModel(w http.ResponseWriter, r *http.Request) {
|
||||
requested := strings.TrimPrefix(r.PathValue("model"), "/")
|
||||
realName, found := s.cfg.RealModelName(requested)
|
||||
if !found {
|
||||
router.SendResponse(w, r, http.StatusNotFound, "model not found")
|
||||
return
|
||||
}
|
||||
if !s.local.Handles(realName) {
|
||||
router.SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
|
||||
return
|
||||
}
|
||||
s.local.Unload(0, realName)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}
|
||||
|
||||
// handleAPIMetrics serves the activity log as a JSON array.
|
||||
func (s *Server) handleAPIMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
data, err := s.metrics.getMetricsJSON()
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, "failed to get metrics")
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
// handleAPIPerformance serves the buffered system/GPU stats, optionally
|
||||
// filtered to samples after the ?after=<RFC3339> timestamp.
|
||||
func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
||||
if s.perf == nil {
|
||||
router.SendResponse(w, r, http.StatusServiceUnavailable, "performance monitor not available")
|
||||
return
|
||||
}
|
||||
|
||||
sysStats, gpuStats := s.perf.Current()
|
||||
|
||||
if afterStr := r.URL.Query().Get("after"); afterStr != "" {
|
||||
after, err := time.Parse(time.RFC3339, afterStr)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusBadRequest, "invalid 'after' timestamp, use RFC3339 format")
|
||||
return
|
||||
}
|
||||
filteredSys := make([]perf.SysStat, 0, len(sysStats))
|
||||
for _, st := range sysStats {
|
||||
if st.Timestamp.After(after) {
|
||||
filteredSys = append(filteredSys, st)
|
||||
}
|
||||
}
|
||||
sysStats = filteredSys
|
||||
|
||||
filteredGpu := make([]perf.GpuStat, 0, len(gpuStats))
|
||||
for _, g := range gpuStats {
|
||||
if g.Timestamp.After(after) {
|
||||
filteredGpu = append(filteredGpu, g)
|
||||
}
|
||||
}
|
||||
gpuStats = filteredGpu
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"sys_stats": sysStats,
|
||||
"gpu_stats": gpuStats,
|
||||
})
|
||||
}
|
||||
|
||||
// handleAPIVersion serves the build metadata.
|
||||
func (s *Server) handleAPIVersion(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"version": s.build.Version,
|
||||
"commit": s.build.Commit,
|
||||
"build_date": s.build.Date,
|
||||
})
|
||||
}
|
||||
|
||||
// handleAPICapture returns the stored request/response capture for a metric ID.
|
||||
func (s *Server) handleAPICapture(w http.ResponseWriter, r *http.Request) {
|
||||
id, err := strconv.Atoi(r.PathValue("id"))
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusBadRequest, "invalid capture ID")
|
||||
return
|
||||
}
|
||||
|
||||
capture := s.metrics.getCaptureByID(id)
|
||||
if capture == nil {
|
||||
router.SendResponse(w, r, http.StatusNotFound, "capture not found")
|
||||
return
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(capture)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, "failed to marshal capture")
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(jsonBytes)
|
||||
}
|
||||
|
||||
type messageType string
|
||||
|
||||
const (
|
||||
msgTypeModelStatus messageType = "modelStatus"
|
||||
msgTypeLogData messageType = "logData"
|
||||
msgTypeMetrics messageType = "metrics"
|
||||
msgTypeInFlight messageType = "inflight"
|
||||
)
|
||||
|
||||
type messageEnvelope struct {
|
||||
Type messageType `json:"type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// handleAPIEvents streams server events (model status, log data, metrics,
|
||||
// in-flight counts) to the client as Server-Sent Events.
|
||||
func (s *Server) handleAPIEvents(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
// prevent nginx from buffering SSE
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
|
||||
return
|
||||
}
|
||||
|
||||
// internal/event already has a 50K event buffer
|
||||
// a 1K message buffer should be enough, watch the logs for the warning that the sendBuffer is full
|
||||
sendBuffer := make(chan messageEnvelope, 1024)
|
||||
ctx, cancel := context.WithCancel(r.Context())
|
||||
defer cancel()
|
||||
|
||||
send := func(msg messageEnvelope) {
|
||||
select {
|
||||
case sendBuffer <- msg:
|
||||
case <-ctx.Done():
|
||||
s.proxylog.Warn("handleAPIEvents send suppressed due to context done")
|
||||
default:
|
||||
s.proxylog.Warn("handleAPIEvents sendBuffer full, dropped message")
|
||||
}
|
||||
}
|
||||
sendModels := func() {
|
||||
if data, err := json.Marshal(s.modelStatus()); err == nil {
|
||||
send(messageEnvelope{Type: msgTypeModelStatus, Data: string(data)})
|
||||
}
|
||||
}
|
||||
sendLogData := func(source string, data []byte) {
|
||||
if j, err := json.Marshal(map[string]string{"source": source, "data": string(data)}); err == nil {
|
||||
send(messageEnvelope{Type: msgTypeLogData, Data: string(j)})
|
||||
}
|
||||
}
|
||||
sendMetrics := func(metrics []ActivityLogEntry) {
|
||||
if j, err := json.Marshal(metrics); err == nil {
|
||||
send(messageEnvelope{Type: msgTypeMetrics, Data: string(j)})
|
||||
}
|
||||
}
|
||||
sendInFlight := func(total int) {
|
||||
if j, err := json.Marshal(map[string]int{"total": total}); err == nil {
|
||||
send(messageEnvelope{Type: msgTypeInFlight, Data: string(j)})
|
||||
}
|
||||
}
|
||||
|
||||
defer event.On(func(e shared.ProcessStateChangeEvent) { sendModels() })()
|
||||
defer event.On(func(e shared.ConfigFileChangedEvent) { sendModels() })()
|
||||
defer s.proxylog.OnLogData(func(data []byte) { sendLogData("proxy", data) })()
|
||||
defer s.upstreamlog.OnLogData(func(data []byte) { sendLogData("upstream", data) })()
|
||||
defer event.On(func(e ActivityLogEvent) { sendMetrics([]ActivityLogEntry{e.Metrics}) })()
|
||||
defer event.On(func(e shared.InFlightRequestsEvent) { sendInFlight(e.Total) })()
|
||||
|
||||
// initial payload
|
||||
sendLogData("proxy", s.proxylog.GetHistory())
|
||||
sendLogData("upstream", s.upstreamlog.GetHistory())
|
||||
sendModels()
|
||||
sendMetrics(s.metrics.getMetrics())
|
||||
sendInFlight(int(s.inflight.Current()))
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case <-s.shutdownCtx.Done():
|
||||
return
|
||||
case msg := <-sendBuffer:
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(w, "event:message\ndata:%s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestServer_InflightMiddleware(t *testing.T) {
|
||||
c := &inflightCounter{}
|
||||
mw := CreateInflightMiddleware(c)
|
||||
|
||||
var duringRequest int64
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
duringRequest = c.Current()
|
||||
}))
|
||||
|
||||
handler.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil))
|
||||
|
||||
if duringRequest != 1 {
|
||||
t.Errorf("counter during request = %d, want 1", duringRequest)
|
||||
}
|
||||
if got := c.Current(); got != 0 {
|
||||
t.Errorf("counter after request = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_APIVersion(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
s.build = BuildInfo{Version: "1.2.3", Commit: "deadbeef", Date: "2026-05-19"}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/version", nil))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d", w.Code)
|
||||
}
|
||||
var got map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if got["version"] != "1.2.3" || got["commit"] != "deadbeef" || got["build_date"] != "2026-05-19" {
|
||||
t.Errorf("body = %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_APIMetrics_Empty(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/metrics", nil))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d", w.Code)
|
||||
}
|
||||
if body := strings.TrimSpace(w.Body.String()); body != "[]" {
|
||||
t.Errorf("body = %q, want []", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_APIPerformance_Unavailable(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/performance", nil))
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("status = %d, want 503", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_APIEvents_InitialPayload(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/events", nil).WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.ServeHTTP(w, req)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("handler did not return after context cancel")
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
for _, want := range []string{`"type":"modelStatus"`, `"type":"inflight"`, `"type":"logData"`} {
|
||||
if !strings.Contains(body, want) {
|
||||
t.Errorf("initial SSE payload missing %s; body=%q", want, body)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
)
|
||||
|
||||
// CreateAuthMiddleware returns middleware that validates API keys when the
|
||||
// config declares any. It accepts the key via Authorization: Bearer,
|
||||
// Authorization: Basic (password field), or x-api-key. On success the auth
|
||||
// headers are stripped so they never leak to upstream. When no keys are
|
||||
// configured the middleware is a pass-through.
|
||||
func CreateAuthMiddleware(cfg config.Config) chain.Middleware {
|
||||
keys := cfg.RequiredAPIKeys
|
||||
return func(next http.Handler) http.Handler {
|
||||
if len(keys) == 0 {
|
||||
return next
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
provided := extractAPIKey(r)
|
||||
|
||||
valid := false
|
||||
for _, key := range keys {
|
||||
if provided == key {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="llama-swap"`)
|
||||
router.SendResponse(w, r, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
|
||||
return
|
||||
}
|
||||
|
||||
r.Header.Del("Authorization")
|
||||
r.Header.Del("x-api-key")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// extractAPIKey pulls a candidate API key from the request, preferring Basic,
|
||||
// then Bearer, then x-api-key.
|
||||
func extractAPIKey(r *http.Request) string {
|
||||
var bearerKey, basicKey string
|
||||
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
bearerKey = strings.TrimPrefix(auth, "Bearer ")
|
||||
} else if strings.HasPrefix(auth, "Basic ") {
|
||||
encoded := strings.TrimPrefix(auth, "Basic ")
|
||||
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
|
||||
if parts := strings.SplitN(string(decoded), ":", 2); len(parts) == 2 {
|
||||
basicKey = parts[1] // password field is the API key
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case basicKey != "":
|
||||
return basicKey
|
||||
case bearerKey != "":
|
||||
return bearerKey
|
||||
default:
|
||||
return r.Header.Get("x-api-key")
|
||||
}
|
||||
}
|
||||
|
||||
// CreateCORSMiddleware returns middleware that answers OPTIONS preflight
|
||||
// requests with permissive CORS headers (see issues #81, #77, #42). Non-OPTIONS
|
||||
// requests pass through untouched.
|
||||
func CreateCORSMiddleware() chain.Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodOptions {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||
if headers := r.Header.Get("Access-Control-Request-Headers"); headers != "" {
|
||||
w.Header().Set("Access-Control-Allow-Headers", sanitizeAccessControlRequestHeaderValues(headers))
|
||||
} else {
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Accept, X-Requested-With")
|
||||
}
|
||||
w.Header().Set("Access-Control-Max-Age", "86400")
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func isTokenChar(r rune) bool {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
case r >= 'A' && r <= 'Z':
|
||||
case r >= '0' && r <= '9':
|
||||
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// sanitizeAccessControlRequestHeaderValues drops any header names that contain
|
||||
// characters outside the HTTP token grammar before echoing them back.
|
||||
func sanitizeAccessControlRequestHeaderValues(headerValues string) string {
|
||||
parts := strings.Split(headerValues, ",")
|
||||
valid := make([]string, 0, len(parts))
|
||||
|
||||
for _, p := range parts {
|
||||
v := strings.TrimSpace(p)
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
validPart := true
|
||||
for _, c := range v {
|
||||
if !isTokenChar(c) {
|
||||
validPart = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if validPart {
|
||||
valid = append(valid, v)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(valid, ", ")
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
)
|
||||
|
||||
func TestServer_ExtractAPIKey(t *testing.T) {
|
||||
basicHeader := func(user, pass string) string {
|
||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
|
||||
}
|
||||
cases := []struct {
|
||||
name string
|
||||
auth string
|
||||
xapi string
|
||||
want string
|
||||
}{
|
||||
{"none", "", "", ""},
|
||||
{"bearer", "Bearer tok123", "", "tok123"},
|
||||
{"basic", basicHeader("user", "pw-key"), "", "pw-key"},
|
||||
{"x-api-key", "", "xkey", "xkey"},
|
||||
{"basic beats bearer", basicHeader("u", "bk"), "", "bk"},
|
||||
{"bearer beats x-api-key", "Bearer btok", "xkey", "btok"},
|
||||
{"malformed basic falls back to x-api-key", "Basic !!!notbase64", "xkey", "xkey"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if c.auth != "" {
|
||||
r.Header.Set("Authorization", c.auth)
|
||||
}
|
||||
if c.xapi != "" {
|
||||
r.Header.Set("x-api-key", c.xapi)
|
||||
}
|
||||
if got := extractAPIKey(r); got != c.want {
|
||||
t.Errorf("extractAPIKey() = %q, want %q", got, c.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_SanitizeAccessControlRequestHeaders(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"Content-Type, Authorization", "Content-Type, Authorization"},
|
||||
{" X-Custom , Accept ", "X-Custom, Accept"},
|
||||
{"Valid, Bad Header", "Valid"},
|
||||
{"Bad@Header", ""},
|
||||
{"", ""},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if got := sanitizeAccessControlRequestHeaderValues(c.in); got != c.want {
|
||||
t.Errorf("sanitize(%q) = %q, want %q", c.in, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_IsTokenChar(t *testing.T) {
|
||||
for _, r := range "abcXYZ0129!#$%&'*+-.^_`|~" {
|
||||
if !isTokenChar(r) {
|
||||
t.Errorf("isTokenChar(%q) = false, want true", r)
|
||||
}
|
||||
}
|
||||
for _, r := range " @()/\t\"" {
|
||||
if isTokenChar(r) {
|
||||
t.Errorf("isTokenChar(%q) = true, want false", r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_AuthMiddleware(t *testing.T) {
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Authorization") != "" || r.Header.Get("x-api-key") != "" {
|
||||
t.Error("auth headers leaked to upstream")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
t.Run("no keys configured passes through", func(t *testing.T) {
|
||||
mw := CreateAuthMiddleware(config.Config{})
|
||||
w := httptest.NewRecorder()
|
||||
mw(final).ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
cfg := config.Config{RequiredAPIKeys: []string{"secret"}}
|
||||
|
||||
t.Run("valid key", func(t *testing.T) {
|
||||
mw := CreateAuthMiddleware(cfg)
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.Header.Set("Authorization", "Bearer secret")
|
||||
w := httptest.NewRecorder()
|
||||
mw(final).ServeHTTP(w, r)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid key", func(t *testing.T) {
|
||||
mw := CreateAuthMiddleware(cfg)
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.Header.Set("Authorization", "Bearer wrong")
|
||||
w := httptest.NewRecorder()
|
||||
mw(final).ServeHTTP(w, r)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401", w.Code)
|
||||
}
|
||||
if w.Header().Get("WWW-Authenticate") == "" {
|
||||
t.Error("missing WWW-Authenticate header")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,176 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
)
|
||||
|
||||
// ReqRespCapture is a stored request/response pair for a single metered request.
|
||||
type ReqRespCapture struct {
|
||||
ID int `json:"id"`
|
||||
ReqPath string `json:"req_path"`
|
||||
ReqHeaders map[string]string `json:"req_headers"`
|
||||
ReqBody []byte `json:"req_body"`
|
||||
RespHeaders map[string]string `json:"resp_headers"`
|
||||
RespBody []byte `json:"resp_body"`
|
||||
}
|
||||
|
||||
// captureFields is a bitmask controlling what a route stores in a ReqRespCapture.
|
||||
type captureFields uint
|
||||
|
||||
const (
|
||||
captureReqHeaders captureFields = 1 << iota
|
||||
captureReqBody
|
||||
captureRespHeaders
|
||||
captureRespBody
|
||||
)
|
||||
|
||||
const (
|
||||
captureReqAll = captureReqHeaders | captureReqBody
|
||||
captureRespAll = captureRespHeaders | captureRespBody
|
||||
captureAll = captureReqAll | captureRespAll
|
||||
)
|
||||
|
||||
// captureFieldsByPath overrides the default capture mask for routes carrying
|
||||
// large binary payloads (audio/image) where storing the full body is wasteful.
|
||||
var captureFieldsByPath = map[string]captureFields{
|
||||
"/v1/audio/speech": captureReqAll | captureRespHeaders,
|
||||
"/v1/audio/voices": captureReqHeaders | captureRespAll,
|
||||
"/v1/audio/transcriptions": captureReqHeaders | captureRespHeaders | captureRespBody,
|
||||
"/v1/images/generations": captureReqAll | captureRespHeaders,
|
||||
"/v1/images/edits": captureReqHeaders | captureRespHeaders,
|
||||
"/sdapi/v1/txt2img": captureReqAll | captureRespHeaders,
|
||||
"/sdapi/v1/img2img": captureReqHeaders | captureRespHeaders,
|
||||
}
|
||||
|
||||
// captureFieldsFor returns the capture mask for a request path. Unlisted routes
|
||||
// (the OpenAI-compatible JSON endpoints) capture everything.
|
||||
func captureFieldsFor(path string) captureFields {
|
||||
if cf, ok := captureFieldsByPath[path]; ok {
|
||||
return cf
|
||||
}
|
||||
return captureAll
|
||||
}
|
||||
|
||||
// zstdEncOptions are the shared zstd encoder options for maximum compression.
|
||||
var zstdEncOptions = []zstd.EOption{
|
||||
zstd.WithEncoderLevel(zstd.SpeedBetterCompression),
|
||||
}
|
||||
|
||||
// zstdEncPool pools zstd.Encoder instances to reduce allocations.
|
||||
var zstdEncPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
enc, _ := zstd.NewWriter(nil, zstdEncOptions...)
|
||||
return enc
|
||||
},
|
||||
}
|
||||
|
||||
// zstdDecPool pools zstd.Decoder instances to reduce allocations.
|
||||
var zstdDecPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
dec, _ := zstd.NewReader(nil)
|
||||
return dec
|
||||
},
|
||||
}
|
||||
|
||||
// compressCapture marshals a ReqRespCapture to CBOR and compresses it with zstd.
|
||||
// Returns the compressed bytes and the original CBOR byte count for logging.
|
||||
func compressCapture(c *ReqRespCapture) ([]byte, int, error) {
|
||||
cborBytes, err := cbor.Marshal(c)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("marshal capture: %w", err)
|
||||
}
|
||||
zenc := zstdEncPool.Get().(*zstd.Encoder)
|
||||
defer zstdEncPool.Put(zenc)
|
||||
return zenc.EncodeAll(cborBytes, nil), len(cborBytes), nil
|
||||
}
|
||||
|
||||
// decompressCapture decompresses zstd-compressed CBOR into a ReqRespCapture.
|
||||
func decompressCapture(data []byte) (*ReqRespCapture, error) {
|
||||
dec := zstdDecPool.Get().(*zstd.Decoder)
|
||||
defer zstdDecPool.Put(dec)
|
||||
cborBytes, err := dec.DecodeAll(data, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decompress capture: %w", err)
|
||||
}
|
||||
var capture ReqRespCapture
|
||||
if err := cbor.Unmarshal(cborBytes, &capture); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal capture: %w", err)
|
||||
}
|
||||
return &capture, nil
|
||||
}
|
||||
|
||||
// addCapture compresses and stores a capture in the cache. Returns true if the
|
||||
// capture was stored.
|
||||
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) bool {
|
||||
if !mp.enableCaptures {
|
||||
return false
|
||||
}
|
||||
|
||||
compressed, uncompressedBytes, err := compressCapture(&capture)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("failed to compress capture: %v, skipping", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := mp.captureCache.Add(capture.ID, compressed); err != nil {
|
||||
mp.logger.Warnf("capture %d too large (%d bytes), skipping: %v", capture.ID, len(compressed), err)
|
||||
return false
|
||||
}
|
||||
|
||||
compressionRatio := (1 - float64(len(compressed))/float64(uncompressedBytes)) * 100
|
||||
mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio)
|
||||
return true
|
||||
}
|
||||
|
||||
// getCaptureByID decompresses and unmarshals a capture by ID. Returns nil if
|
||||
// the capture is not found or decompression fails.
|
||||
func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture {
|
||||
if mp.captureCache == nil {
|
||||
return nil
|
||||
}
|
||||
data, err := mp.captureCache.Get(id)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
capture, err := decompressCapture(data)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("failed to decompress capture %d: %v", id, err)
|
||||
return nil
|
||||
}
|
||||
return capture
|
||||
}
|
||||
|
||||
// sensitiveHeaders lists headers that are redacted in captures.
|
||||
var sensitiveHeaders = map[string]bool{
|
||||
"authorization": true,
|
||||
"proxy-authorization": true,
|
||||
"cookie": true,
|
||||
"set-cookie": true,
|
||||
"x-api-key": true,
|
||||
}
|
||||
|
||||
// headerMap flattens an http.Header to a single-value map.
|
||||
func headerMap(h http.Header) map[string]string {
|
||||
m := make(map[string]string, len(h))
|
||||
for key, values := range h {
|
||||
if len(values) > 0 {
|
||||
m[key] = values[0]
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// redactHeaders replaces sensitive header values in-place with "[REDACTED]".
|
||||
func redactHeaders(headers map[string]string) {
|
||||
for key := range headers {
|
||||
if sensitiveHeaders[strings.ToLower(key)] {
|
||||
headers[key] = "[REDACTED]"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
func TestServer_CaptureCompressRoundtrip(t *testing.T) {
|
||||
orig := &ReqRespCapture{
|
||||
ID: 7,
|
||||
ReqPath: "/v1/chat/completions",
|
||||
ReqHeaders: map[string]string{"Content-Type": "application/json"},
|
||||
ReqBody: []byte(`{"model":"m"}`),
|
||||
RespHeaders: map[string]string{"Content-Type": "application/json"},
|
||||
RespBody: []byte(`{"usage":{}}`),
|
||||
}
|
||||
|
||||
compressed, uncompressed, err := compressCapture(orig)
|
||||
if err != nil {
|
||||
t.Fatalf("compressCapture: %v", err)
|
||||
}
|
||||
if uncompressed == 0 || len(compressed) == 0 {
|
||||
t.Fatalf("unexpected sizes: uncompressed=%d compressed=%d", uncompressed, len(compressed))
|
||||
}
|
||||
|
||||
got, err := decompressCapture(compressed)
|
||||
if err != nil {
|
||||
t.Fatalf("decompressCapture: %v", err)
|
||||
}
|
||||
if got.ID != orig.ID || got.ReqPath != orig.ReqPath ||
|
||||
!bytes.Equal(got.ReqBody, orig.ReqBody) || !bytes.Equal(got.RespBody, orig.RespBody) {
|
||||
t.Fatalf("roundtrip mismatch: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_CaptureStoreAndRetrieve(t *testing.T) {
|
||||
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 5)
|
||||
if !mm.enableCaptures {
|
||||
t.Fatal("captures should be enabled with non-zero buffer")
|
||||
}
|
||||
|
||||
capture := ReqRespCapture{ID: 3, ReqPath: "/v1/chat/completions", ReqBody: []byte("hello")}
|
||||
if !mm.addCapture(capture) {
|
||||
t.Fatal("addCapture returned false")
|
||||
}
|
||||
|
||||
got := mm.getCaptureByID(3)
|
||||
if got == nil || !bytes.Equal(got.ReqBody, []byte("hello")) {
|
||||
t.Fatalf("getCaptureByID = %+v", got)
|
||||
}
|
||||
if mm.getCaptureByID(999) != nil {
|
||||
t.Fatal("expected nil for unknown capture ID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_CaptureDisabled(t *testing.T) {
|
||||
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 0)
|
||||
if mm.enableCaptures {
|
||||
t.Fatal("captures should be disabled with zero buffer")
|
||||
}
|
||||
if mm.addCapture(ReqRespCapture{ID: 1}) {
|
||||
t.Fatal("addCapture should return false when disabled")
|
||||
}
|
||||
if mm.getCaptureByID(1) != nil {
|
||||
t.Fatal("getCaptureByID should return nil when disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_CaptureFieldsFor(t *testing.T) {
|
||||
if got := captureFieldsFor("/v1/chat/completions"); got != captureAll {
|
||||
t.Fatalf("default = %b, want captureAll", got)
|
||||
}
|
||||
if got := captureFieldsFor("/v1/audio/speech"); got != captureReqAll|captureRespHeaders {
|
||||
t.Fatalf("/v1/audio/speech = %b", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/sync/semaphore"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
)
|
||||
|
||||
// defaultConcurrencyLimit caps simultaneous in-flight requests per model when
|
||||
// the model config leaves concurrencyLimit unset. Matches the legacy
|
||||
// proxy.Process default.
|
||||
const defaultConcurrencyLimit = 10
|
||||
|
||||
// CreateConcurrencyMiddleware returns middleware that limits simultaneous
|
||||
// model-dispatched requests per model. Each model gets a semaphore sized to
|
||||
// its concurrencyLimit (or defaultConcurrencyLimit). A request that cannot
|
||||
// immediately acquire a slot is rejected with 429. Models without a local
|
||||
// config entry (e.g. peer-routed models) are not limited.
|
||||
func CreateConcurrencyMiddleware(cfg config.Config) chain.Middleware {
|
||||
semaphores := make(map[string]*semaphore.Weighted, len(cfg.Models))
|
||||
for id, mc := range cfg.Models {
|
||||
limit := defaultConcurrencyLimit
|
||||
if mc.ConcurrencyLimit > 0 {
|
||||
limit = mc.ConcurrencyLimit
|
||||
}
|
||||
semaphores[id] = semaphore.NewWeighted(int64(limit))
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
data, err := router.FetchContext(r, cfg)
|
||||
if err != nil {
|
||||
router.SendError(w, r, router.ErrNoModelInContext)
|
||||
return
|
||||
}
|
||||
|
||||
// fall through for peer models
|
||||
sem, ok := semaphores[data.ModelID]
|
||||
if !ok {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
if !sem.TryAcquire(1) {
|
||||
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
defer sem.Release(1)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
)
|
||||
|
||||
func concurrencyTestReq(model string) *http.Request {
|
||||
r := httptest.NewRequest("GET", "/v1/chat/completions", nil)
|
||||
return r.WithContext(router.SetContext(r.Context(), router.ReqContextData{Model: model, ModelID: model}))
|
||||
}
|
||||
|
||||
func TestServer_ConcurrencyMiddleware_RejectsOverLimit(t *testing.T) {
|
||||
cfg := config.Config{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"m1": {ConcurrencyLimit: 1},
|
||||
},
|
||||
}
|
||||
|
||||
entered := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
var once sync.Once
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
once.Do(func() { close(entered) })
|
||||
<-release
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
h := CreateConcurrencyMiddleware(cfg)(final)
|
||||
|
||||
// First request occupies the only slot.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
h.ServeHTTP(httptest.NewRecorder(), concurrencyTestReq("m1"))
|
||||
}()
|
||||
<-entered
|
||||
|
||||
// Second concurrent request is rejected with 429.
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, concurrencyTestReq("m1"))
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("over-limit status = %d, want 429", w.Code)
|
||||
}
|
||||
|
||||
// Once the slot frees, a new request succeeds.
|
||||
close(release)
|
||||
<-done
|
||||
w = httptest.NewRecorder()
|
||||
h.ServeHTTP(w, concurrencyTestReq("m1"))
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("post-release status = %d, want 200", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ConcurrencyMiddleware_UnconfiguredModelPassesThrough(t *testing.T) {
|
||||
cfg := config.Config{Models: map[string]config.ModelConfig{}}
|
||||
|
||||
called := 0
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called++
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
h := CreateConcurrencyMiddleware(cfg)(final)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, concurrencyTestReq("peer-model"))
|
||||
if w.Code != http.StatusOK || called != 1 {
|
||||
t.Fatalf("unconfigured model: status=%d called=%d, want 200/1", w.Code, called)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,205 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
func TestServer_DecompressBody(t *testing.T) {
|
||||
plain := []byte("hello world")
|
||||
|
||||
var gz bytes.Buffer
|
||||
gw := gzip.NewWriter(&gz)
|
||||
gw.Write(plain)
|
||||
gw.Close()
|
||||
|
||||
var fl bytes.Buffer
|
||||
fw, _ := flate.NewWriter(&fl, flate.DefaultCompression)
|
||||
fw.Write(plain)
|
||||
fw.Close()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
body []byte
|
||||
encoding string
|
||||
}{
|
||||
{"plain", plain, ""},
|
||||
{"gzip", gz.Bytes(), "gzip"},
|
||||
{"deflate", fl.Bytes(), "deflate"},
|
||||
{"unknown passthrough", plain, "br"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
got, err := decompressBody(c.body, c.encoding)
|
||||
if err != nil {
|
||||
t.Fatalf("decompressBody: %v", err)
|
||||
}
|
||||
if !bytes.Equal(got, plain) {
|
||||
t.Errorf("got %q, want %q", got, plain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_FilterAcceptEncoding(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"", ""},
|
||||
{"gzip, deflate, br", "gzip, deflate"},
|
||||
{"br, zstd", ""},
|
||||
{"gzip;q=1.0", "gzip;q=1.0"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if got := filterAcceptEncoding(c.in); got != c.want {
|
||||
t.Errorf("filterAcceptEncoding(%q) = %q, want %q", c.in, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_BodyCopier_Flush(t *testing.T) {
|
||||
bc := newBodyCopier(httptest.NewRecorder())
|
||||
bc.Write([]byte("data"))
|
||||
bc.Flush()
|
||||
if bc.Status() != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", bc.Status())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_HeaderMapAndRedact(t *testing.T) {
|
||||
h := http.Header{
|
||||
"Content-Type": {"application/json"},
|
||||
"Authorization": {"Bearer secret"},
|
||||
"X-Api-Key": {"key123"},
|
||||
}
|
||||
m := headerMap(h)
|
||||
if m["Content-Type"] != "application/json" {
|
||||
t.Errorf("Content-Type = %q", m["Content-Type"])
|
||||
}
|
||||
|
||||
redactHeaders(m)
|
||||
if m["Authorization"] != "[REDACTED]" || m["X-Api-Key"] != "[REDACTED]" {
|
||||
t.Errorf("sensitive headers not redacted: %v", m)
|
||||
}
|
||||
if m["Content-Type"] != "application/json" {
|
||||
t.Error("non-sensitive header should not be redacted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_StripVersionPrefix(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/v/v1/chat", nil)
|
||||
stripVersionPrefix(r)
|
||||
if r.URL.Path != "/v1/chat" {
|
||||
t.Errorf("path = %q, want /v1/chat", r.URL.Path)
|
||||
}
|
||||
|
||||
r2 := httptest.NewRequest(http.MethodGet, "/v1/chat", nil)
|
||||
stripVersionPrefix(r2)
|
||||
if r2.URL.Path != "/v1/chat" {
|
||||
t.Errorf("path = %q, want unchanged", r2.URL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_CloseStreams(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
s.CloseStreams()
|
||||
select {
|
||||
case <-s.shutdownCtx.Done():
|
||||
default:
|
||||
t.Error("CloseStreams did not cancel shutdown context")
|
||||
}
|
||||
s.CloseStreams() // idempotent
|
||||
}
|
||||
|
||||
func TestServer_HandleUIAndFavicon(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
for _, path := range []string{"/ui/", "/favicon.ico"} {
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, path, nil))
|
||||
// The embedded ui_dist only carries placeholder.txt in test builds, so
|
||||
// these resolve to 404 — the handlers still execute end to end.
|
||||
if w.Code != http.StatusOK && w.Code != http.StatusNotFound {
|
||||
t.Errorf("%s: status = %d", path, w.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_HandleAPIUnloadAll(t *testing.T) {
|
||||
local := newStubRouter([]string{"m1"}, "")
|
||||
s := newTestServer(local, newStubRouter(nil, ""))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/models/unload", nil))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d", w.Code)
|
||||
}
|
||||
if local.unloadCalls.Load() != 1 {
|
||||
t.Errorf("unloadCalls = %d, want 1", local.unloadCalls.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_HandleAPIUnloadModel(t *testing.T) {
|
||||
local := newStubRouter([]string{"m1"}, "")
|
||||
s := newTestServer(local, newStubRouter(nil, ""))
|
||||
s.cfg = config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
|
||||
|
||||
t.Run("known model", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/models/unload/m1", nil))
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unknown model 404", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/models/unload/nope", nil))
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("status = %d, want 404", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_HandleAPICapture(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
s.metrics = newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 5)
|
||||
s.metrics.addCapture(ReqRespCapture{ID: 42, ReqPath: "/v1/chat/completions"})
|
||||
|
||||
t.Run("found", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/captures/42", nil))
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d", w.Code)
|
||||
}
|
||||
if !bytes.Contains(w.Body.Bytes(), []byte("/v1/chat/completions")) {
|
||||
t.Errorf("body = %q", w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/captures/999", nil))
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("status = %d, want 404", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid id", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/captures/abc", nil))
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want 400", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// CreateFilterMiddleware returns middleware that applies per-model request-body
|
||||
// filters to JSON requests before they are forwarded upstream:
|
||||
//
|
||||
// - UseModelName rewrite (issue #69)
|
||||
// - StripParams removal (issue #174)
|
||||
// - SetParams injection (issue #453)
|
||||
// - SetParamsByID per-alias overrides
|
||||
//
|
||||
// Non-JSON requests (GET, multipart forms) pass through untouched. The buffered
|
||||
// body is re-attached with Content-Length / Transfer-Encoding cleanup so the
|
||||
// downstream reverse proxy forwards the correct bytes (see issue #11).
|
||||
func CreateFilterMiddleware(cfg config.Config) chain.Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), "application/json") {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := router.FetchContext(r, cfg)
|
||||
if err != nil {
|
||||
router.SendError(w, r, router.ErrNoModelInContext)
|
||||
return
|
||||
}
|
||||
|
||||
useModelName, filters, ok := resolveFilters(cfg, data.Model)
|
||||
if !ok {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusBadRequest, "could not read request body")
|
||||
return
|
||||
}
|
||||
|
||||
body, err = applyFilters(body, data.Model, useModelName, filters)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
r.Header.Del("Transfer-Encoding")
|
||||
r.Header.Set("Content-Length", strconv.Itoa(len(body)))
|
||||
r.ContentLength = int64(len(body))
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// CreateFormFilterMiddleware returns middleware that applies the UseModelName
|
||||
// rewrite (issue #69) to multipart/form-data requests before they are forwarded
|
||||
// upstream. JSON-body filters (StripParams, SetParams) do not apply to form
|
||||
// endpoints; only the "model" field is rewritten.
|
||||
//
|
||||
// Non-multipart requests pass through untouched. When a rewrite is needed the
|
||||
// form is reconstructed and re-attached with Content-Type / Content-Length
|
||||
// cleanup so the downstream reverse proxy forwards the correct bytes.
|
||||
func CreateFormFilterMiddleware(cfg config.Config) chain.Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := router.FetchContext(r, cfg)
|
||||
if err != nil {
|
||||
router.SendError(w, r, router.ErrNoModelInContext)
|
||||
return
|
||||
}
|
||||
|
||||
useModelName, _, ok := resolveFilters(cfg, data.Model)
|
||||
if !ok || useModelName == "" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||
router.SendResponse(w, r, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
body, contentType, err := rewriteMultipartModel(r.MultipartForm, useModelName)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
r.MultipartForm = nil
|
||||
r.Header.Del("Transfer-Encoding")
|
||||
r.Header.Set("Content-Type", contentType)
|
||||
r.Header.Set("Content-Length", strconv.Itoa(len(body)))
|
||||
r.ContentLength = int64(len(body))
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// rewriteMultipartModel reconstructs a multipart form, replacing the "model"
|
||||
// field value with useModelName. It returns the encoded body and the matching
|
||||
// Content-Type header (which carries the generated boundary).
|
||||
func rewriteMultipartModel(form *multipart.Form, useModelName string) ([]byte, string, error) {
|
||||
var buf bytes.Buffer
|
||||
mw := multipart.NewWriter(&buf)
|
||||
|
||||
for key, values := range form.Value {
|
||||
for _, value := range values {
|
||||
if key == "model" {
|
||||
value = useModelName
|
||||
}
|
||||
field, err := mw.CreateFormField(key)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error recreating form field %s: %w", key, err)
|
||||
}
|
||||
if _, err := field.Write([]byte(value)); err != nil {
|
||||
return nil, "", fmt.Errorf("error writing form field %s: %w", key, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for key, headers := range form.File {
|
||||
for _, fh := range headers {
|
||||
part, err := mw.CreateFormFile(key, fh.Filename)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error recreating form file %s: %w", key, err)
|
||||
}
|
||||
file, err := fh.Open()
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error opening uploaded file %s: %w", key, err)
|
||||
}
|
||||
if _, err := io.Copy(part, file); err != nil {
|
||||
file.Close()
|
||||
return nil, "", fmt.Errorf("error copying file data %s: %w", key, err)
|
||||
}
|
||||
file.Close()
|
||||
}
|
||||
}
|
||||
|
||||
if err := mw.Close(); err != nil {
|
||||
return nil, "", fmt.Errorf("error finalizing multipart form: %w", err)
|
||||
}
|
||||
return buf.Bytes(), mw.FormDataContentType(), nil
|
||||
}
|
||||
|
||||
// resolveFilters returns the filter settings for a requested model. UseModelName
|
||||
// only applies to local models; peers carry filters but no name rewrite.
|
||||
func resolveFilters(cfg config.Config, requested string) (useModelName string, filters config.Filters, ok bool) {
|
||||
if realName, found := cfg.RealModelName(requested); found {
|
||||
mc := cfg.Models[realName]
|
||||
return mc.UseModelName, mc.Filters.Filters, true
|
||||
}
|
||||
for _, peer := range cfg.Peers {
|
||||
for _, m := range peer.Models {
|
||||
if m == requested {
|
||||
return "", peer.Filters, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", config.Filters{}, false
|
||||
}
|
||||
|
||||
// applyFilters rewrites the JSON body in place. Order matches the legacy
|
||||
// ProxyManager: useModelName, stripParams, setParams, then setParamsByID (which
|
||||
// can override setParams).
|
||||
func applyFilters(body []byte, requested, useModelName string, f config.Filters) ([]byte, error) {
|
||||
var err error
|
||||
|
||||
if useModelName != "" {
|
||||
if body, err = sjson.SetBytes(body, "model", useModelName); err != nil {
|
||||
return nil, fmt.Errorf("error rewriting model name in JSON: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, param := range f.SanitizedStripParams() {
|
||||
if body, err = sjson.DeleteBytes(body, param); err != nil {
|
||||
return nil, fmt.Errorf("error stripping parameter %s from request", param)
|
||||
}
|
||||
}
|
||||
|
||||
setParams, setKeys := f.SanitizedSetParams()
|
||||
for _, key := range setKeys {
|
||||
if body, err = sjson.SetBytes(body, key, setParams[key]); err != nil {
|
||||
return nil, fmt.Errorf("error setting parameter %s in request", key)
|
||||
}
|
||||
}
|
||||
|
||||
byID, byIDKeys := f.SanitizedSetParamsByID(requested)
|
||||
for _, key := range byIDKeys {
|
||||
if body, err = sjson.SetBytes(body, key, byID[key]); err != nil {
|
||||
return nil, fmt.Errorf("error setting parameter %s in request", key)
|
||||
}
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestServer_ApplyFilters(t *testing.T) {
|
||||
t.Run("useModelName rewrite", func(t *testing.T) {
|
||||
out, err := applyFilters([]byte(`{"model":"alias","temp":1}`), "alias", "real-model", config.Filters{})
|
||||
if err != nil {
|
||||
t.Fatalf("applyFilters: %v", err)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "model").String(); got != "real-model" {
|
||||
t.Errorf("model = %q, want real-model", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("strip and set params", func(t *testing.T) {
|
||||
f := config.Filters{
|
||||
StripParams: "temperature",
|
||||
SetParams: map[string]any{"top_p": 0.9},
|
||||
}
|
||||
out, err := applyFilters([]byte(`{"model":"m","temperature":0.7}`), "m", "", f)
|
||||
if err != nil {
|
||||
t.Fatalf("applyFilters: %v", err)
|
||||
}
|
||||
if gjson.GetBytes(out, "temperature").Exists() {
|
||||
t.Error("temperature should be stripped")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "top_p").Float(); got != 0.9 {
|
||||
t.Errorf("top_p = %v, want 0.9", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("setParamsByID overrides setParams", func(t *testing.T) {
|
||||
f := config.Filters{
|
||||
SetParams: map[string]any{"top_p": 0.5},
|
||||
SetParamsByID: map[string]map[string]any{"alias": {"top_p": 0.1}},
|
||||
}
|
||||
out, err := applyFilters([]byte(`{"model":"alias"}`), "alias", "", f)
|
||||
if err != nil {
|
||||
t.Fatalf("applyFilters: %v", err)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "top_p").Float(); got != 0.1 {
|
||||
t.Errorf("top_p = %v, want 0.1", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_RewriteMultipartModel(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
mw := multipart.NewWriter(&buf)
|
||||
mw.WriteField("model", "old-name")
|
||||
mw.WriteField("language", "en")
|
||||
fw, _ := mw.CreateFormFile("file", "audio.wav")
|
||||
fw.Write([]byte("RIFFdata"))
|
||||
mw.Close()
|
||||
|
||||
r := httptest.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf)
|
||||
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||
t.Fatalf("ParseMultipartForm: %v", err)
|
||||
}
|
||||
|
||||
body, contentType, err := rewriteMultipartModel(r.MultipartForm, "new-name")
|
||||
if err != nil {
|
||||
t.Fatalf("rewriteMultipartModel: %v", err)
|
||||
}
|
||||
|
||||
parsed, err := multipart.NewReader(bytes.NewReader(body), boundaryOf(t, contentType)).ReadForm(32 << 20)
|
||||
if err != nil {
|
||||
t.Fatalf("re-parse: %v", err)
|
||||
}
|
||||
if got := parsed.Value["model"][0]; got != "new-name" {
|
||||
t.Errorf("model = %q, want new-name", got)
|
||||
}
|
||||
if got := parsed.Value["language"][0]; got != "en" {
|
||||
t.Errorf("language = %q, want en", got)
|
||||
}
|
||||
fh := parsed.File["file"][0]
|
||||
f, _ := fh.Open()
|
||||
data, _ := io.ReadAll(f)
|
||||
f.Close()
|
||||
if string(data) != "RIFFdata" {
|
||||
t.Errorf("file data = %q, want RIFFdata", data)
|
||||
}
|
||||
}
|
||||
|
||||
func boundaryOf(t *testing.T, contentType string) string {
|
||||
t.Helper()
|
||||
_, params, ok := strings.Cut(contentType, "boundary=")
|
||||
if !ok {
|
||||
t.Fatalf("no boundary in %q", contentType)
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
func TestServer_FormFilterMiddleware(t *testing.T) {
|
||||
cfg := config.Config{Models: map[string]config.ModelConfig{
|
||||
"whisper": {UseModelName: "whisper-large-v3"},
|
||||
}}
|
||||
|
||||
var buf bytes.Buffer
|
||||
mw := multipart.NewWriter(&buf)
|
||||
mw.WriteField("model", "whisper")
|
||||
fw, _ := mw.CreateFormFile("file", "a.wav")
|
||||
fw.Write([]byte("xx"))
|
||||
mw.Close()
|
||||
|
||||
r := httptest.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf)
|
||||
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
|
||||
var gotModel string
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = r.ParseMultipartForm(32 << 20)
|
||||
gotModel = r.MultipartForm.Value["model"][0]
|
||||
})
|
||||
CreateFormFilterMiddleware(cfg)(final).ServeHTTP(httptest.NewRecorder(), r)
|
||||
|
||||
if gotModel != "whisper-large-v3" {
|
||||
t.Errorf("model rewritten to %q, want whisper-large-v3", gotModel)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// inflightCounter tracks the number of in-flight model-dispatched requests.
|
||||
type inflightCounter struct {
|
||||
total atomic.Int64
|
||||
}
|
||||
|
||||
func (c *inflightCounter) Increment() int64 { return c.total.Add(1) }
|
||||
func (c *inflightCounter) Decrement() int64 { return c.total.Add(-1) }
|
||||
func (c *inflightCounter) Current() int64 { return c.total.Load() }
|
||||
|
||||
// CreateInflightMiddleware returns middleware that increments the counter on
|
||||
// entry and decrements on exit, emitting an InFlightRequestsEvent for each.
|
||||
func CreateInflightMiddleware(c *inflightCounter) chain.Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
event.Emit(shared.InFlightRequestsEvent{Total: int(c.Increment())})
|
||||
defer func() {
|
||||
event.Emit(shared.InFlightRequestsEvent{Total: int(c.Decrement())})
|
||||
}()
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,222 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
)
|
||||
|
||||
// NewLoggers builds the proxy, upstream, and combined (mux) log monitors,
|
||||
// wiring each one's output per the logToStdout config value. The proxy and
|
||||
// upstream monitors write into muxlog (rather than os.Stdout directly) so
|
||||
// muxlog accumulates a combined history for the /logs endpoints, while each
|
||||
// monitor keeps its own per-source history and event subscribers.
|
||||
//
|
||||
// Behaviour matches the legacy ProxyManager:
|
||||
//
|
||||
// - none: everything discarded
|
||||
// - both: proxy + upstream both routed to muxlog -> stdout
|
||||
// - upstream: only upstream routed to muxlog -> stdout; proxy discarded
|
||||
// - proxy: only proxy routed to muxlog -> stdout; upstream discarded
|
||||
//
|
||||
// An empty or unrecognised value behaves like "proxy".
|
||||
func NewLoggers(logToStdout string) (muxlog, proxylog, upstreamlog *logmon.Monitor) {
|
||||
switch logToStdout {
|
||||
case config.LogToStdoutNone:
|
||||
muxlog = logmon.NewWriter(io.Discard)
|
||||
proxylog = logmon.NewWriter(io.Discard)
|
||||
upstreamlog = logmon.NewWriter(io.Discard)
|
||||
case config.LogToStdoutBoth:
|
||||
muxlog = logmon.NewWriter(os.Stdout)
|
||||
proxylog = logmon.NewWriter(muxlog)
|
||||
upstreamlog = logmon.NewWriter(muxlog)
|
||||
case config.LogToStdoutUpstream:
|
||||
muxlog = logmon.NewWriter(os.Stdout)
|
||||
proxylog = logmon.NewWriter(io.Discard)
|
||||
upstreamlog = logmon.NewWriter(muxlog)
|
||||
default:
|
||||
// config.LogToStdoutProxy, and the fallback for an unset value.
|
||||
muxlog = logmon.NewWriter(os.Stdout)
|
||||
proxylog = logmon.NewWriter(muxlog)
|
||||
upstreamlog = logmon.NewWriter(io.Discard)
|
||||
}
|
||||
return muxlog, proxylog, upstreamlog
|
||||
}
|
||||
|
||||
// handleLogs serves the historical proxy/upstream log. HTML clients are
|
||||
// redirected to the UI.
|
||||
func (s *Server) handleLogs(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.Header.Get("Accept"), "text/html") {
|
||||
http.Redirect(w, r, "/ui/", http.StatusFound)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.Write(s.muxlog.GetHistory())
|
||||
}
|
||||
|
||||
// getLogger resolves a log monitor by id. An empty id maps to the combined
|
||||
// muxlog; "proxy" and "upstream" select the respective monitors.
|
||||
func (s *Server) getLogger(logMonitorID string) (*logmon.Monitor, error) {
|
||||
switch logMonitorID {
|
||||
case "":
|
||||
return s.muxlog, nil
|
||||
case "proxy":
|
||||
return s.proxylog, nil
|
||||
case "upstream":
|
||||
return s.upstreamlog, nil
|
||||
default:
|
||||
if _, modelID, _, found := findModelInPath(s.cfg, "/"+logMonitorID); found {
|
||||
if log, ok := s.local.ProcessLogger(modelID); ok {
|
||||
return log, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID")
|
||||
}
|
||||
}
|
||||
|
||||
// handleLogStream tails a log monitor: it writes the history then streams live
|
||||
// log data until the client disconnects or the server shuts down.
|
||||
func (s *Server) handleLogStream(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
// prevent nginx from buffering streamed logs
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
logMonitorID := strings.TrimPrefix(r.PathValue("logMonitorID"), "/")
|
||||
// Strip a query string if it leaked into the path segment.
|
||||
if idx := strings.Index(logMonitorID, "?"); idx != -1 {
|
||||
logMonitorID = logMonitorID[:idx]
|
||||
}
|
||||
|
||||
logger, err := s.getLogger(logMonitorID)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
|
||||
return
|
||||
}
|
||||
|
||||
_, skipHistory := r.URL.Query()["no-history"]
|
||||
if !skipHistory {
|
||||
if history := logger.GetHistory(); len(history) != 0 {
|
||||
w.Write(history)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
sendChan := make(chan []byte, 10)
|
||||
ctx, cancel := context.WithCancel(r.Context())
|
||||
defer cancel()
|
||||
cancelSub := logger.OnLogData(func(data []byte) {
|
||||
select {
|
||||
case sendChan <- data:
|
||||
case <-ctx.Done():
|
||||
default:
|
||||
}
|
||||
})
|
||||
defer cancelSub()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case <-s.shutdownCtx.Done():
|
||||
return
|
||||
case data := <-sendChan:
|
||||
w.Write(data)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// requestLogPathSkips lists path prefixes excluded from the access log because
|
||||
// they are polled frequently and would drown out useful entries.
|
||||
var requestLogPathSkips = []string{"/wol-health", "/api/performance", "/metrics"}
|
||||
|
||||
// statusRecorder wraps an http.ResponseWriter to capture the response status
|
||||
// code and the number of body bytes written, so the access log can report
|
||||
// them. Flush is forwarded so streaming handlers (SSE) still work.
|
||||
type statusRecorder struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
size int
|
||||
}
|
||||
|
||||
func (sr *statusRecorder) WriteHeader(code int) {
|
||||
sr.status = code
|
||||
sr.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (sr *statusRecorder) Write(b []byte) (int, error) {
|
||||
n, err := sr.ResponseWriter.Write(b)
|
||||
sr.size += n
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (sr *statusRecorder) Flush() {
|
||||
if f, ok := sr.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// clientIP resolves the originating client address, preferring proxy headers
|
||||
// over the raw connection address.
|
||||
func clientIP(r *http.Request) string {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
if first, _, found := strings.Cut(xff, ","); found {
|
||||
return strings.TrimSpace(first)
|
||||
}
|
||||
return strings.TrimSpace(xff)
|
||||
}
|
||||
if xr := r.Header.Get("X-Real-IP"); xr != "" {
|
||||
return strings.TrimSpace(xr)
|
||||
}
|
||||
if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||
return host
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
// CreateRequestLogMiddleware returns middleware that records one access-log
|
||||
// line per request to proxylog, in the legacy format:
|
||||
//
|
||||
// clientIP "METHOD PATH PROTO" status bodySize "UA" duration
|
||||
//
|
||||
// Frequently-polled health/metrics paths are skipped. The path is captured
|
||||
// before next runs because /upstream rewrites the request URL in place.
|
||||
func CreateRequestLogMiddleware(proxylog *logmon.Monitor) chain.Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
for _, prefix := range requestLogPathSkips {
|
||||
if strings.HasPrefix(r.URL.Path, prefix) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
ip, method, path, proto, ua := clientIP(r), r.Method, r.URL.Path, r.Proto, r.UserAgent()
|
||||
|
||||
rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK}
|
||||
next.ServeHTTP(rec, r)
|
||||
|
||||
proxylog.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
|
||||
ip, method, path, proto, rec.status, rec.size, ua, time.Since(start))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
func TestServer_NewLoggers(t *testing.T) {
|
||||
t.Run("proxy mode routes proxy into muxlog, discards upstream", func(t *testing.T) {
|
||||
mux, proxy, upstream := NewLoggers(config.LogToStdoutProxy)
|
||||
proxy.Info("PROXYLINE")
|
||||
upstream.Info("UPSTREAMLINE")
|
||||
h := string(mux.GetHistory())
|
||||
if !strings.Contains(h, "PROXYLINE") {
|
||||
t.Errorf("muxlog missing proxy line: %q", h)
|
||||
}
|
||||
if strings.Contains(h, "UPSTREAMLINE") {
|
||||
t.Errorf("muxlog should not contain upstream line: %q", h)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("both mode routes proxy and upstream into muxlog", func(t *testing.T) {
|
||||
mux, proxy, upstream := NewLoggers(config.LogToStdoutBoth)
|
||||
proxy.Info("PROXYLINE")
|
||||
upstream.Info("UPSTREAMLINE")
|
||||
h := string(mux.GetHistory())
|
||||
if !strings.Contains(h, "PROXYLINE") || !strings.Contains(h, "UPSTREAMLINE") {
|
||||
t.Errorf("muxlog history = %q", h)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("none mode discards everything from muxlog", func(t *testing.T) {
|
||||
mux, proxy, upstream := NewLoggers(config.LogToStdoutNone)
|
||||
proxy.Info("PROXYLINE")
|
||||
upstream.Info("UPSTREAMLINE")
|
||||
if len(mux.GetHistory()) != 0 {
|
||||
t.Errorf("muxlog should be empty, got %q", mux.GetHistory())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_HandleLogs_Plain(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
s.muxlog.Write([]byte("a log line"))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/logs", nil))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d", w.Code)
|
||||
}
|
||||
if ct := w.Header().Get("Content-Type"); ct != "text/plain" {
|
||||
t.Errorf("Content-Type = %q, want text/plain", ct)
|
||||
}
|
||||
if w.Body.String() != "a log line" {
|
||||
t.Errorf("body = %q", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_HandleLogs_HTMLRedirect(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
|
||||
req.Header.Set("Accept", "text/html")
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusFound {
|
||||
t.Fatalf("status = %d, want 302", w.Code)
|
||||
}
|
||||
if got := w.Header().Get("Location"); got != "/ui/" {
|
||||
t.Errorf("Location = %q, want /ui/", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ClientIP(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
setup func(*http.Request)
|
||||
want string
|
||||
}{
|
||||
{"remote addr", func(r *http.Request) { r.RemoteAddr = "10.0.0.5:1234" }, "10.0.0.5"},
|
||||
{"x-forwarded-for", func(r *http.Request) {
|
||||
r.Header.Set("X-Forwarded-For", "1.2.3.4, 5.6.7.8")
|
||||
}, "1.2.3.4"},
|
||||
{"x-real-ip", func(r *http.Request) { r.Header.Set("X-Real-IP", "9.9.9.9") }, "9.9.9.9"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.RemoteAddr = ""
|
||||
c.setup(r)
|
||||
if got := clientIP(r); got != c.want {
|
||||
t.Errorf("clientIP() = %q, want %q", got, c.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_RequestLogMiddleware(t *testing.T) {
|
||||
proxylog := logmon.NewWriter(io.Discard)
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
w.Write([]byte("hello"))
|
||||
})
|
||||
mw := CreateRequestLogMiddleware(proxylog)
|
||||
|
||||
t.Run("logs request", func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
r.RemoteAddr = "192.168.1.1:5000"
|
||||
mw(final).ServeHTTP(httptest.NewRecorder(), r)
|
||||
|
||||
line := string(proxylog.GetHistory())
|
||||
for _, want := range []string{"192.168.1.1", "POST /v1/chat/completions", "201", "5"} {
|
||||
if !strings.Contains(line, want) {
|
||||
t.Errorf("log line %q missing %q", line, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
for _, path := range []string{"/wol-health", "/api/performance", "/metrics"} {
|
||||
t.Run("skips "+path, func(t *testing.T) {
|
||||
skipLog := logmon.NewWriter(io.Discard)
|
||||
skipMW := CreateRequestLogMiddleware(skipLog)
|
||||
skipMW(final).ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, path, nil))
|
||||
if len(skipLog.GetHistory()) != 0 {
|
||||
t.Errorf("%s should not be logged; got %q", path, skipLog.GetHistory())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,450 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/cache"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/ring"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// TokenMetrics holds token usage and performance metrics.
|
||||
type TokenMetrics struct {
|
||||
CachedTokens int `json:"cache_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||
TokensPerSecond float64 `json:"tokens_per_second"`
|
||||
}
|
||||
|
||||
// ActivityLogEntry represents parsed token statistics from llama-server logs.
|
||||
type ActivityLogEntry struct {
|
||||
ID int `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Model string `json:"model"`
|
||||
ReqPath string `json:"req_path"`
|
||||
RespContentType string `json:"resp_content_type"`
|
||||
RespStatusCode int `json:"resp_status_code"`
|
||||
Tokens TokenMetrics `json:"tokens"`
|
||||
DurationMs int `json:"duration_ms"`
|
||||
HasCapture bool `json:"has_capture"`
|
||||
}
|
||||
|
||||
// ActivityLogEvent carries a single activity log entry to event subscribers.
|
||||
type ActivityLogEvent struct {
|
||||
Metrics ActivityLogEntry
|
||||
}
|
||||
|
||||
func (e ActivityLogEvent) Type() uint32 {
|
||||
return shared.ActivityLogEventID
|
||||
}
|
||||
|
||||
// metricsMonitor parses upstream responses for token statistics, keeps a
|
||||
// bounded in-memory ring of recent activity, and (when captures are enabled)
|
||||
// stores zstd+CBOR-compressed request/response captures in a sized cache.
|
||||
type metricsMonitor struct {
|
||||
mu sync.RWMutex
|
||||
metrics ring.Buffer[ActivityLogEntry]
|
||||
nextID int
|
||||
logger *logmon.Monitor
|
||||
|
||||
enableCaptures bool
|
||||
captureCache *cache.Cache // zstd-compressed CBOR of ReqRespCapture
|
||||
}
|
||||
|
||||
// newMetricsMonitor creates a metricsMonitor retaining up to maxMetrics entries.
|
||||
// captureBufferMB is the capture buffer size in megabytes; 0 disables captures.
|
||||
func newMetricsMonitor(logger *logmon.Monitor, maxMetrics int, captureBufferMB int) *metricsMonitor {
|
||||
if maxMetrics <= 0 {
|
||||
maxMetrics = 1000
|
||||
}
|
||||
mm := &metricsMonitor{
|
||||
logger: logger,
|
||||
metrics: ring.NewBuffer[ActivityLogEntry](maxMetrics),
|
||||
enableCaptures: captureBufferMB > 0,
|
||||
}
|
||||
if captureBufferMB > 0 {
|
||||
mm.captureCache = cache.New(captureBufferMB * 1024 * 1024)
|
||||
}
|
||||
return mm
|
||||
}
|
||||
|
||||
// queueMetrics adds a metric to the ring and returns its assigned ID.
|
||||
func (mp *metricsMonitor) queueMetrics(metric ActivityLogEntry) int {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
metric.ID = mp.nextID
|
||||
mp.nextID++
|
||||
mp.metrics.Push(metric)
|
||||
return metric.ID
|
||||
}
|
||||
|
||||
// emitMetric publishes an ActivityLogEvent for the given metric.
|
||||
func (mp *metricsMonitor) emitMetric(metric ActivityLogEntry) {
|
||||
event.Emit(ActivityLogEvent{Metrics: metric})
|
||||
}
|
||||
|
||||
// getMetrics returns a copy of the current metrics.
|
||||
func (mp *metricsMonitor) getMetrics() []ActivityLogEntry {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
result := mp.metrics.Slice()
|
||||
if result == nil {
|
||||
return []ActivityLogEntry{}
|
||||
}
|
||||
if mp.captureCache != nil {
|
||||
for i := range result {
|
||||
result[i].HasCapture = mp.captureCache.Has(result[i].ID)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// getMetricsJSON returns the current metrics as a JSON array.
|
||||
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
|
||||
return json.Marshal(mp.getMetrics())
|
||||
}
|
||||
|
||||
// record parses a completed response body and stores/emits an activity entry.
|
||||
// When captures are enabled, a zstd+CBOR capture is stored for successful
|
||||
// requests, with cf controlling which request/response parts are retained.
|
||||
// reqBody and reqHeaders are the request data buffered before dispatch.
|
||||
func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *responseBodyCopier, cf captureFields, reqBody []byte, reqHeaders map[string]string) {
|
||||
tm := ActivityLogEntry{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
ReqPath: r.URL.Path,
|
||||
RespContentType: recorder.Header().Get("Content-Type"),
|
||||
RespStatusCode: recorder.Status(),
|
||||
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
|
||||
}
|
||||
|
||||
queueAndEmit := func() {
|
||||
tm.ID = mp.queueMetrics(tm)
|
||||
mp.emitMetric(tm)
|
||||
}
|
||||
|
||||
if recorder.Status() != http.StatusOK {
|
||||
mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), r.URL.Path)
|
||||
queueAndEmit()
|
||||
return
|
||||
}
|
||||
|
||||
body := recorder.body.Bytes()
|
||||
if len(body) == 0 {
|
||||
mp.logger.Warn("metrics: empty body, recording minimal metrics")
|
||||
queueAndEmit()
|
||||
return
|
||||
}
|
||||
|
||||
if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" {
|
||||
decoded, err := decompressBody(body, encoding)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, r.URL.Path)
|
||||
queueAndEmit()
|
||||
return
|
||||
}
|
||||
body = decoded
|
||||
}
|
||||
|
||||
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, r.URL.Path)
|
||||
} else {
|
||||
tm.Tokens = parsed.Tokens
|
||||
tm.DurationMs = parsed.DurationMs
|
||||
}
|
||||
} else if gjson.ValidBytes(body) {
|
||||
parsed := gjson.ParseBytes(body)
|
||||
usage := parsed.Get("usage")
|
||||
timings := parsed.Get("timings")
|
||||
|
||||
// /infill responses are arrays; timings live in the last element (#463).
|
||||
if strings.HasPrefix(r.URL.Path, "/infill") {
|
||||
if arr := parsed.Array(); len(arr) > 0 {
|
||||
timings = arr[len(arr)-1].Get("timings")
|
||||
}
|
||||
}
|
||||
|
||||
if usage.Exists() || timings.Exists() {
|
||||
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
|
||||
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, r.URL.Path)
|
||||
} else {
|
||||
tm.Tokens = parsedMetrics.Tokens
|
||||
tm.DurationMs = parsedMetrics.DurationMs
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", r.URL.Path)
|
||||
}
|
||||
|
||||
tm.ID = mp.queueMetrics(tm)
|
||||
if mp.enableCaptures {
|
||||
capture := ReqRespCapture{
|
||||
ID: tm.ID,
|
||||
ReqPath: r.URL.Path,
|
||||
ReqHeaders: reqHeaders,
|
||||
}
|
||||
if cf&captureReqBody != 0 {
|
||||
capture.ReqBody = reqBody
|
||||
}
|
||||
if cf&captureRespHeaders != 0 {
|
||||
capture.RespHeaders = headerMap(recorder.Header())
|
||||
redactHeaders(capture.RespHeaders)
|
||||
delete(capture.RespHeaders, "Content-Encoding")
|
||||
}
|
||||
if cf&captureRespBody != 0 {
|
||||
capture.RespBody = body
|
||||
}
|
||||
if mp.addCapture(capture) {
|
||||
tm.HasCapture = true
|
||||
}
|
||||
}
|
||||
mp.emitMetric(tm)
|
||||
}
|
||||
|
||||
// usagePaths lists the JSON paths where a per-event usage object can live.
|
||||
var usagePaths = []string{"usage", "response.usage", "message.usage"}
|
||||
|
||||
// extractUsageTokens reads input/output/cached token counts from a usage
|
||||
// gjson.Result, handling the field-name differences across endpoints.
|
||||
func extractUsageTokens(usage gjson.Result) (input, output, cached int64, ok bool) {
|
||||
cached = -1
|
||||
if !usage.Exists() {
|
||||
return
|
||||
}
|
||||
|
||||
if v := usage.Get("prompt_tokens"); v.Exists() {
|
||||
input = v.Int()
|
||||
ok = true
|
||||
} else if v := usage.Get("input_tokens"); v.Exists() {
|
||||
input = v.Int()
|
||||
ok = true
|
||||
}
|
||||
|
||||
if v := usage.Get("completion_tokens"); v.Exists() {
|
||||
output = v.Int()
|
||||
ok = true
|
||||
} else if v := usage.Get("output_tokens"); v.Exists() {
|
||||
output = v.Int()
|
||||
ok = true
|
||||
}
|
||||
|
||||
if v := usage.Get("cache_read_input_tokens"); v.Exists() {
|
||||
cached = v.Int()
|
||||
ok = true
|
||||
} else if v := usage.Get("input_tokens_details.cached_tokens"); v.Exists() {
|
||||
cached = v.Int()
|
||||
ok = true
|
||||
} else if v := usage.Get("prompt_tokens_details.cached_tokens"); v.Exists() {
|
||||
cached = v.Int()
|
||||
ok = true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func processStreamingResponse(modelID string, start time.Time, body []byte) (ActivityLogEntry, error) {
|
||||
var (
|
||||
inputTokens, outputTokens int64
|
||||
cachedTokens int64 = -1
|
||||
hasAny bool
|
||||
timings gjson.Result
|
||||
)
|
||||
|
||||
prefix := []byte("data:")
|
||||
for offset := 0; offset < len(body); {
|
||||
nl := bytes.IndexByte(body[offset:], '\n')
|
||||
var line []byte
|
||||
if nl == -1 {
|
||||
line = body[offset:]
|
||||
offset = len(body)
|
||||
} else {
|
||||
line = body[offset : offset+nl]
|
||||
offset += nl + 1
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || !bytes.HasPrefix(line, prefix) {
|
||||
continue
|
||||
}
|
||||
data := bytes.TrimSpace(line[len(prefix):])
|
||||
if len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
if !gjson.ValidBytes(data) {
|
||||
continue
|
||||
}
|
||||
parsed := gjson.ParseBytes(data)
|
||||
|
||||
for _, path := range usagePaths {
|
||||
u := parsed.Get(path)
|
||||
if !u.Exists() {
|
||||
continue
|
||||
}
|
||||
i, o, c, ok := extractUsageTokens(u)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
hasAny = true
|
||||
if i > 0 {
|
||||
inputTokens = i
|
||||
}
|
||||
if o > 0 {
|
||||
outputTokens = o
|
||||
}
|
||||
if c >= 0 {
|
||||
cachedTokens = c
|
||||
}
|
||||
}
|
||||
if t := parsed.Get("timings"); t.Exists() {
|
||||
timings = t
|
||||
hasAny = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasAny {
|
||||
return ActivityLogEntry{}, fmt.Errorf("no valid JSON data found in stream")
|
||||
}
|
||||
|
||||
return buildMetrics(modelID, start, inputTokens, outputTokens, cachedTokens, timings), nil
|
||||
}
|
||||
|
||||
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (ActivityLogEntry, error) {
|
||||
input, output, cached, _ := extractUsageTokens(usage)
|
||||
return buildMetrics(modelID, start, input, output, cached, timings), nil
|
||||
}
|
||||
|
||||
// buildMetrics composes an ActivityLogEntry from accumulated token counts and
|
||||
// optional llama-server timings (which override input/output and provide rates).
|
||||
func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, cachedTokens int64, timings gjson.Result) ActivityLogEntry {
|
||||
wallDurationMs := int(time.Since(start).Milliseconds())
|
||||
durationMs := wallDurationMs
|
||||
tokensPerSecond := -1.0
|
||||
promptPerSecond := -1.0
|
||||
|
||||
if timings.Exists() {
|
||||
inputTokens = timings.Get("prompt_n").Int()
|
||||
outputTokens = timings.Get("predicted_n").Int()
|
||||
promptPerSecond = timings.Get("prompt_per_second").Float()
|
||||
tokensPerSecond = timings.Get("predicted_per_second").Float()
|
||||
timingsDurationMs := int(timings.Get("prompt_ms").Float() + timings.Get("predicted_ms").Float())
|
||||
if timingsDurationMs > durationMs {
|
||||
durationMs = timingsDurationMs
|
||||
}
|
||||
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
|
||||
cachedTokens = cachedValue.Int()
|
||||
}
|
||||
}
|
||||
|
||||
return ActivityLogEntry{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
Tokens: TokenMetrics{
|
||||
CachedTokens: int(cachedTokens),
|
||||
InputTokens: int(inputTokens),
|
||||
OutputTokens: int(outputTokens),
|
||||
PromptPerSecond: promptPerSecond,
|
||||
TokensPerSecond: tokensPerSecond,
|
||||
},
|
||||
DurationMs: durationMs,
|
||||
}
|
||||
}
|
||||
|
||||
// decompressBody decompresses the body based on the Content-Encoding header.
|
||||
func decompressBody(body []byte, encoding string) ([]byte, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(encoding)) {
|
||||
case "gzip":
|
||||
reader, err := gzip.NewReader(bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
case "deflate":
|
||||
reader := flate.NewReader(bytes.NewReader(body))
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
|
||||
// filterAcceptEncoding filters Accept-Encoding to only gzip/deflate so response
|
||||
// bodies remain decompressible for metrics parsing.
|
||||
func filterAcceptEncoding(acceptEncoding string) string {
|
||||
if acceptEncoding == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
supported := map[string]bool{"gzip": true, "deflate": true}
|
||||
var filtered []string
|
||||
for part := range strings.SplitSeq(acceptEncoding, ",") {
|
||||
encoding, _, _ := strings.Cut(strings.TrimSpace(part), ";")
|
||||
if supported[strings.ToLower(encoding)] {
|
||||
filtered = append(filtered, strings.TrimSpace(part))
|
||||
}
|
||||
}
|
||||
return strings.Join(filtered, ", ")
|
||||
}
|
||||
|
||||
// responseBodyCopier tees the upstream response to the client while buffering
|
||||
// it for metrics parsing. Status defaults to 200 until WriteHeader is called.
|
||||
type responseBodyCopier struct {
|
||||
http.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
tee io.Writer
|
||||
status int
|
||||
wroteHeader bool
|
||||
start time.Time
|
||||
}
|
||||
|
||||
func newBodyCopier(w http.ResponseWriter) *responseBodyCopier {
|
||||
buf := &bytes.Buffer{}
|
||||
return &responseBodyCopier{
|
||||
ResponseWriter: w,
|
||||
body: buf,
|
||||
tee: io.MultiWriter(w, buf),
|
||||
status: http.StatusOK,
|
||||
start: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) Write(b []byte) (int, error) {
|
||||
if !w.wroteHeader {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
return w.tee.Write(b)
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) WriteHeader(statusCode int) {
|
||||
if w.wroteHeader {
|
||||
return
|
||||
}
|
||||
w.wroteHeader = true
|
||||
w.status = statusCode
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
// Flush forwards to the underlying writer so streaming responses still flush.
|
||||
func (w *responseBodyCopier) Flush() {
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) Status() int { return w.status }
|
||||
func (w *responseBodyCopier) StartTime() time.Time { return w.start }
|
||||
@@ -0,0 +1,62 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
)
|
||||
|
||||
// CreateMetricsMiddleware returns middleware that records token metrics for
|
||||
// model-dispatched POST requests. It resolves the model, tees the response into
|
||||
// a buffer, and parses token usage once the upstream handler returns.
|
||||
func CreateMetricsMiddleware(mm *metricsMonitor, cfg config.Config) chain.Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if mm == nil || r.Method != http.MethodPost {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve the model now so downstream dispatch hits the context
|
||||
// fast path; FetchContext restores the request body.
|
||||
data, err := router.FetchContext(r, cfg)
|
||||
if err != nil {
|
||||
router.SendError(w, r, router.ErrNoModelInContext)
|
||||
return
|
||||
}
|
||||
|
||||
// Buffer the request body/headers for capture before dispatch
|
||||
// consumes them.
|
||||
cf := captureFieldsFor(r.URL.Path)
|
||||
var reqBody []byte
|
||||
var reqHeaders map[string]string
|
||||
if mm.enableCaptures {
|
||||
if cf&captureReqBody != 0 && r.Body != nil {
|
||||
if buffered, err := io.ReadAll(r.Body); err == nil {
|
||||
reqBody = buffered
|
||||
r.Body.Close()
|
||||
r.Body = io.NopCloser(bytes.NewReader(reqBody))
|
||||
}
|
||||
}
|
||||
if cf&captureReqHeaders != 0 {
|
||||
reqHeaders = headerMap(r.Header)
|
||||
redactHeaders(reqHeaders)
|
||||
}
|
||||
}
|
||||
|
||||
// Restrict Accept-Encoding to encodings we can decompress so the
|
||||
// buffered response body stays parseable.
|
||||
if ae := r.Header.Get("Accept-Encoding"); ae != "" {
|
||||
r.Header.Set("Accept-Encoding", filterAcceptEncoding(ae))
|
||||
}
|
||||
|
||||
recorder := newBodyCopier(w)
|
||||
next.ServeHTTP(recorder, r)
|
||||
mm.record(data.ModelID, r, recorder, cf, reqBody, reqHeaders)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestServer_ParseMetrics_ChatCompletions(t *testing.T) {
|
||||
body := `{"usage":{"prompt_tokens":12,"completion_tokens":7,"prompt_tokens_details":{"cached_tokens":4}}}`
|
||||
parsed := gjson.Parse(body)
|
||||
entry, err := parseMetrics("m", time.Now(), parsed.Get("usage"), parsed.Get("timings"))
|
||||
if err != nil {
|
||||
t.Fatalf("parseMetrics: %v", err)
|
||||
}
|
||||
if entry.Tokens.InputTokens != 12 || entry.Tokens.OutputTokens != 7 || entry.Tokens.CachedTokens != 4 {
|
||||
t.Fatalf("tokens = %+v", entry.Tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ParseMetrics_Timings(t *testing.T) {
|
||||
body := `{"timings":{"prompt_n":20,"predicted_n":50,"prompt_per_second":100.0,"predicted_per_second":40.0,"prompt_ms":200,"predicted_ms":1250,"cache_n":8}}`
|
||||
parsed := gjson.Parse(body)
|
||||
entry, err := parseMetrics("m", time.Now(), parsed.Get("usage"), parsed.Get("timings"))
|
||||
if err != nil {
|
||||
t.Fatalf("parseMetrics: %v", err)
|
||||
}
|
||||
if entry.Tokens.InputTokens != 20 || entry.Tokens.OutputTokens != 50 || entry.Tokens.CachedTokens != 8 {
|
||||
t.Fatalf("tokens = %+v", entry.Tokens)
|
||||
}
|
||||
if entry.Tokens.TokensPerSecond != 40.0 || entry.Tokens.PromptPerSecond != 100.0 {
|
||||
t.Fatalf("rates = %+v", entry.Tokens)
|
||||
}
|
||||
if entry.DurationMs != 1450 {
|
||||
t.Fatalf("DurationMs = %d, want 1450", entry.DurationMs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessStreamingResponse(t *testing.T) {
|
||||
body := []byte("data: {\"choices\":[{}]}\n\n" +
|
||||
"data: {\"usage\":{\"prompt_tokens\":15,\"completion_tokens\":33}}\n\n" +
|
||||
"data: [DONE]\n\n")
|
||||
entry, err := processStreamingResponse("m", time.Now(), body)
|
||||
if err != nil {
|
||||
t.Fatalf("processStreamingResponse: %v", err)
|
||||
}
|
||||
if entry.Tokens.InputTokens != 15 || entry.Tokens.OutputTokens != 33 {
|
||||
t.Fatalf("tokens = %+v", entry.Tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessStreamingResponse_NoData(t *testing.T) {
|
||||
if _, err := processStreamingResponse("m", time.Now(), []byte("data: [DONE]\n\n")); err == nil {
|
||||
t.Fatal("expected error for stream with no usage data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ParseMetrics_Infill(t *testing.T) {
|
||||
// /infill responses are arrays; timings live in the last element.
|
||||
body := `[{"content":"a"},{"content":"b","timings":{"prompt_n":5,"predicted_n":9,"prompt_ms":10,"predicted_ms":20}}]`
|
||||
parsed := gjson.Parse(body)
|
||||
timings := parsed.Get("timings")
|
||||
if arr := parsed.Array(); len(arr) > 0 {
|
||||
timings = arr[len(arr)-1].Get("timings")
|
||||
}
|
||||
entry, err := parseMetrics("m", time.Now(), parsed.Get("usage"), timings)
|
||||
if err != nil {
|
||||
t.Fatalf("parseMetrics: %v", err)
|
||||
}
|
||||
if entry.Tokens.InputTokens != 5 || entry.Tokens.OutputTokens != 9 {
|
||||
t.Fatalf("tokens = %+v", entry.Tokens)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,290 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
)
|
||||
|
||||
// Server owns the HTTP mux, cross-cutting middleware, and the local/peer model
|
||||
// dispatch. It supersedes router.Server: it builds the local and peer routers
|
||||
// directly and dispatches between them itself.
|
||||
type Server struct {
|
||||
cfg config.Config
|
||||
|
||||
muxlog *logmon.Monitor
|
||||
proxylog *logmon.Monitor
|
||||
upstreamlog *logmon.Monitor
|
||||
|
||||
perf *perf.Monitor
|
||||
inflight *inflightCounter
|
||||
metrics *metricsMonitor
|
||||
build BuildInfo
|
||||
|
||||
local router.LocalRouter
|
||||
peer router.Router
|
||||
|
||||
mux *http.ServeMux
|
||||
handler http.Handler
|
||||
|
||||
shutdownCtx context.Context
|
||||
shutdownFn context.CancelFunc
|
||||
shuttingDown atomic.Bool
|
||||
}
|
||||
|
||||
// modelPostJSONRoutes are endpoints with a model id in the JSON request body.
|
||||
var modelPostJSONRoutes = []string{
|
||||
"/v1/chat/completions",
|
||||
"/v1/responses",
|
||||
"/v1/completions",
|
||||
"/v1/messages",
|
||||
"/v1/messages/count_tokens",
|
||||
"/v1/embeddings",
|
||||
"/reranking",
|
||||
"/rerank",
|
||||
"/v1/rerank",
|
||||
"/v1/reranking",
|
||||
"/infill",
|
||||
"/completion",
|
||||
"/v1/audio/speech",
|
||||
"/v1/audio/voices",
|
||||
"/v1/images/generations",
|
||||
"/sdapi/v1/txt2img",
|
||||
"/sdapi/v1/img2img",
|
||||
|
||||
// versionless routes, the /v/ is stripped before the request is forwarded upstream
|
||||
// see issue #728
|
||||
"/v/chat/completions",
|
||||
"/v/responses",
|
||||
"/v/completions",
|
||||
"/v/messages",
|
||||
"/v/messages/count_tokens",
|
||||
"/v/embeddings",
|
||||
"/v/rerank",
|
||||
"/v/reranking",
|
||||
}
|
||||
|
||||
// modelPostFormRoutes are multipart/form-data endpoints with a model id in the form data
|
||||
var modelPostFormRoutes = []string{
|
||||
"/v1/audio/transcriptions",
|
||||
"/v1/images/edits",
|
||||
}
|
||||
|
||||
// modelGetRoutes are model-dispatched GET endpoints (the model arrives as a
|
||||
// query parameter).
|
||||
var modelGetRoutes = []string{
|
||||
"/v1/audio/voices",
|
||||
"/sdapi/v1/loras",
|
||||
}
|
||||
|
||||
// BuildInfo carries version metadata surfaced by GET /api/version.
|
||||
type BuildInfo struct {
|
||||
Version string
|
||||
Commit string
|
||||
Date string
|
||||
}
|
||||
|
||||
func New(cfg config.Config, muxlog *logmon.Monitor, proxylog *logmon.Monitor, upstreamlog *logmon.Monitor, perfMon *perf.Monitor, build BuildInfo) (*Server, error) {
|
||||
var local router.LocalRouter
|
||||
var err error
|
||||
|
||||
if cfg.Matrix != nil {
|
||||
local, err = router.NewMatrix(cfg, proxylog, upstreamlog)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating matrix router: %w", err)
|
||||
}
|
||||
} else {
|
||||
local, err = router.NewGroup(cfg, proxylog, upstreamlog)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating group router: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
peer, err := router.NewPeer(cfg, proxylog)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating peer router: %w", err)
|
||||
}
|
||||
|
||||
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
|
||||
s := &Server{
|
||||
cfg: cfg,
|
||||
muxlog: muxlog,
|
||||
proxylog: proxylog,
|
||||
upstreamlog: upstreamlog,
|
||||
perf: perfMon,
|
||||
inflight: &inflightCounter{},
|
||||
metrics: newMetricsMonitor(proxylog, cfg.MetricsMaxInMemory, cfg.CaptureBuffer),
|
||||
build: build,
|
||||
local: local,
|
||||
peer: peer,
|
||||
shutdownCtx: shutdownCtx,
|
||||
shutdownFn: shutdownFn,
|
||||
}
|
||||
s.routes()
|
||||
s.startPreload()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// localPeerHandler dispatches a model-routed request to the local or peer
|
||||
// router. The model is resolved once via router.FetchContext.
|
||||
func (s *Server) localPeerHandler(w http.ResponseWriter, r *http.Request) {
|
||||
stripVersionPrefix(r)
|
||||
|
||||
data, err := router.FetchContext(r, s.cfg)
|
||||
if err != nil {
|
||||
router.SendError(w, r, router.ErrNoModelInContext)
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case s.local.Handles(data.ModelID):
|
||||
s.proxylog.Debugf("dispatch: using local process for model: %s", data.ModelID)
|
||||
s.local.ServeHTTP(w, r)
|
||||
case s.peer.Handles(data.ModelID):
|
||||
s.proxylog.Debugf("dispatch: using peer for model: %s", data.ModelID)
|
||||
s.peer.ServeHTTP(w, r)
|
||||
default:
|
||||
router.SendError(w, r, router.ErrNoRouterFound)
|
||||
}
|
||||
}
|
||||
|
||||
// stripVersionPrefix rewrites versionless /v/... requests to their /... form
|
||||
// before forwarding upstream (issue #728).
|
||||
func stripVersionPrefix(r *http.Request) {
|
||||
if strings.HasPrefix(r.URL.Path, "/v/") {
|
||||
r.URL.Path = strings.TrimPrefix(r.URL.Path, "/v")
|
||||
}
|
||||
}
|
||||
|
||||
// routes builds the mux, registers every route, and wraps the mux with the
|
||||
// global CORS middleware.
|
||||
func (s *Server) routes() {
|
||||
authMW := CreateAuthMiddleware(s.cfg)
|
||||
filterMW := CreateFilterMiddleware(s.cfg)
|
||||
formFilterMW := CreateFormFilterMiddleware(s.cfg)
|
||||
|
||||
// Model-dispatched routes get auth + per-model concurrency limiting + body
|
||||
// filters + in-flight tracking + token metrics. concurrencyMW rejects with
|
||||
// 429 before the body filters do any rewrite work. filterMW rewrites JSON
|
||||
// bodies and formFilterMW rewrites multipart bodies; each is a no-op for the
|
||||
// other's Content-Type. Both run before the metrics middleware so it buffers
|
||||
// the rewritten body.
|
||||
modelChain := chain.New(
|
||||
authMW,
|
||||
CreateConcurrencyMiddleware(s.cfg),
|
||||
filterMW,
|
||||
formFilterMW,
|
||||
CreateInflightMiddleware(s.inflight),
|
||||
CreateMetricsMiddleware(s.metrics, s.cfg),
|
||||
)
|
||||
// Custom endpoints only need auth.
|
||||
apiChain := chain.New(authMW)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
dispatch := http.HandlerFunc(s.localPeerHandler)
|
||||
|
||||
for _, path := range modelPostJSONRoutes {
|
||||
mux.Handle("POST "+path, modelChain.Then(dispatch))
|
||||
}
|
||||
for _, path := range modelPostFormRoutes {
|
||||
mux.Handle("POST "+path, modelChain.Then(dispatch))
|
||||
}
|
||||
for _, path := range modelGetRoutes {
|
||||
mux.Handle("GET "+path, modelChain.Then(dispatch))
|
||||
}
|
||||
|
||||
// llama-swap API + custom endpoints.
|
||||
mux.Handle("GET /v1/models", apiChain.ThenFunc(s.handleListModels))
|
||||
mux.Handle("GET /logs", apiChain.ThenFunc(s.handleLogs))
|
||||
mux.Handle("GET /logs/stream", apiChain.ThenFunc(s.handleLogStream))
|
||||
mux.Handle("GET /logs/stream/{logMonitorID...}", apiChain.ThenFunc(s.handleLogStream))
|
||||
|
||||
mux.HandleFunc("GET /health", handleHealth)
|
||||
mux.HandleFunc("GET /wol-health", handleHealth)
|
||||
mux.HandleFunc("GET /{$}", handleRootRedirect)
|
||||
|
||||
// Embedded UI.
|
||||
mux.HandleFunc("GET /ui/", s.handleUI)
|
||||
mux.HandleFunc("GET /favicon.ico", s.handleFavicon)
|
||||
|
||||
// Prometheus metrics (no auth, matches the legacy endpoint).
|
||||
mux.HandleFunc("GET /metrics", s.handleMetrics)
|
||||
|
||||
// Operations endpoints.
|
||||
mux.Handle("GET /unload", apiChain.ThenFunc(s.handleUnload))
|
||||
mux.Handle("GET /running", apiChain.ThenFunc(s.handleRunning))
|
||||
|
||||
// Upstream passthrough.
|
||||
mux.HandleFunc("GET /upstream", handleUpstreamRedirect)
|
||||
mux.Handle("/upstream/{upstreamPath...}", apiChain.ThenFunc(s.handleUpstream))
|
||||
|
||||
// API group (API-key protected) consumed by the UI.
|
||||
mux.Handle("POST /api/models/unload", apiChain.ThenFunc(s.handleAPIUnloadAll))
|
||||
mux.Handle("POST /api/models/unload/{model...}", apiChain.ThenFunc(s.handleAPIUnloadModel))
|
||||
mux.Handle("GET /api/events", apiChain.ThenFunc(s.handleAPIEvents))
|
||||
mux.Handle("GET /api/metrics", apiChain.ThenFunc(s.handleAPIMetrics))
|
||||
mux.Handle("GET /api/performance", apiChain.ThenFunc(s.handleAPIPerformance))
|
||||
mux.Handle("GET /api/version", apiChain.ThenFunc(s.handleAPIVersion))
|
||||
mux.Handle("GET /api/captures/{id}", apiChain.ThenFunc(s.handleAPICapture))
|
||||
|
||||
s.mux = mux
|
||||
s.handler = chain.New(CreateRequestLogMiddleware(s.proxylog), CreateCORSMiddleware()).Then(mux)
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// CloseStreams cancels long-lived response streams (Server-Sent Events) so a
|
||||
// graceful httpServer.Shutdown can drain without blocking on them. It does not
|
||||
// tear down routers; call Shutdown for that. Safe to call repeatedly.
|
||||
func (s *Server) CloseStreams() {
|
||||
s.shutdownFn()
|
||||
}
|
||||
|
||||
// Shutdown stops the local and peer routers in parallel. It is idempotent;
|
||||
// repeated calls return nil without re-running shutdown.
|
||||
//
|
||||
// Callers must drain inflight HTTP requests (httpServer.Shutdown) before
|
||||
// calling this, otherwise inflight requests 502 when their processes are torn
|
||||
// down. Call CloseStreams before httpServer.Shutdown so SSE streams do not
|
||||
// block the drain.
|
||||
func (s *Server) Shutdown(timeout time.Duration) error {
|
||||
if !s.shuttingDown.CompareAndSwap(false, true) {
|
||||
return nil
|
||||
}
|
||||
s.shutdownFn()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
var errs []error
|
||||
|
||||
for _, rt := range []router.Router{s.local, s.peer} {
|
||||
if rt == nil {
|
||||
continue
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(rt router.Router) {
|
||||
defer wg.Done()
|
||||
if err := rt.Shutdown(timeout); err != nil {
|
||||
mu.Lock()
|
||||
errs = append(errs, err)
|
||||
mu.Unlock()
|
||||
}
|
||||
}(rt)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
@@ -0,0 +1,331 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// stubRouter is a minimal router.LocalRouter for Server dispatch tests.
|
||||
type stubRouter struct {
|
||||
models map[string]bool
|
||||
response string
|
||||
shutdownCalls atomic.Int32
|
||||
running map[string]process.ProcessState
|
||||
unloadCalls atomic.Int32
|
||||
loggers map[string]*logmon.Monitor
|
||||
}
|
||||
|
||||
func newStubRouter(models []string, response string) *stubRouter {
|
||||
m := make(map[string]bool, len(models))
|
||||
for _, id := range models {
|
||||
m[id] = true
|
||||
}
|
||||
return &stubRouter{models: m, response: response}
|
||||
}
|
||||
|
||||
func (s *stubRouter) Handles(model string) bool { return s.models[model] }
|
||||
func (s *stubRouter) Shutdown(_ time.Duration) error { s.shutdownCalls.Add(1); return nil }
|
||||
func (s *stubRouter) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(s.response))
|
||||
}
|
||||
|
||||
func (s *stubRouter) RunningModels() map[string]process.ProcessState { return s.running }
|
||||
func (s *stubRouter) Unload(_ time.Duration, _ ...string) { s.unloadCalls.Add(1) }
|
||||
func (s *stubRouter) ProcessLogger(modelID string) (*logmon.Monitor, bool) {
|
||||
if s.loggers != nil {
|
||||
if lg, ok := s.loggers[modelID]; ok {
|
||||
return lg, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// newTestServer wires a Server with stub routers and a built mux.
|
||||
func newTestServer(local router.LocalRouter, peer router.Router) *Server {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
proxylog := logmon.NewWriter(io.Discard)
|
||||
s := &Server{
|
||||
cfg: config.Config{},
|
||||
muxlog: logmon.NewWriter(io.Discard),
|
||||
proxylog: proxylog,
|
||||
upstreamlog: logmon.NewWriter(io.Discard),
|
||||
inflight: &inflightCounter{},
|
||||
metrics: newMetricsMonitor(proxylog, 0, 0),
|
||||
local: local,
|
||||
peer: peer,
|
||||
shutdownCtx: ctx,
|
||||
shutdownFn: cancel,
|
||||
}
|
||||
s.routes()
|
||||
return s
|
||||
}
|
||||
|
||||
func chatRequest(model string) *http.Request {
|
||||
body := strings.NewReader(`{"model":"` + model + `"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
return req
|
||||
}
|
||||
|
||||
func TestServer_New_GroupConfig(t *testing.T) {
|
||||
discard := logmon.NewWriter(io.Discard)
|
||||
s, err := New(config.Config{HealthCheckTimeout: 15}, discard, discard, discard, nil, BuildInfo{})
|
||||
if err != nil {
|
||||
t.Fatalf("New (group): %v", err)
|
||||
}
|
||||
if err := s.Shutdown(time.Second); err != nil {
|
||||
t.Fatalf("Shutdown: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_New_MatrixConfig(t *testing.T) {
|
||||
discard := logmon.NewWriter(io.Discard)
|
||||
cfg := config.Config{HealthCheckTimeout: 15, Matrix: &config.MatrixConfig{}}
|
||||
s, err := New(cfg, discard, discard, discard, nil, BuildInfo{})
|
||||
if err != nil {
|
||||
t.Fatalf("New (matrix): %v", err)
|
||||
}
|
||||
if err := s.Shutdown(time.Second); err != nil {
|
||||
t.Fatalf("Shutdown: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_RouteToLocalModel(t *testing.T) {
|
||||
s := newTestServer(
|
||||
newStubRouter([]string{"local-model"}, "local response"),
|
||||
newStubRouter(nil, ""),
|
||||
)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, chatRequest("local-model"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if w.Body.String() != "local response" {
|
||||
t.Errorf("body=%q want %q", w.Body.String(), "local response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_RouteToPeerModel(t *testing.T) {
|
||||
s := newTestServer(
|
||||
newStubRouter(nil, ""),
|
||||
newStubRouter([]string{"peer-model"}, "peer response"),
|
||||
)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, chatRequest("peer-model"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if w.Body.String() != "peer response" {
|
||||
t.Errorf("body=%q want %q", w.Body.String(), "peer response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_UnknownModelReturns404(t *testing.T) {
|
||||
s := newTestServer(
|
||||
newStubRouter([]string{"local-model"}, ""),
|
||||
newStubRouter(nil, ""),
|
||||
)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, chatRequest("unknown-model"))
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("status=%d want 404 body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_UnknownPathReturns404(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/does-not-exist", nil))
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("status=%d want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Health(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
for _, path := range []string{"/health", "/wol-health"} {
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, path, nil))
|
||||
if w.Code != http.StatusOK || w.Body.String() != "OK" {
|
||||
t.Errorf("%s: status=%d body=%q", path, w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_CORSPreflight(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Fatalf("status=%d want 204", w.Code)
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
|
||||
t.Errorf("Access-Control-Allow-Origin=%q want *", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Unload(t *testing.T) {
|
||||
local := newStubRouter([]string{"m1"}, "")
|
||||
s := newTestServer(local, newStubRouter(nil, ""))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/unload", nil))
|
||||
|
||||
if w.Code != http.StatusOK || w.Body.String() != "OK" {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := local.unloadCalls.Load(); got != 1 {
|
||||
t.Errorf("unloadCalls=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Running(t *testing.T) {
|
||||
local := newStubRouter([]string{"m1"}, "")
|
||||
local.running = map[string]process.ProcessState{"m1": process.StateReady}
|
||||
s := newTestServer(local, newStubRouter(nil, ""))
|
||||
s.cfg = config.Config{Models: map[string]config.ModelConfig{
|
||||
"m1": {
|
||||
Cmd: "llama-server",
|
||||
Proxy: "http://localhost:9999",
|
||||
UnloadAfter: 300,
|
||||
Name: "Model One",
|
||||
Description: "the first model",
|
||||
},
|
||||
}}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/running", nil))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Running []runningModel `json:"running"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("decode: %v body=%q", err, w.Body.String())
|
||||
}
|
||||
if len(resp.Running) != 1 {
|
||||
t.Fatalf("running=%v want 1 entry", resp.Running)
|
||||
}
|
||||
want := runningModel{
|
||||
Model: "m1",
|
||||
State: "ready",
|
||||
Cmd: "llama-server",
|
||||
Proxy: "http://localhost:9999",
|
||||
TTL: 300,
|
||||
Name: "Model One",
|
||||
Description: "the first model",
|
||||
}
|
||||
if resp.Running[0] != want {
|
||||
t.Errorf("got %+v want %+v", resp.Running[0], want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Preload(t *testing.T) {
|
||||
local := newStubRouter([]string{"m1"}, "ok")
|
||||
s := newTestServer(local, newStubRouter(nil, ""))
|
||||
s.cfg = config.Config{Hooks: config.HooksConfig{
|
||||
OnStartup: config.HookOnStartup{Preload: []string{"m1"}},
|
||||
}}
|
||||
|
||||
got := make(chan shared.ModelPreloadedEvent, 1)
|
||||
cancel := event.On(func(e shared.ModelPreloadedEvent) { got <- e })
|
||||
defer cancel()
|
||||
|
||||
s.startPreload()
|
||||
|
||||
select {
|
||||
case e := <-got:
|
||||
if e.ModelName != "m1" || !e.Success {
|
||||
t.Errorf("event=%+v want {ModelName:m1 Success:true}", e)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("preload event not received")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Shutdown_StopsRoutersAndIsIdempotent(t *testing.T) {
|
||||
local := newStubRouter([]string{"local-model"}, "")
|
||||
peer := newStubRouter(nil, "")
|
||||
s := newTestServer(local, peer)
|
||||
|
||||
if err := s.Shutdown(time.Second); err != nil {
|
||||
t.Fatalf("Shutdown: %v", err)
|
||||
}
|
||||
if err := s.Shutdown(time.Second); err != nil {
|
||||
t.Fatalf("second Shutdown: %v", err)
|
||||
}
|
||||
if got := local.shutdownCalls.Load(); got != 1 {
|
||||
t.Errorf("local shutdownCalls=%d want 1", got)
|
||||
}
|
||||
if got := peer.shutdownCalls.Load(); got != 1 {
|
||||
t.Errorf("peer shutdownCalls=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_LogStream_ModelID(t *testing.T) {
|
||||
buf := logmon.NewWriter(io.Discard)
|
||||
buf.Write([]byte("hello from model"))
|
||||
|
||||
local := newStubRouter([]string{"mymodel"}, "")
|
||||
local.loggers = map[string]*logmon.Monitor{"mymodel": buf}
|
||||
|
||||
s := newTestServer(local, newStubRouter(nil, ""))
|
||||
s.cfg = config.Config{Models: map[string]config.ModelConfig{"mymodel": {}}}
|
||||
|
||||
// Pre-cancel the context so the streaming loop exits immediately after
|
||||
// flushing history.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs/stream/mymodel", nil).WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := w.Body.String(); got != "hello from model" {
|
||||
t.Errorf("body=%q want %q", got, "hello from model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_LogStream_UnknownID_Returns400(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/logs/stream/no-such-model", nil))
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status=%d want 400", w.Code)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// uiStaticFS holds the embedded UI build. The build is copied into ui_dist by
|
||||
// the Makefile's `ui` target; placeholder.txt keeps the embed valid before a
|
||||
// build has run.
|
||||
//
|
||||
//go:embed ui_dist
|
||||
var uiStaticFS embed.FS
|
||||
|
||||
// uiFS is the embedded UI rooted at ui_dist.
|
||||
var uiFS = func() http.FileSystem {
|
||||
sub, err := fs.Sub(uiStaticFS, "ui_dist")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return http.FS(sub)
|
||||
}()
|
||||
|
||||
// selectEncoding chooses the best pre-compressed encoding the client accepts.
|
||||
// It returns the encoding ("br" or "gzip") and the matching file extension.
|
||||
func selectEncoding(acceptEncoding string) (encoding, ext string) {
|
||||
if acceptEncoding == "" {
|
||||
return "", ""
|
||||
}
|
||||
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||
if strings.TrimSpace(strings.SplitN(part, ";", 2)[0]) == "br" {
|
||||
return "br", ".br"
|
||||
}
|
||||
}
|
||||
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||
if strings.TrimSpace(strings.SplitN(part, ";", 2)[0]) == "gzip" {
|
||||
return "gzip", ".gz"
|
||||
}
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// serveCompressedFile serves name from fsys, preferring a pre-compressed
|
||||
// sibling (name+".br" / name+".gz") when the client accepts it. It returns an
|
||||
// error without writing a response when name cannot be served, so callers can
|
||||
// fall back (e.g. SPA routing).
|
||||
func serveCompressedFile(fsys http.FileSystem, w http.ResponseWriter, r *http.Request, name string) error {
|
||||
if encoding, ext := selectEncoding(r.Header.Get("Accept-Encoding")); encoding != "" {
|
||||
if cf, err := fsys.Open(name + ext); err == nil {
|
||||
defer cf.Close()
|
||||
if stat, err := cf.Stat(); err == nil && !stat.IsDir() {
|
||||
w.Header().Set("Content-Encoding", encoding)
|
||||
w.Header().Add("Vary", "Accept-Encoding")
|
||||
http.ServeContent(w, r, name, stat.ModTime(), cf)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
file, err := fsys.Open(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if stat.IsDir() {
|
||||
return fs.ErrNotExist
|
||||
}
|
||||
|
||||
http.ServeContent(w, r, name, stat.ModTime(), file)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleUI serves the embedded SPA under /ui/.
|
||||
func (s *Server) handleUI(w http.ResponseWriter, r *http.Request) {
|
||||
serveUI(uiFS, w, r)
|
||||
}
|
||||
|
||||
// serveUI serves the SPA from fsys. Real files are served with compression
|
||||
// support; unknown paths without a file extension fall back to index.html so
|
||||
// client-side routing works.
|
||||
func serveUI(fsys http.FileSystem, w http.ResponseWriter, r *http.Request) {
|
||||
name := strings.TrimPrefix(r.URL.Path, "/ui/")
|
||||
if name == "" {
|
||||
name = "index.html"
|
||||
}
|
||||
|
||||
if err := serveCompressedFile(fsys, w, r, name); err != nil {
|
||||
if strings.Contains(path.Base(name), ".") {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
if err := serveCompressedFile(fsys, w, r, "index.html"); err != nil {
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleFavicon serves /favicon.ico from the embedded UI build.
|
||||
func (s *Server) handleFavicon(w http.ResponseWriter, r *http.Request) {
|
||||
if err := serveCompressedFile(uiFS, w, r, "favicon.ico"); err != nil {
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
placeholder so //go:embed ui_dist succeeds before the UI is built
|
||||
@@ -0,0 +1,92 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
)
|
||||
|
||||
func TestServer_SelectEncoding(t *testing.T) {
|
||||
cases := []struct {
|
||||
accept string
|
||||
encoding string
|
||||
ext string
|
||||
}{
|
||||
{"", "", ""},
|
||||
{"gzip", "gzip", ".gz"},
|
||||
{"gzip, deflate, br", "br", ".br"},
|
||||
{"deflate", "", ""},
|
||||
{"br;q=1.0, gzip;q=0.8", "br", ".br"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
enc, ext := selectEncoding(c.accept)
|
||||
if enc != c.encoding || ext != c.ext {
|
||||
t.Errorf("selectEncoding(%q) = (%q, %q), want (%q, %q)", c.accept, enc, ext, c.encoding, c.ext)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func uiTestFS() http.FileSystem {
|
||||
return http.FS(fstest.MapFS{
|
||||
"index.html": {Data: []byte("<html>app</html>")},
|
||||
"app.js": {Data: []byte("plain")},
|
||||
"app.js.br": {Data: []byte("brotli")},
|
||||
"app.js.gz": {Data: []byte("gzipped")},
|
||||
"favicon.ico": {Data: []byte("icon")},
|
||||
})
|
||||
}
|
||||
|
||||
func serveUIRequest(t *testing.T, path, acceptEncoding string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
if acceptEncoding != "" {
|
||||
req.Header.Set("Accept-Encoding", acceptEncoding)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
serveUI(uiTestFS(), w, req)
|
||||
return w
|
||||
}
|
||||
|
||||
func TestServer_ServeUI_File(t *testing.T) {
|
||||
w := serveUIRequest(t, "/ui/app.js", "")
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want 200", w.Code)
|
||||
}
|
||||
if w.Body.String() != "plain" {
|
||||
t.Errorf("body = %q, want plain", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ServeUI_Brotli(t *testing.T) {
|
||||
w := serveUIRequest(t, "/ui/app.js", "gzip, br")
|
||||
if got := w.Header().Get("Content-Encoding"); got != "br" {
|
||||
t.Fatalf("Content-Encoding = %q, want br", got)
|
||||
}
|
||||
if w.Body.String() != "brotli" {
|
||||
t.Errorf("body = %q, want brotli", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ServeUI_IndexAndRoot(t *testing.T) {
|
||||
for _, path := range []string{"/ui/", "/ui/index.html"} {
|
||||
w := serveUIRequest(t, path, "")
|
||||
if w.Code != http.StatusOK || w.Body.String() != "<html>app</html>" {
|
||||
t.Errorf("%s: status=%d body=%q", path, w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ServeUI_SPAFallback(t *testing.T) {
|
||||
w := serveUIRequest(t, "/ui/models", "")
|
||||
if w.Code != http.StatusOK || w.Body.String() != "<html>app</html>" {
|
||||
t.Errorf("SPA fallback: status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ServeUI_MissingFile(t *testing.T) {
|
||||
w := serveUIRequest(t, "/ui/missing.js", "")
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("status = %d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package shared
|
||||
|
||||
const ProcessStateChangeEventID = 0x01
|
||||
const ConfigFileChangedEventID = 0x03
|
||||
const ActivityLogEventID = 0x05
|
||||
const ModelPreloadedEventID = 0x06
|
||||
const InFlightRequestsEventID = 0x07
|
||||
|
||||
// ProcessStateChangeEvent is emitted whenever a process transitions between
|
||||
// lifecycle states. States are carried as strings so this package stays a leaf
|
||||
// (no import of internal/process).
|
||||
type ProcessStateChangeEvent struct {
|
||||
ProcessName string
|
||||
OldState string
|
||||
NewState string
|
||||
}
|
||||
|
||||
func (e ProcessStateChangeEvent) Type() uint32 {
|
||||
return ProcessStateChangeEventID
|
||||
}
|
||||
|
||||
type ReloadingState int
|
||||
|
||||
const (
|
||||
ReloadingStateStart ReloadingState = iota
|
||||
ReloadingStateEnd
|
||||
)
|
||||
|
||||
type ConfigFileChangedEvent struct {
|
||||
State ReloadingState
|
||||
}
|
||||
|
||||
func (e ConfigFileChangedEvent) Type() uint32 {
|
||||
return ConfigFileChangedEventID
|
||||
}
|
||||
|
||||
type ModelPreloadedEvent struct {
|
||||
ModelName string
|
||||
Success bool
|
||||
}
|
||||
|
||||
func (e ModelPreloadedEvent) Type() uint32 {
|
||||
return ModelPreloadedEventID
|
||||
}
|
||||
|
||||
type InFlightRequestsEvent struct {
|
||||
Total int
|
||||
}
|
||||
|
||||
func (e InFlightRequestsEvent) Type() uint32 {
|
||||
return InFlightRequestsEventID
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
// Package configwatcher provides a simple cross-platform file watcher based
|
||||
// on os.Stat polling. It works correctly inside Docker containers where the
|
||||
// config file is bind-mounted as an individual file, and for k8s ConfigMap
|
||||
// projections (which present the file as a symlink to an atomically swapped
|
||||
// target) — both cases where inotify-based watchers are unreliable.
|
||||
package configwatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io/fs"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
const DefaultInterval = 2 * time.Second
|
||||
|
||||
type Watcher struct {
|
||||
Path string
|
||||
Interval time.Duration
|
||||
OnChange func()
|
||||
}
|
||||
|
||||
type snapshot struct {
|
||||
exists bool
|
||||
modTime time.Time
|
||||
size int64
|
||||
}
|
||||
|
||||
// Run blocks until ctx is canceled. It polls Path on Interval and invokes
|
||||
// OnChange whenever the file's modification time or size changes, or when
|
||||
// the file reappears after being missing. The baseline poll establishes
|
||||
// initial state and does not fire OnChange.
|
||||
func (w *Watcher) Run(ctx context.Context) {
|
||||
interval := w.Interval
|
||||
if interval <= 0 {
|
||||
interval = DefaultInterval
|
||||
}
|
||||
|
||||
prev := stat(w.Path)
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
cur := stat(w.Path)
|
||||
if changed(prev, cur) && w.OnChange != nil {
|
||||
w.OnChange()
|
||||
}
|
||||
prev = cur
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stat(path string) snapshot {
|
||||
fi, err := os.Stat(path)
|
||||
if err != nil {
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
log.Printf("configwatcher: stat %s: %v", path, err)
|
||||
}
|
||||
return snapshot{}
|
||||
}
|
||||
return snapshot{
|
||||
exists: true,
|
||||
modTime: fi.ModTime(),
|
||||
size: fi.Size(),
|
||||
}
|
||||
}
|
||||
|
||||
func changed(prev, cur snapshot) bool {
|
||||
// Present → missing: stay quiet (likely a transient rename-style write).
|
||||
// Missing → present: fire so we reload as soon as the file comes back.
|
||||
if !cur.exists {
|
||||
return false
|
||||
}
|
||||
if !prev.exists {
|
||||
return true
|
||||
}
|
||||
return !prev.modTime.Equal(cur.modTime) || prev.size != cur.size
|
||||
}
|
||||
@@ -0,0 +1,191 @@
|
||||
package configwatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testInterval = 25 * time.Millisecond
|
||||
|
||||
// startWatcher launches w.Run in a goroutine and returns a function that
|
||||
// cancels the context and waits for Run to return.
|
||||
func startWatcher(t *testing.T, w *Watcher) func() {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
w.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
return func() {
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("watcher did not stop within 2s of cancel")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// waitForCount blocks until counter reaches want or timeout elapses.
|
||||
func waitForCount(t *testing.T, counter *int64, want int64, timeout time.Duration) bool {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if atomic.LoadInt64(counter) >= want {
|
||||
return true
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestWatcher_NoFireOnBaseline(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
|
||||
|
||||
var n int64
|
||||
stop := startWatcher(t, &Watcher{
|
||||
Path: path,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
|
||||
time.Sleep(testInterval * 5)
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&n), "baseline poll must not fire")
|
||||
}
|
||||
|
||||
func TestWatcher_DetectsModTimeChange(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
|
||||
|
||||
// Force a known baseline mtime.
|
||||
base := time.Now().Add(-1 * time.Hour).Truncate(time.Second)
|
||||
require.NoError(t, os.Chtimes(path, base, base))
|
||||
|
||||
var n int64
|
||||
stop := startWatcher(t, &Watcher{
|
||||
Path: path,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
|
||||
// Let the baseline settle.
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
// Bump mtime well above the baseline so low-resolution filesystems still notice.
|
||||
require.NoError(t, os.Chtimes(path, base.Add(10*time.Second), base.Add(10*time.Second)))
|
||||
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire after mtime change")
|
||||
}
|
||||
|
||||
func TestWatcher_DetectsSizeChangeWithSameModTime(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
|
||||
|
||||
fi, err := os.Stat(path)
|
||||
require.NoError(t, err)
|
||||
originalMtime := fi.ModTime()
|
||||
|
||||
var n int64
|
||||
stop := startWatcher(t, &Watcher{
|
||||
Path: path,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
require.NoError(t, os.WriteFile(path, []byte("aaaaa"), 0o644))
|
||||
// Reset mtime back to the original so size is the only signal.
|
||||
require.NoError(t, os.Chtimes(path, originalMtime, originalMtime))
|
||||
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire on size change")
|
||||
}
|
||||
|
||||
func TestWatcher_SymlinkTargetSwap(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
targetA := filepath.Join(dir, "targetA")
|
||||
targetB := filepath.Join(dir, "targetB")
|
||||
link := filepath.Join(dir, "config.yaml")
|
||||
|
||||
require.NoError(t, os.WriteFile(targetA, []byte("AAAA"), 0o644))
|
||||
require.NoError(t, os.WriteFile(targetB, []byte("BBBBBBBB"), 0o644))
|
||||
|
||||
if err := os.Symlink(targetA, link); err != nil {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skipf("symlink creation requires privilege on Windows: %v", err)
|
||||
}
|
||||
t.Fatalf("os.Symlink: %v", err)
|
||||
}
|
||||
|
||||
var n int64
|
||||
stop := startWatcher(t, &Watcher{
|
||||
Path: link,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
// Atomic symlink swap (k8s ConfigMap pattern): create new symlink at a
|
||||
// temp name, then rename over the existing one.
|
||||
tmpLink := filepath.Join(dir, "config.yaml.tmp")
|
||||
require.NoError(t, os.Symlink(targetB, tmpLink))
|
||||
require.NoError(t, os.Rename(tmpLink, link))
|
||||
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire after symlink target swap")
|
||||
}
|
||||
|
||||
func TestWatcher_FileMissingThenReturns(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
|
||||
|
||||
var n int64
|
||||
stop := startWatcher(t, &Watcher{
|
||||
Path: path,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
require.NoError(t, os.Remove(path))
|
||||
time.Sleep(testInterval * 3)
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&n), "removal alone must not fire")
|
||||
|
||||
require.NoError(t, os.WriteFile(path, []byte("b"), 0o644))
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when file returns")
|
||||
}
|
||||
|
||||
func TestWatcher_ContextCancelStopsRun(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
|
||||
|
||||
w := &Watcher{Path: path, Interval: testInterval}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() { w.Run(ctx); close(done) }()
|
||||
|
||||
time.Sleep(testInterval * 2)
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Run did not return within 2s of cancel")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user