Introduce new routing backend (#790)
This is a huge backend change that essentially started with rewriting the concurrency handling for processes and blew up to a refactor of the entire application. In short these are the improvements: **Better state and life cycle management:** Life cycle management of processes has always been the trickiest part of the code. Juggling mutex locks between multiple locations to reduce race conditions was complex. Too complex for my feeble brain to build a simple mental model around as llama-swap gained more features. All of that has been refactored. Most of the locks are gone, replaced with a single run() that owns all state changes. There is one place to start from now to understand and extend routing logic. The improved life cycle management makes it easier to implement more complex swap optimization strategies in the future like #727. **Collation of requests:** llama-swap previously handled requests and swapping in the order they came in. For example requests for models in this order ABCABC would result in 5 swaps. Now those requests are handled in this order AABBCC. The result is less time waiting for swap under a high churn request queue. This fixes #588 #612. A possible future enhancement is to support a starvation parameter so swap can be forced when models have been waiting too long. **Shared base implementation for groups and swap matrix:** During the refactor it became clear that much of the swapping logic was shared between these two implementations. That is not surprising considering the swap matrix was added many moons after groups. Now they share a common base and their specific swap strategies are implemented into the swapPlanner interface. Requests for bespoke or specific swapping scenarios is a common theme in the issues. Now users can implement whatever bespoke and weird swapping strategy they want in their own fork. Just ask your agent of choice to implement swapPlanner. I'll still remaining more conservative on what actually lands in core llama-swap and will continue to evaluate PRs if the changes is good for everyone or just one specific use case. **AI / Agentic Disclosure:** I paid very close attention to the low level swap concurrency design and implementation. It's important to keep that essential part reliable, boring and no surprises. Backwards compatibility was also maintained, even the one way non-exclusive group model loading behaviour that people have rightly pointed out be a weird design decision. With the underlying swap core done the web server, api and UI sitting on top were largely ported over with Claude Code and Opus 4.7 in multiple phases. If you're curious I kept the changes in docs/newrouter-todo.md. I did several passes to make sure things weren't left behind. However, even frontier LLMs at the time of this PR still make small decisions that don't make a lot of sense. They get shit wrong all the time, just in small subtle way. That said, there's likely to be some new bugs introduced with this massive refactor. I'm fairly confident that there's no major architectural flaws that would cause goal seeking agents to make dumb, ugly code decisions. For a little while the legacy llama-swap will be available under cmd/legacy/llama-swap. The plan is to eventually delete that entry point as well as the proxy package. On a bit of a personal note, this PR is exciting and a bit sad for me. I hand wrote much of the original code and this PR ultimately replaces much of it. While the old code served as a good reference for the agent to implement the new stuff it still a bit sad to eventually delete it all.
This commit is contained in:
@@ -0,0 +1,49 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var simpleResponderPath string
|
||||
|
||||
func skipIfNoSimpleResponder(t *testing.T) {
|
||||
t.Helper()
|
||||
if _, err := os.Stat(simpleResponderPath); os.IsNotExist(err) {
|
||||
t.Skipf("simple-responder not found at %s, run `make simple-responder`", simpleResponderPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goos := runtime.GOOS
|
||||
goarch := runtime.GOARCH
|
||||
if goos == "windows" {
|
||||
simpleResponderPath = filepath.Join("..", "..", "build", "simple-responder.exe")
|
||||
} else {
|
||||
simpleResponderPath = filepath.Join("..", "..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||
}
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func getFreePort(t *testing.T) int {
|
||||
t.Helper()
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("getFreePort: %v", err)
|
||||
}
|
||||
defer l.Close()
|
||||
return l.Addr().(*net.TCPAddr).Port
|
||||
}
|
||||
|
||||
func simpleResponderCmd(t *testing.T, args ...string) (string, int) {
|
||||
port := getFreePort(t)
|
||||
cmdPath := filepath.ToSlash(simpleResponderPath)
|
||||
base := []string{cmdPath, fmt.Sprintf("-port %d", port)}
|
||||
base = append(base, args...)
|
||||
return strings.Join(base, " "), port
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
type ProcessState string
|
||||
|
||||
const (
|
||||
StateStopped ProcessState = ProcessState("stopped")
|
||||
StateStarting ProcessState = ProcessState("starting")
|
||||
StateReady ProcessState = ProcessState("ready")
|
||||
StateStopping ProcessState = ProcessState("stopping")
|
||||
|
||||
// process is shutdown and will not be restarted
|
||||
StateShutdown ProcessState = ProcessState("shutdown")
|
||||
)
|
||||
|
||||
type Process interface {
|
||||
// Run starts the process blocks until the process is terminated.
|
||||
// The timeout parameter controls how long to wait for the process to get
|
||||
// to a ready state to process traffic
|
||||
Run(timeout time.Duration) error
|
||||
|
||||
// WaitReady blocks until the process is ready to serve requests
|
||||
// or the context is cancelled. It returns nil when the process is ready
|
||||
WaitReady(context.Context) error
|
||||
|
||||
// Stop blocks until the process has terminated. It returns nil when
|
||||
// the process terminated as expected (exit 0)
|
||||
Stop(timeout time.Duration) error
|
||||
|
||||
// State returns the current state of the process
|
||||
// Note: this is a snapshot of the state at the time of the call
|
||||
// and may change at any time after the call returns.
|
||||
State() ProcessState
|
||||
|
||||
// ServeHTTP forwards requests to the underlying process
|
||||
// Calling it when the process is not ready will result in a
|
||||
// 503 response with a body indicating it is a llama-swap-error
|
||||
ServeHTTP(http.ResponseWriter, *http.Request)
|
||||
|
||||
// Logger returns the monitor that captures this process's stdout/stderr.
|
||||
Logger() *logmon.Monitor
|
||||
}
|
||||
@@ -0,0 +1,568 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
var ErrStartAborted = fmt.Errorf("aborted")
|
||||
|
||||
type runReq struct {
|
||||
timeout time.Duration
|
||||
respond chan error
|
||||
}
|
||||
|
||||
type stopReq struct {
|
||||
timeout time.Duration
|
||||
respond chan error
|
||||
}
|
||||
|
||||
type waitReadyReq struct {
|
||||
respond chan error
|
||||
}
|
||||
|
||||
type startResult struct {
|
||||
cmd *exec.Cmd
|
||||
cmdDone chan struct{}
|
||||
handlerFn http.HandlerFunc
|
||||
err error
|
||||
}
|
||||
|
||||
type ProcessCommand struct {
|
||||
id string
|
||||
config config.ModelConfig
|
||||
parentCtx context.Context
|
||||
|
||||
processLogger *logmon.Monitor
|
||||
proxyLogger *logmon.Monitor
|
||||
|
||||
runCh chan runReq
|
||||
stopCh chan stopReq
|
||||
waitReadyCh chan waitReadyReq
|
||||
|
||||
// current ProcessState. Written only by run(); read by State() via atomic load.
|
||||
state atomic.Value
|
||||
|
||||
// stores the active reverse-proxy handler when the process is running.
|
||||
// Written only by run(); read by ServeHTTP via atomic load.
|
||||
handler atomic.Pointer[http.HandlerFunc]
|
||||
|
||||
lastUse atomic.Int64 // unix nano timestamp of last ServeHTTP completion
|
||||
inflight atomic.Int64 // current in-flight ServeHTTP calls
|
||||
}
|
||||
|
||||
var _ Process = (*ProcessCommand)(nil)
|
||||
|
||||
func New(
|
||||
parentCtx context.Context,
|
||||
id string,
|
||||
conf config.ModelConfig,
|
||||
processLogger *logmon.Monitor,
|
||||
proxyLogger *logmon.Monitor,
|
||||
) (*ProcessCommand, error) {
|
||||
p := &ProcessCommand{
|
||||
id: id,
|
||||
config: conf,
|
||||
parentCtx: parentCtx,
|
||||
processLogger: processLogger,
|
||||
proxyLogger: proxyLogger,
|
||||
|
||||
runCh: make(chan runReq),
|
||||
stopCh: make(chan stopReq),
|
||||
waitReadyCh: make(chan waitReadyReq),
|
||||
}
|
||||
p.state.Store(StateStopped)
|
||||
|
||||
go p.run()
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) Logger() *logmon.Monitor { return p.processLogger }
|
||||
|
||||
// run is the single-writer goroutine that owns all mutable lifecycle state
|
||||
// (current ProcessState, the running *exec.Cmd, the active reverse-proxy
|
||||
// handler, and the list of WaitReady subscribers). Every public method
|
||||
// (Run / Stop / State / WaitReady) is a thin client that sends a request on
|
||||
// one of the channels below and waits for a response — this funnels concurrent
|
||||
// callers through a single serialization point so the state machine never
|
||||
// observes a race.
|
||||
func (p *ProcessCommand) run() {
|
||||
// Mutable state — only read/written from this goroutine. ServeHTTP reads
|
||||
// p.handler concurrently, which is why handler is an atomic.Pointer.
|
||||
// p.state mirrors `state` so State() can observe transitions; setState
|
||||
// writes both.
|
||||
state := StateStopped
|
||||
setState := func(s ProcessState) {
|
||||
old := state
|
||||
state = s
|
||||
p.state.Store(s)
|
||||
if old != s {
|
||||
event.Emit(shared.ProcessStateChangeEvent{
|
||||
ProcessName: p.id,
|
||||
OldState: string(old),
|
||||
NewState: string(s),
|
||||
})
|
||||
}
|
||||
}
|
||||
var (
|
||||
cmd *exec.Cmd
|
||||
cmdDone <-chan struct{}
|
||||
readyWaiters []waitReadyReq
|
||||
// runResp parks the in-flight Run caller's response channel. The
|
||||
// interface contract is that Run blocks until the process is
|
||||
// terminated, so we hold this until Stop, parentCtx, or an
|
||||
// upstream exit unblocks it via respondRun.
|
||||
runResp chan<- error
|
||||
)
|
||||
|
||||
// notifyWaiters wakes every blocked WaitReady caller with the given result.
|
||||
// Used on transitions out of StateStarting (ready, failed, aborted, or
|
||||
// shutdown) — anything that resolves the "is it ready yet?" question.
|
||||
notifyWaiters := func(err error) {
|
||||
for _, w := range readyWaiters {
|
||||
select {
|
||||
case w.respond <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
readyWaiters = nil
|
||||
}
|
||||
|
||||
// respondRun delivers the final Run result, if a Run caller is parked.
|
||||
respondRun := func(err error) {
|
||||
if runResp != nil {
|
||||
runResp <- err
|
||||
runResp = nil
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
// Shutdown: parent context cancelled. Tear down any running process,
|
||||
// wake any pending WaitReady callers with an error, then exit the
|
||||
// goroutine permanently. Subsequent public-method calls will fail
|
||||
// because parentCtx.Done() unblocks their send-side selects.
|
||||
case <-p.parentCtx.Done():
|
||||
// Mark shutdown before killProcess so concurrent State() readers
|
||||
// stop treating this process as ready while the (possibly slow)
|
||||
// teardown is in progress.
|
||||
setState(StateShutdown)
|
||||
if cmd != nil {
|
||||
p.handler.Store(nil)
|
||||
p.killProcess(cmd, cmdDone, 100*time.Millisecond)
|
||||
cmd = nil
|
||||
cmdDone = nil
|
||||
}
|
||||
notifyWaiters(fmt.Errorf("[%s] shutdown", p.id))
|
||||
respondRun(fmt.Errorf("[%s] shutdown", p.id))
|
||||
return
|
||||
|
||||
// Upstream exited on its own (not via Stop). Drop handler state,
|
||||
// transition to Stopped, and unblock the parked Run caller.
|
||||
// cmdDone is nil while no process is running, so this case is
|
||||
// dormant outside of StateReady.
|
||||
case <-cmdDone:
|
||||
cmd = nil
|
||||
cmdDone = nil
|
||||
p.handler.Store(nil)
|
||||
setState(StateStopped)
|
||||
respondRun(fmt.Errorf("[%s] upstream exited unexpectedly", p.id))
|
||||
|
||||
// WaitReady: if we're already in a terminal-for-this-question state,
|
||||
// respond immediately; otherwise queue the caller and let a future
|
||||
// state transition wake them via notifyWaiters.
|
||||
case req := <-p.waitReadyCh:
|
||||
switch state {
|
||||
case StateReady:
|
||||
req.respond <- nil
|
||||
case StateShutdown:
|
||||
req.respond <- fmt.Errorf("[%s] shutdown", p.id)
|
||||
default:
|
||||
readyWaiters = append(readyWaiters, req)
|
||||
}
|
||||
|
||||
// Run: start the upstream process. Only valid from StateStopped.
|
||||
// doStart can take a long time (health-check polling), so it runs in
|
||||
// a separate goroutine and we wait on resultCh. While waiting we also
|
||||
// listen for an incoming Stop — that's how callers cancel an in-flight
|
||||
// start.
|
||||
case req := <-p.runCh:
|
||||
if state != StateStopped {
|
||||
req.respond <- fmt.Errorf("[%s] could not be started in %s state", p.id, state)
|
||||
continue
|
||||
}
|
||||
setState(StateStarting)
|
||||
|
||||
startCtx, cancelStart := context.WithCancel(context.Background())
|
||||
resultCh := make(chan startResult, 1)
|
||||
go func() {
|
||||
resultCh <- p.doStart(startCtx, req.timeout)
|
||||
}()
|
||||
|
||||
// pendingStop holds a Stop request that arrived mid-start, so we
|
||||
// can respond to it AFTER we've finished tearing the start down.
|
||||
var pendingStop *stopReq
|
||||
select {
|
||||
// doStart finished on its own — either successfully (latch
|
||||
// cmd/handler and move to Ready) or with an error (back to
|
||||
// Stopped). Either way wake WaitReady subscribers and reply
|
||||
// to the Run caller.
|
||||
case res := <-resultCh:
|
||||
if res.err == nil {
|
||||
cmd = res.cmd
|
||||
cmdDone = res.cmdDone
|
||||
fn := res.handlerFn
|
||||
p.handler.Store(&fn)
|
||||
setState(StateReady)
|
||||
notifyWaiters(nil)
|
||||
// Park the Run response — Run blocks until the process
|
||||
// terminates, so we only fire this when Stop, parentCtx,
|
||||
// or the upstream exit takes the process down.
|
||||
runResp = req.respond
|
||||
|
||||
// Start TTL goroutine if configured — self-terminates
|
||||
// when state leaves StateReady.
|
||||
if p.config.UnloadAfter > 0 {
|
||||
ttlDuration := time.Duration(p.config.UnloadAfter) * time.Second
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
if p.State() != StateReady {
|
||||
return
|
||||
}
|
||||
if p.inflight.Load() != 0 {
|
||||
continue
|
||||
}
|
||||
if time.Since(time.Unix(0, p.lastUse.Load())) > ttlDuration {
|
||||
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.id, p.config.UnloadAfter)
|
||||
p.Stop(10 * time.Second)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
} else {
|
||||
setState(StateStopped)
|
||||
notifyWaiters(res.err)
|
||||
req.respond <- res.err
|
||||
}
|
||||
|
||||
// Stop arrived while doStart was still running. Cancel the
|
||||
// start context to abort it, then wait for doStart to return.
|
||||
// If doStart had already crossed the finish line before
|
||||
// cancellation took effect, it returns a live cmd that we
|
||||
// must kill ourselves. The Run caller gets ErrAbort; the Stop
|
||||
// caller is parked in pendingStop and answered below.
|
||||
case stop := <-p.stopCh:
|
||||
cancelStart()
|
||||
res := <-resultCh
|
||||
if res.cmd != nil {
|
||||
p.killProcess(res.cmd, res.cmdDone, stop.timeout)
|
||||
}
|
||||
setState(StateStopped)
|
||||
notifyWaiters(ErrStartAborted)
|
||||
req.respond <- ErrStartAborted
|
||||
pendingStop = &stop
|
||||
|
||||
// Parent context cancelled (e.g. config reload) while doStart
|
||||
// was still running. Stop() returns early when parentCtx is
|
||||
// done and never sends on stopCh, so we must handle shutdown
|
||||
// here to avoid leaving doStart running indefinitely.
|
||||
case <-p.parentCtx.Done():
|
||||
cancelStart()
|
||||
// Mark shutdown before tearing the process down: killProcess
|
||||
// may block (e.g. taskkill on Windows is slow to spawn), and
|
||||
// callers observing State() should see StateShutdown promptly
|
||||
// rather than a stale StateStarting.
|
||||
setState(StateShutdown)
|
||||
res := <-resultCh
|
||||
if res.cmd != nil {
|
||||
p.killProcess(res.cmd, res.cmdDone, 100*time.Millisecond)
|
||||
}
|
||||
notifyWaiters(fmt.Errorf("[%s] shutdown", p.id))
|
||||
respondRun(fmt.Errorf("[%s] shutdown", p.id))
|
||||
return
|
||||
}
|
||||
// cancelStart is idempotent; calling it again here ensures the
|
||||
// context is released even on the success path (govet leak check).
|
||||
cancelStart()
|
||||
if pendingStop != nil {
|
||||
pendingStop.respond <- nil
|
||||
}
|
||||
|
||||
// Stop: tear down a running process.
|
||||
case stop := <-p.stopCh:
|
||||
if cmd != nil {
|
||||
setState(StateStopping)
|
||||
p.killProcess(cmd, cmdDone, stop.timeout)
|
||||
cmd = nil
|
||||
cmdDone = nil
|
||||
p.handler.Store(nil)
|
||||
}
|
||||
// Stop is a no-op (and not an error) when already Stopped — this
|
||||
// is what makes it idempotent for callers that don't track state.
|
||||
setState(StateStopped)
|
||||
respondRun(nil)
|
||||
stop.respond <- nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) doStart(startCtx context.Context, healthCheckTimeout time.Duration) startResult {
|
||||
if p.config.Proxy == "" {
|
||||
return startResult{err: fmt.Errorf("upstream proxy missing")}
|
||||
}
|
||||
|
||||
args, err := p.config.SanitizedCommand()
|
||||
if err != nil {
|
||||
return startResult{err: fmt.Errorf("unable to get sanitized command: %w", err)}
|
||||
}
|
||||
|
||||
proxyURL, err := url.Parse(p.config.Proxy)
|
||||
if err != nil {
|
||||
return startResult{err: fmt.Errorf("invalid proxy URL %q: %w", p.config.Proxy, err)}
|
||||
}
|
||||
|
||||
reverseProxy := httputil.NewSingleHostReverseProxy(proxyURL)
|
||||
reverseProxy.Transport = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: time.Duration(p.config.Timeouts.Connect) * time.Second,
|
||||
KeepAlive: time.Duration(p.config.Timeouts.KeepAlive) * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: time.Duration(p.config.Timeouts.TLSHandshake) * time.Second,
|
||||
ResponseHeaderTimeout: time.Duration(p.config.Timeouts.ResponseHeader) * time.Second,
|
||||
ExpectContinueTimeout: time.Duration(p.config.Timeouts.ExpectContinue) * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: time.Duration(p.config.Timeouts.IdleConn) * time.Second,
|
||||
}
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||
resp.Header.Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// httputil.ReverseProxy panics with http.ErrAbortHandler when the upstream
|
||||
// disconnects after response headers have been sent. Recover here so the
|
||||
// streaming termination is treated as a normal client/upstream disconnect.
|
||||
// see: https://github.com/golang/go/issues/23643
|
||||
handlerFn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
if rec == http.ErrAbortHandler {
|
||||
p.proxyLogger.Infof("<%s> recovered from upstream disconnection during streaming", p.id)
|
||||
} else {
|
||||
p.proxyLogger.Warnf("<%s> recovered from panic: %v", p.id, rec)
|
||||
}
|
||||
}
|
||||
}()
|
||||
reverseProxy.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
cmd := exec.Command(args[0], args[1:]...)
|
||||
cmd.Stderr = p.processLogger
|
||||
cmd.Stdout = p.processLogger
|
||||
cmd.Env = append(cmd.Environ(), p.config.Env...)
|
||||
setProcAttributes(cmd)
|
||||
|
||||
p.proxyLogger.Debugf("<%s> Executing start command: %s, env: %s", p.id, strings.Join(args, " "), strings.Join(p.config.Env, ", "))
|
||||
|
||||
cmdDone := make(chan struct{})
|
||||
if err := cmd.Start(); err != nil {
|
||||
return startResult{err: fmt.Errorf("failed to start command '%s': %w", strings.Join(args, " "), err)}
|
||||
}
|
||||
|
||||
go func() {
|
||||
waitErr := cmd.Wait()
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
p.proxyLogger.Debugf("<%s> process exited: code=%d, err=%v", p.id, exitErr.ExitCode(), waitErr)
|
||||
} else if waitErr != nil {
|
||||
p.proxyLogger.Debugf("<%s> process exited with error: %v", p.id, waitErr)
|
||||
} else {
|
||||
p.proxyLogger.Debugf("<%s> process exited cleanly", p.id)
|
||||
}
|
||||
close(cmdDone)
|
||||
}()
|
||||
|
||||
if startCtx.Err() != nil {
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: ErrStartAborted}
|
||||
}
|
||||
|
||||
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
||||
if checkEndpoint == "none" {
|
||||
return startResult{cmd: cmd, cmdDone: cmdDone, handlerFn: handlerFn}
|
||||
}
|
||||
|
||||
// Wait 250ms for the command to start up before health checking
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: ErrStartAborted}
|
||||
case <-time.After(250 * time.Millisecond):
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(healthCheckTimeout)
|
||||
for {
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: ErrStartAborted}
|
||||
case <-cmdDone:
|
||||
return startResult{err: fmt.Errorf("upstream command exited prematurely")}
|
||||
default:
|
||||
}
|
||||
|
||||
if time.Now().After(deadline) {
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: fmt.Errorf("health check timed out after %v", healthCheckTimeout)}
|
||||
}
|
||||
|
||||
req, _ := http.NewRequestWithContext(startCtx, "GET", p.config.CheckEndpoint, nil)
|
||||
rr := httptest.NewRecorder()
|
||||
reverseProxy.ServeHTTP(rr, req)
|
||||
resp := rr.Result()
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
p.proxyLogger.Infof("<%s> Health check passed on %s%s", p.id, p.config.Proxy, p.config.CheckEndpoint)
|
||||
break
|
||||
} else if startCtx.Err() != nil {
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: ErrStartAborted}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
p.killProcess(cmd, cmdDone, 5*time.Second)
|
||||
return startResult{err: ErrStartAborted}
|
||||
case <-cmdDone:
|
||||
return startResult{err: fmt.Errorf("upstream command exited prematurely")}
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
return startResult{cmd: cmd, cmdDone: cmdDone, handlerFn: handlerFn}
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) killProcess(cmd *exec.Cmd, cmdDone <-chan struct{}, gracefulTimeout time.Duration) {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if p.config.CmdStop != "" {
|
||||
stopArgs, err := config.SanitizeCommand(
|
||||
strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", cmd.Process.Pid)),
|
||||
)
|
||||
if err == nil {
|
||||
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
|
||||
stopCmd.Env = cmd.Env
|
||||
setProcAttributes(stopCmd)
|
||||
stopCmd.Run()
|
||||
} else {
|
||||
cmd.Process.Signal(syscall.SIGTERM)
|
||||
}
|
||||
} else {
|
||||
cmd.Process.Signal(syscall.SIGTERM)
|
||||
}
|
||||
|
||||
timer := time.NewTimer(gracefulTimeout)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-cmdDone:
|
||||
case <-timer.C:
|
||||
cmd.Process.Kill()
|
||||
<-cmdDone
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) ID() string {
|
||||
return p.id
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) Run(timeout time.Duration) error {
|
||||
req := runReq{
|
||||
timeout: timeout,
|
||||
respond: make(chan error, 1),
|
||||
}
|
||||
select {
|
||||
case p.runCh <- req:
|
||||
case <-p.parentCtx.Done():
|
||||
return fmt.Errorf("[%s] shutdown", p.id)
|
||||
}
|
||||
select {
|
||||
case err := <-req.respond:
|
||||
return err
|
||||
case <-p.parentCtx.Done():
|
||||
return fmt.Errorf("[%s] shutdown", p.id)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) WaitReady(ctx context.Context) error {
|
||||
req := waitReadyReq{respond: make(chan error, 1)}
|
||||
select {
|
||||
case p.waitReadyCh <- req:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-p.parentCtx.Done():
|
||||
return fmt.Errorf("[%s] shutdown", p.id)
|
||||
}
|
||||
select {
|
||||
case err := <-req.respond:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) Stop(timeout time.Duration) error {
|
||||
req := stopReq{
|
||||
timeout: timeout,
|
||||
respond: make(chan error, 1),
|
||||
}
|
||||
select {
|
||||
case p.stopCh <- req:
|
||||
case <-p.parentCtx.Done():
|
||||
return fmt.Errorf("[%s] shutdown", p.id)
|
||||
}
|
||||
return <-req.respond
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) State() ProcessState {
|
||||
if s, ok := p.state.Load().(ProcessState); ok {
|
||||
return s
|
||||
}
|
||||
return StateStopped
|
||||
}
|
||||
|
||||
func (p *ProcessCommand) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
fn := p.handler.Load()
|
||||
if fn == nil {
|
||||
http.Error(w, fmt.Sprintf("llama-swap-error: [%s] process is not ready", p.id), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
p.inflight.Add(1)
|
||||
defer func() {
|
||||
p.lastUse.Store(time.Now().UnixNano())
|
||||
p.inflight.Add(-1)
|
||||
}()
|
||||
(*fn)(w, r)
|
||||
}
|
||||
@@ -0,0 +1,646 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
const (
|
||||
testStartTimeout = 3 * time.Second
|
||||
testStopTimeout = 2 * time.Second
|
||||
testReturnTimeout = 1 * time.Second
|
||||
testPollInterval = 20 * time.Millisecond
|
||||
testLogPollInterval = 10 * time.Millisecond
|
||||
)
|
||||
|
||||
func newProcessCommand(t *testing.T, conf config.ModelConfig) *ProcessCommand {
|
||||
t.Helper()
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
p, err := New(context.Background(), t.Name(), conf, logger, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// runAsync starts Run in a goroutine and waits until the process is ready,
|
||||
// matching the new interface contract where Run blocks until the process is
|
||||
// terminated. Returns a channel that delivers Run's eventual error.
|
||||
func runAsync(t *testing.T, p *ProcessCommand) <-chan error {
|
||||
t.Helper()
|
||||
ch := make(chan error, 1)
|
||||
go func() { ch <- p.Run(testStartTimeout) }()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testStartTimeout)
|
||||
defer cancel()
|
||||
if err := p.WaitReady(ctx); err != nil {
|
||||
t.Fatalf("WaitReady: %v", err)
|
||||
}
|
||||
return ch
|
||||
}
|
||||
|
||||
func TestProcessCommand_StartStop(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
cmd, port := simpleResponderCmd(t, "-silent", "-respond hello")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
})
|
||||
t.Cleanup(func() { p.Stop(testStopTimeout) })
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
|
||||
// before start: no handler
|
||||
rr := httptest.NewRecorder()
|
||||
p.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("before start: expected 503, got %d", rr.Code)
|
||||
}
|
||||
if body := rr.Body.String(); !strings.Contains(body, "llama-swap-error") {
|
||||
t.Errorf("before start: expected body to contain %q, got %q", "llama-swap-error", body)
|
||||
}
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
if got := p.State(); got != StateReady {
|
||||
t.Errorf("after Run: expected state %s, got %s", StateReady, got)
|
||||
}
|
||||
|
||||
rr = httptest.NewRecorder()
|
||||
p.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("after Run: expected 200, got %d", rr.Code)
|
||||
}
|
||||
if body := rr.Body.String(); body != "hello" {
|
||||
t.Errorf("expected body %q, got %q", "hello", body)
|
||||
}
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop() error: %v", err)
|
||||
}
|
||||
if got := p.State(); got != StateStopped {
|
||||
t.Errorf("after Stop: expected state %s, got %s", StateStopped, got)
|
||||
}
|
||||
select {
|
||||
case err := <-runErr:
|
||||
if err != nil {
|
||||
t.Errorf("Run() after Stop: expected nil, got %v", err)
|
||||
}
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Fatal("Run() did not return after Stop")
|
||||
}
|
||||
|
||||
// after stop: handler cleared
|
||||
rr = httptest.NewRecorder()
|
||||
p.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("after stop: expected 503, got %d", rr.Code)
|
||||
}
|
||||
if body := rr.Body.String(); !strings.Contains(body, "llama-swap-error") {
|
||||
t.Errorf("after stop: expected body to contain %q, got %q", "llama-swap-error", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessCommand_Run_Idempotent(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
cmd, port := simpleResponderCmd(t, "-silent")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
})
|
||||
t.Cleanup(func() { p.Stop(testStopTimeout) })
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
|
||||
if err := p.Run(testStartTimeout); err == nil {
|
||||
t.Error("second Run() while running: expected error, got nil")
|
||||
}
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop() error: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Fatal("Run() did not return after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessCommand_Stop_Idempotent(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
cmd, port := simpleResponderCmd(t, "-silent")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
})
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop() before Run(): %v", err)
|
||||
}
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("first Stop() error: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Fatal("Run() did not return after Stop")
|
||||
}
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("second Stop() error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_StopCancelsRun verifies that a Stop sent while Run is
|
||||
// executing its health-check loop returns ErrAbort to the Run caller.
|
||||
//
|
||||
// A blocking mock HTTP server is used as the proxy so the test can deterministically
|
||||
// know when doStart is inside the health-check loop before issuing Stop.
|
||||
func TestProcessCommand_StopCancelsRun(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
healthCheckStarted := make(chan struct{}, 1)
|
||||
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Signal that a health check is in-flight, then block until the client
|
||||
// cancels (which happens when Stop cancels the start context).
|
||||
select {
|
||||
case healthCheckStarted <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
<-r.Context().Done()
|
||||
http.Error(w, "mock cancelled", http.StatusServiceUnavailable)
|
||||
}))
|
||||
defer mock.Close()
|
||||
|
||||
// simple-responder is the real process; health checks go to the blocking mock.
|
||||
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: mock.URL,
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 30,
|
||||
})
|
||||
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- p.Run(testStartTimeout)
|
||||
}()
|
||||
|
||||
// Block until doStart is actually performing a health check, guaranteeing
|
||||
// that Run is in-flight when Stop is called.
|
||||
<-healthCheckStarted
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop() error: %v", err)
|
||||
}
|
||||
|
||||
if err := <-runErrCh; !errors.Is(err, ErrStartAborted) {
|
||||
t.Errorf("expected ErrStartAborted from Run, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_ParentCtxCancelDuringStart verifies that cancelling the
|
||||
// parent context while doStart is health-checking causes the process to
|
||||
// transition to StateShutdown promptly, not wait for the health-check timeout.
|
||||
//
|
||||
// This is the config-reload race: Stop() returns early when parentCtx is
|
||||
// already done and never writes to stopCh, so without a parentCtx.Done()
|
||||
// case in the inner select, the process would keep loading indefinitely.
|
||||
func TestProcessCommand_ParentCtxCancelDuringStart(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
healthCheckStarted := make(chan struct{}, 1)
|
||||
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
select {
|
||||
case healthCheckStarted <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
<-r.Context().Done()
|
||||
http.Error(w, "mock cancelled", http.StatusServiceUnavailable)
|
||||
}))
|
||||
defer mock.Close()
|
||||
|
||||
parentCtx, cancelParent := context.WithCancel(context.Background())
|
||||
logger := logmon.NewWriter(io.Discard)
|
||||
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||
p, err := New(parentCtx, t.Name(), config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: mock.URL,
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 60,
|
||||
}, logger, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() { runErrCh <- p.Run(60 * time.Second) }()
|
||||
|
||||
<-healthCheckStarted
|
||||
|
||||
// Cancel parent context to simulate a config reload tearing down the old server.
|
||||
cancelParent()
|
||||
|
||||
select {
|
||||
case err := <-runErrCh:
|
||||
if !strings.Contains(err.Error(), "shutdown") {
|
||||
t.Errorf("Run error = %v, want shutdown error", err)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("process did not shut down within 5s after parent context cancel during start")
|
||||
}
|
||||
|
||||
// Run() may return before the run() goroutine writes StateShutdown;
|
||||
// poll briefly to avoid a spurious race in the assertion.
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if p.State() == StateShutdown {
|
||||
break
|
||||
}
|
||||
time.Sleep(testPollInterval)
|
||||
}
|
||||
if got := p.State(); got != StateShutdown {
|
||||
t.Errorf("after cancel: expected StateShutdown, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_RunStopCycle runs several sequential start/stop pairs on
|
||||
// fresh processes to confirm they are reusable.
|
||||
func TestProcessCommand_RunStopCycle(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
for i := range 3 {
|
||||
cmd, port := simpleResponderCmd(t, "-silent")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
})
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
|
||||
req := httptest.NewRequest("GET", "/health", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
p.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("cycle %d: expected 200 from /health, got %d", i, rr.Code)
|
||||
}
|
||||
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("cycle %d Stop() error: %v", i, err)
|
||||
}
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Fatalf("cycle %d: Run() did not return after Stop", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_ReverseProxyPanicIsRecovered drives the full proxy path:
|
||||
// the upstream responds healthy on /health (so Run completes), then on the
|
||||
// actual proxied request it hijacks the connection and closes it mid-body.
|
||||
// That upstream EOF makes httputil.ReverseProxy.copyResponse return an error,
|
||||
// which panics with http.ErrAbortHandler — the wrapped handlerFn must recover
|
||||
// and log the disconnect.
|
||||
//
|
||||
// Requests are issued through an httptest.NewServer wrapping the process so
|
||||
// the panic actually fires (httputil only panics on copy errors when the
|
||||
// request carries http.ServerContextKey, which a real server sets).
|
||||
//
|
||||
// see: https://github.com/golang/go/issues/23643
|
||||
func TestProcessCommand_ReverseProxyPanicIsRecovered(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/health" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
// Send a Content-Length that promises 100 bytes, deliver only a few,
|
||||
// then slam the connection shut. The reverse proxy will see EOF
|
||||
// before the body is fully copied and panic with ErrAbortHandler.
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
t.Errorf("upstream: hijack not supported")
|
||||
return
|
||||
}
|
||||
conn, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
t.Errorf("upstream: hijack: %v", err)
|
||||
return
|
||||
}
|
||||
_, _ = conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 100\r\nContent-Type: text/plain\r\n\r\npartial"))
|
||||
_ = conn.Close()
|
||||
}))
|
||||
t.Cleanup(upstream.Close)
|
||||
|
||||
// Capture proxy log output so we can assert the recover message was
|
||||
// emitted by handlerFn.
|
||||
logBuf := &syncBuffer{}
|
||||
proxyLogger := logmon.NewWriter(logBuf)
|
||||
procLogger := logmon.NewWriter(io.Discard)
|
||||
|
||||
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||
p, err := New(context.Background(), t.Name(), config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: upstream.URL,
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
}, procLogger, proxyLogger)
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { p.Stop(testStopTimeout) })
|
||||
|
||||
_ = runAsync(t, p)
|
||||
|
||||
// Wrap p in an httptest server so requests get http.ServerContextKey
|
||||
// automatically — that is what makes httputil.ReverseProxy raise the panic.
|
||||
front := httptest.NewServer(p)
|
||||
t.Cleanup(front.Close)
|
||||
|
||||
resp, err := http.Get(front.URL + "/disconnect")
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
const want = "recovered from upstream disconnection"
|
||||
deadline := time.Now().Add(testReturnTimeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if strings.Contains(logBuf.String(), want) {
|
||||
return
|
||||
}
|
||||
time.Sleep(testLogPollInterval)
|
||||
}
|
||||
t.Errorf("expected proxy log to contain %q; got:\n%s", want, logBuf.String())
|
||||
}
|
||||
|
||||
// syncBuffer is a concurrent-safe bytes.Buffer for capturing logmon output.
|
||||
type syncBuffer struct {
|
||||
mu sync.Mutex
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
func (b *syncBuffer) Write(p []byte) (int, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.buf.Write(p)
|
||||
}
|
||||
|
||||
func (b *syncBuffer) String() string {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.buf.String()
|
||||
}
|
||||
|
||||
// TestProcessCommand_TTL_StopsAfterIdle verifies that a process with a TTL
|
||||
// automatically stops itself after the idle timeout has elapsed following its
|
||||
// last request.
|
||||
func TestProcessCommand_TTL_StopsAfterIdle(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
t.Cleanup(mock.Close)
|
||||
|
||||
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||
|
||||
cfg := config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: mock.URL,
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
UnloadAfter: 1, // 1-second TTL
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
cfg.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
}
|
||||
|
||||
p := newProcessCommand(t, cfg)
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
defer func() {
|
||||
if p.State() == StateReady {
|
||||
p.Stop(testStopTimeout)
|
||||
}
|
||||
}()
|
||||
|
||||
if got := p.State(); got != StateReady {
|
||||
t.Fatalf("expected StateReady, got %s", got)
|
||||
}
|
||||
|
||||
// Make one request to prime the last-use timestamp.
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
p.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("expected 200 after request, got %d", rr.Code)
|
||||
}
|
||||
|
||||
// Wait for the TTL goroutine to fire and the process to fully stop.
|
||||
// Poll for StateStopped directly to avoid racing the StateStopping
|
||||
// intermediate state that sits between StateReady and StateStopped.
|
||||
deadline := time.Now().Add(5 * time.Second)
|
||||
for p.State() != StateStopped && time.Now().Before(deadline) {
|
||||
time.Sleep(testPollInterval)
|
||||
}
|
||||
|
||||
if got := p.State(); got != StateStopped {
|
||||
t.Fatalf("TTL did not stop process; state is %s (expected %s)", got, StateStopped)
|
||||
}
|
||||
|
||||
// Run() should have returned nil (clean stop from TTL).
|
||||
select {
|
||||
case err := <-runErr:
|
||||
if err != nil {
|
||||
t.Errorf("Run() after TTL stop: expected nil, got %v", err)
|
||||
}
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Fatal("Run() did not return after TTL-induced stop")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_TTL_ResetsOnRequest verifies that inflight requests
|
||||
// prevent the TTL goroutine from stopping the process, and that the TTL timer
|
||||
// resets after each request completes.
|
||||
func TestProcessCommand_TTL_ResetsOnRequest(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
t.Cleanup(mock.Close)
|
||||
|
||||
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: mock.URL,
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
UnloadAfter: 1, // 1-second TTL
|
||||
})
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
defer func() {
|
||||
if p.State() == StateReady {
|
||||
p.Stop(testStopTimeout)
|
||||
}
|
||||
}()
|
||||
|
||||
// Keep sending requests for 1.5s — past the 1s TTL — and verify
|
||||
// the process never stops while traffic is flowing.
|
||||
stopAt := time.Now().Add(1500 * time.Millisecond)
|
||||
for time.Now().Before(stopAt) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
p.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d", rr.Code)
|
||||
}
|
||||
if p.State() != StateReady {
|
||||
t.Fatalf("process was stopped during active traffic (state=%s)", p.State())
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
if got := p.State(); got != StateReady {
|
||||
t.Fatalf("expected StateReady while traffic was active, got %s", got)
|
||||
}
|
||||
|
||||
// Now stop manually to clean up.
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop() error: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Fatal("Run() did not return after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_TTL_ZeroDisables verifies that UnloadAfter=0 does not
|
||||
// spawn a TTL goroutine — the process stays ready until explicitly stopped.
|
||||
func TestProcessCommand_TTL_ZeroDisables(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
t.Cleanup(mock.Close)
|
||||
|
||||
cmd, _ := simpleResponderCmd(t, "-silent")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: mock.URL,
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
UnloadAfter: 0, // disabled
|
||||
})
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
defer func() {
|
||||
if p.State() == StateReady {
|
||||
p.Stop(testStopTimeout)
|
||||
}
|
||||
}()
|
||||
|
||||
if got := p.State(); got != StateReady {
|
||||
t.Fatalf("expected StateReady, got %s", got)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
p.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("expected 200 after request, got %d", rr.Code)
|
||||
}
|
||||
|
||||
// No TTL goroutine is spawned when UnloadAfter=0, so a brief sleep is
|
||||
// enough to confirm the process remains ready.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if got := p.State(); got != StateReady {
|
||||
t.Fatalf("process was stopped unexpectedly (state=%s) with TTL=0", got)
|
||||
}
|
||||
|
||||
// Cleanly stop.
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop() error: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-runErr:
|
||||
case <-time.After(testReturnTimeout):
|
||||
t.Fatal("Run() did not return after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessCommand_ConcurrentRunStop launches many concurrent run/stop racing
|
||||
// pairs to exercise the race detector and verify no deadlocks occur.
|
||||
func TestProcessCommand_ConcurrentRunStop(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
for range 10 {
|
||||
cmd, port := simpleResponderCmd(t, "-silent")
|
||||
cfg := config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
cfg.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
}
|
||||
|
||||
p := newProcessCommand(t, cfg)
|
||||
|
||||
runDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(runDone)
|
||||
p.Run(testStartTimeout) //nolint: errcheck — one goroutine wins the race
|
||||
}()
|
||||
go func() {
|
||||
p.Stop(testStopTimeout) //nolint: errcheck
|
||||
}()
|
||||
|
||||
// Backstop: the racing Stop may have arrived before Run got on the
|
||||
// channel (making it a no-op), so keep stopping until Run unblocks.
|
||||
deadline := time.After(testStartTimeout)
|
||||
for done := false; !done; {
|
||||
select {
|
||||
case <-runDone:
|
||||
done = true
|
||||
case <-deadline:
|
||||
t.Fatal("Run did not return")
|
||||
case <-time.After(testPollInterval):
|
||||
p.Stop(testStopTimeout) //nolint: errcheck
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
func TestProcessCommand_EmitsStateChangeEvents(t *testing.T) {
|
||||
skipIfNoSimpleResponder(t)
|
||||
|
||||
var mu sync.Mutex
|
||||
var transitions []shared.ProcessStateChangeEvent
|
||||
cancel := event.On(func(e shared.ProcessStateChangeEvent) {
|
||||
if e.ProcessName != t.Name() {
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
transitions = append(transitions, e)
|
||||
mu.Unlock()
|
||||
})
|
||||
defer cancel()
|
||||
|
||||
cmd, port := simpleResponderCmd(t, "-silent", "-respond hello")
|
||||
p := newProcessCommand(t, config.ModelConfig{
|
||||
Cmd: cmd,
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
HealthCheckTimeout: 10,
|
||||
})
|
||||
|
||||
runErr := runAsync(t, p)
|
||||
if err := p.Stop(testStopTimeout); err != nil {
|
||||
t.Fatalf("Stop: %v", err)
|
||||
}
|
||||
<-runErr
|
||||
|
||||
// Events are delivered asynchronously; give the dispatcher a moment.
|
||||
deadline := time.Now().Add(time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
mu.Lock()
|
||||
n := len(transitions)
|
||||
mu.Unlock()
|
||||
if n >= 4 {
|
||||
break
|
||||
}
|
||||
time.Sleep(testPollInterval)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
for _, e := range transitions {
|
||||
if e.OldState == e.NewState {
|
||||
t.Errorf("emitted no-op transition: %s -> %s", e.OldState, e.NewState)
|
||||
}
|
||||
}
|
||||
|
||||
want := []string{
|
||||
string(StateStopped) + "->" + string(StateStarting),
|
||||
string(StateStarting) + "->" + string(StateReady),
|
||||
string(StateReady) + "->" + string(StateStopping),
|
||||
string(StateStopping) + "->" + string(StateStopped),
|
||||
}
|
||||
got := make([]string, len(transitions))
|
||||
for i, e := range transitions {
|
||||
got[i] = e.OldState + "->" + e.NewState
|
||||
}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("transitions = %v, want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("transitions = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
//go:build !windows
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
// No-op on Unix systems
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
//go:build windows
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
HideWindow: true,
|
||||
CreationFlags: 0x08000000, // CREATE_NO_WINDOW
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user