Compare commits

...

4 Commits

Author SHA1 Message Date
David Wen Riccardi-Zhu 6516532568 Add optional TLS support (#340)
* Add optional TLS support

Introduce HTTPS support with net/http Server.ListenAndServeTLS.

This should enable the option of serving via HTTPS without a reverse
proxy.

Add two flags:
- tls-cert-file (path to the TLS certificate file)
- tls-key-file (path to the TLS private key file)

Both flags must be supplied together; otherwise exit with error.

If both flags are present, call srv.ListenAndServeTLS.
If not, fall back to the existing srv.ListenAndServe (HTTP); no changes
to existing non‑TLS behavior.
2025-10-15 19:29:02 -07:00
David Wen Riccardi-Zhu d58a8b85bf Refactor to use httputil.ReverseProxy (#342)
* Refactor to use httputil.ReverseProxy

Refactor manual HTTP proxying logic in Process.ProxyRequest to use the standard
library's httputil.ReverseProxy.

* Refactor TestProcess_ForceStopWithKill test

Update to handle behavior with httputil.ReverseProxy.

* Fix gin interface conversion panic
2025-10-13 16:47:04 -07:00
Benson Wong caf9e98b1e Fix race conditions in proxy.Process (#349)
- Fix data races found in proxy.Process by go's race detector. 
- Add data race detection to the CI tests. 

Fixes #348
2025-10-13 16:42:49 -07:00
Benson Wong 539278343b ui: tweak vertical space for mobile (#343) 2025-10-10 10:05:36 -07:00
8 changed files with 272 additions and 132 deletions
+1 -1
View File
@@ -33,7 +33,7 @@ test: proxy/ui_dist/placeholder.txt
# for CI - full test (takes longer) # for CI - full test (takes longer)
test-all: proxy/ui_dist/placeholder.txt test-all: proxy/ui_dist/placeholder.txt
go test -count=1 ./proxy/... go test -race -count=1 ./proxy/...
ui/node_modules: ui/node_modules:
cd ui && npm install cd ui && npm install
+29 -3
View File
@@ -28,7 +28,9 @@ var (
func main() { func main() {
// Define a command-line flag for the port // Define a command-line flag for the port
configPath := flag.String("config", "config.yaml", "config file name") configPath := flag.String("config", "config.yaml", "config file name")
listenStr := flag.String("listen", ":8080", "listen ip/port") listenStr := flag.String("listen", "", "listen ip/port")
certFile := flag.String("tls-cert-file", "", "TLS certificate file")
keyFile := flag.String("tls-key-file", "", "TLS key file")
showVersion := flag.Bool("version", false, "show version of build") showVersion := flag.Bool("version", false, "show version of build")
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change") watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
@@ -55,6 +57,23 @@ func main() {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
// Validate TLS flags.
var useTLS = (*certFile != "" && *keyFile != "")
if (*certFile != "" && *keyFile == "") ||
(*certFile == "" && *keyFile != "") {
fmt.Println("Error: Both --tls-cert-file and --tls-key-file must be provided for TLS.")
os.Exit(1)
}
// Set default ports.
if *listenStr == "" {
defaultPort := ":8080"
if useTLS {
defaultPort = ":8443"
}
listenStr = &defaultPort
}
// Setup channels for server management // Setup channels for server management
exitChan := make(chan struct{}) exitChan := make(chan struct{})
sigChan := make(chan os.Signal, 1) sigChan := make(chan os.Signal, 1)
@@ -167,9 +186,16 @@ func main() {
}() }()
// Start server // Start server
fmt.Printf("llama-swap listening on %s\n", *listenStr)
go func() { go func() {
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { var err error
if useTLS {
fmt.Printf("llama-swap listening with TLS on https://%s\n", *listenStr)
err = srv.ListenAndServeTLS(*certFile, *keyFile)
} else {
fmt.Printf("llama-swap listening on http://%s\n", *listenStr)
err = srv.ListenAndServe()
}
if err != nil && err != http.ErrServerClosed {
log.Fatalf("Fatal server error: %v\n", err) log.Fatalf("Fatal server error: %v\n", err)
} }
}() }()
+8
View File
@@ -3,6 +3,7 @@ package config
import ( import (
"fmt" "fmt"
"io" "io"
"net/url"
"os" "os"
"regexp" "regexp"
"runtime" "runtime"
@@ -342,6 +343,13 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
} }
} }
// Validate the proxy URL.
if _, err := url.Parse(modelConfig.Proxy); err != nil {
return Config{}, fmt.Errorf(
"model %s: invalid proxy URL: %w", modelId, err,
)
}
config.Models[modelId] = modelConfig config.Models[modelId] = modelConfig
} }
+91 -69
View File
@@ -4,14 +4,14 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
"net/http/httputil"
"net/url" "net/url"
"os/exec" "os/exec"
"strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"syscall" "syscall"
"time" "time"
@@ -39,11 +39,13 @@ const (
) )
type Process struct { type Process struct {
ID string ID string
config config.ModelConfig config config.ModelConfig
cmd *exec.Cmd cmd *exec.Cmd
reverseProxy *httputil.ReverseProxy
// PR #155 called to cancel the upstream process // PR #155 called to cancel the upstream process
cmdMutex sync.RWMutex
cancelUpstream context.CancelFunc cancelUpstream context.CancelFunc
// closed when command exits // closed when command exits
@@ -55,12 +57,14 @@ type Process struct {
healthCheckTimeout int healthCheckTimeout int
healthCheckLoopInterval time.Duration healthCheckLoopInterval time.Duration
lastRequestHandled time.Time lastRequestHandledMutex sync.RWMutex
lastRequestHandled time.Time
stateMutex sync.RWMutex stateMutex sync.RWMutex
state ProcessState state ProcessState
inFlightRequests sync.WaitGroup inFlightRequests sync.WaitGroup
inFlightRequestsCount atomic.Int32
// used to block on multiple start() calls // used to block on multiple start() calls
waitStarting sync.WaitGroup waitStarting sync.WaitGroup
@@ -81,10 +85,29 @@ func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, pr
concurrentLimit = config.ConcurrencyLimit concurrentLimit = config.ConcurrencyLimit
} }
// Setup the reverse proxy.
proxyURL, err := url.Parse(config.Proxy)
if err != nil {
proxyLogger.Errorf("<%s> invalid proxy URL %q: %v", ID, config.Proxy, err)
}
var reverseProxy *httputil.ReverseProxy
if proxyURL != nil {
reverseProxy = httputil.NewSingleHostReverseProxy(proxyURL)
reverseProxy.ModifyResponse = func(resp *http.Response) error {
// prevent nginx from buffering streaming responses (e.g., SSE)
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
resp.Header.Set("X-Accel-Buffering", "no")
}
return nil
}
}
return &Process{ return &Process{
ID: ID, ID: ID,
config: config, config: config,
cmd: nil, cmd: nil,
reverseProxy: reverseProxy,
cancelUpstream: nil, cancelUpstream: nil,
processLogger: processLogger, processLogger: processLogger,
proxyLogger: proxyLogger, proxyLogger: proxyLogger,
@@ -107,6 +130,20 @@ func (p *Process) LogMonitor() *LogMonitor {
return p.processLogger return p.processLogger
} }
// setLastRequestHandled sets the last request handled time in a thread-safe manner.
func (p *Process) setLastRequestHandled(t time.Time) {
p.lastRequestHandledMutex.Lock()
defer p.lastRequestHandledMutex.Unlock()
p.lastRequestHandled = t
}
// getLastRequestHandled gets the last request handled time in a thread-safe manner.
func (p *Process) getLastRequestHandled() time.Time {
p.lastRequestHandledMutex.RLock()
defer p.lastRequestHandledMutex.RUnlock()
return p.lastRequestHandled
}
// custom error types for swapping state // custom error types for swapping state
var ( var (
ErrExpectedStateMismatch = errors.New("expected state mismatch") ErrExpectedStateMismatch = errors.New("expected state mismatch")
@@ -130,6 +167,13 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
} }
p.state = newState p.state = newState
// Atomically increment waitStarting when entering StateStarting
// This ensures any thread that sees StateStarting will also see the WaitGroup counter incremented
if newState == StateStarting {
p.waitStarting.Add(1)
}
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState) p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState}) event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
return p.state, nil return p.state, nil
@@ -158,6 +202,15 @@ func (p *Process) CurrentState() ProcessState {
return p.state return p.state
} }
// forceState forces the process state to the new state with mutex protection.
// This should only be used in exceptional cases where the normal state transition
// validation via swapState() cannot be used.
func (p *Process) forceState(newState ProcessState) {
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
p.state = newState
}
// start starts the upstream command, checks the health endpoint, and sets the state to Ready // start starts the upstream command, checks the health endpoint, and sets the state to Ready
// it is a private method because starting is automatic but stopping can be called // it is a private method because starting is automatic but stopping can be called
// at any time. // at any time.
@@ -191,7 +244,7 @@ func (p *Process) start() error {
} }
} }
p.waitStarting.Add(1) // waitStarting.Add(1) is now called atomically in swapState() when transitioning to StateStarting
defer p.waitStarting.Done() defer p.waitStarting.Done()
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background()) cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())
@@ -201,8 +254,11 @@ func (p *Process) start() error {
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...) p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
p.cmd.Cancel = p.cmdStopUpstreamProcess p.cmd.Cancel = p.cmdStopUpstreamProcess
p.cmd.WaitDelay = p.gracefulStopTimeout p.cmd.WaitDelay = p.gracefulStopTimeout
p.cmdMutex.Lock()
p.cancelUpstream = ctxCancelUpstream p.cancelUpstream = ctxCancelUpstream
p.cmdWaitChan = make(chan struct{}) p.cmdWaitChan = make(chan struct{})
p.cmdMutex.Unlock()
p.failedStartCount++ // this will be reset to zero when the process has successfully started p.failedStartCount++ // this will be reset to zero when the process has successfully started
@@ -212,7 +268,7 @@ func (p *Process) start() error {
// Set process state to failed // Set process state to failed
if err != nil { if err != nil {
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil { if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
p.state = StateStopped // force it into a stopped state p.forceState(StateStopped) // force it into a stopped state
return fmt.Errorf( return fmt.Errorf(
"failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v", "failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v",
strings.Join(args, " "), err, curState, swapErr, strings.Join(args, " "), err, curState, swapErr,
@@ -285,10 +341,12 @@ func (p *Process) start() error {
return return
} }
// wait for all inflight requests to complete and ticker // skip the TTL check if there are inflight requests
p.inFlightRequests.Wait() if p.inFlightRequestsCount.Load() != 0 {
continue
}
if time.Since(p.lastRequestHandled) > maxDuration { if time.Since(p.getLastRequestHandled()) > maxDuration {
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter) p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
p.Stop() p.Stop()
return return
@@ -344,7 +402,7 @@ func (p *Process) Shutdown() {
p.stopCommand() p.stopCommand()
// just force it to this state since there is no recovery from shutdown // just force it to this state since there is no recovery from shutdown
p.state = StateShutdown p.forceState(StateShutdown)
} }
// stopCommand will send a SIGTERM to the process and wait for it to exit. // stopCommand will send a SIGTERM to the process and wait for it to exit.
@@ -355,13 +413,18 @@ func (p *Process) stopCommand() {
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime)) p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
}() }()
if p.cancelUpstream == nil { p.cmdMutex.RLock()
cancelUpstream := p.cancelUpstream
cmdWaitChan := p.cmdWaitChan
p.cmdMutex.RUnlock()
if cancelUpstream == nil {
p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID) p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID)
return return
} }
p.cancelUpstream() cancelUpstream()
<-p.cmdWaitChan <-cmdWaitChan
} }
func (p *Process) checkHealthEndpoint(healthURL string) error { func (p *Process) checkHealthEndpoint(healthURL string) error {
@@ -418,8 +481,10 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
} }
p.inFlightRequests.Add(1) p.inFlightRequests.Add(1)
p.inFlightRequestsCount.Add(1)
defer func() { defer func() {
p.lastRequestHandled = time.Now() p.setLastRequestHandled(time.Now())
p.inFlightRequestsCount.Add(-1)
p.inFlightRequests.Done() p.inFlightRequests.Done()
}() }()
@@ -434,56 +499,10 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
startDuration = time.Since(beginStartTime) startDuration = time.Since(beginStartTime)
} }
proxyTo := p.config.Proxy if p.reverseProxy != nil {
client := &http.Client{} p.reverseProxy.ServeHTTP(w, r)
req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyTo+r.URL.String(), r.Body) } else {
if err != nil { http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
req.Header = r.Header.Clone()
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
if err == nil {
req.ContentLength = contentLength
}
resp, err := client.Do(req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
defer resp.Body.Close()
for k, vv := range resp.Header {
for _, v := range vv {
w.Header().Add(k, v)
}
}
// prevent nginx from buffering streaming responses (e.g., SSE)
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
w.Header().Set("X-Accel-Buffering", "no")
}
w.WriteHeader(resp.StatusCode)
// faster than io.Copy when streaming
buf := make([]byte, 32*1024)
for {
n, err := resp.Body.Read(buf)
if n > 0 {
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
return
}
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
}
if err == io.EOF {
break
}
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
} }
totalTime := time.Since(requestBeginTime) totalTime := time.Since(requestBeginTime)
@@ -519,13 +538,16 @@ func (p *Process) waitForCmd() {
case StateStopping: case StateStopping:
if curState, err := p.swapState(StateStopping, StateStopped); err != nil { if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err) p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err)
p.state = StateStopped p.forceState(StateStopped)
} }
default: default:
p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState) p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState)
p.state = StateStopped // force it to be in this state p.forceState(StateStopped) // force it to be in this state
} }
p.cmdMutex.Lock()
close(p.cmdWaitChan) close(p.cmdWaitChan)
p.cmdMutex.Unlock()
} }
// cmdStopUpstreamProcess attemps to stop the upstream process gracefully // cmdStopUpstreamProcess attemps to stop the upstream process gracefully
+3 -1
View File
@@ -436,7 +436,9 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host") assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host")
} else { } else {
assert.Contains(t, w.Body.String(), "unexpected EOF") // Upstream may be killed mid-response.
// Assert an incomplete or partial response.
assert.NotEqual(t, "12345", w.Body.String())
} }
close(waitChan) close(waitChan)
+70 -34
View File
@@ -21,6 +21,32 @@ import (
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
// TestResponseRecorder adds CloseNotify to httptest.ResponseRecorder.
// "If you want to write your own tests around streams you will need a Recorder that can handle CloseNotifier."
// The tests can panic otherwise:
// panic: interface conversion: *httptest.ResponseRecorder is not http.CloseNotifier: missing method CloseNotify
// See: https://github.com/gin-gonic/gin/issues/1815
// TestResponseRecorder is taken from gin's own tests: https://github.com/gin-gonic/gin/blob/ce20f107f5dc498ec7489d7739541a25dcd48463/context_test.go#L1747-L1765
type TestResponseRecorder struct {
*httptest.ResponseRecorder
closeChannel chan bool
}
func (r *TestResponseRecorder) CloseNotify() <-chan bool {
return r.closeChannel
}
func (r *TestResponseRecorder) closeClient() {
r.closeChannel <- true
}
func CreateTestResponseRecorder() *TestResponseRecorder {
return &TestResponseRecorder{
httptest.NewRecorder(),
make(chan bool, 1),
}
}
func TestProxyManager_SwapProcessCorrectly(t *testing.T) { func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
config := config.AddDefaultGroupToConfig(config.Config{ config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
@@ -37,7 +63,7 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
for _, modelName := range []string{"model1", "model2"} { for _, modelName := range []string{"model1", "model2"} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName) reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
@@ -74,7 +100,7 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
t.Run(requestedModel, func(t *testing.T) { t.Run(requestedModel, func(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
@@ -116,7 +142,7 @@ func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
for _, requestedModel := range tests { for _, requestedModel := range tests {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
@@ -159,7 +185,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, key) reqBody := fmt.Sprintf(`{"model":"%s"}`, key)
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
@@ -212,7 +238,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
// Create a test request // Create a test request
req := httptest.NewRequest("GET", "/v1/models", nil) req := httptest.NewRequest("GET", "/v1/models", nil)
req.Header.Add("Origin", "i-am-the-origin") req.Header.Add("Origin", "i-am-the-origin")
w := httptest.NewRecorder() w := CreateTestResponseRecorder()
// Call the listModelsHandler // Call the listModelsHandler
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
@@ -311,7 +337,7 @@ models:
proxy := New(processedConfig) proxy := New(processedConfig)
req := httptest.NewRequest("GET", "/v1/models", nil) req := httptest.NewRequest("GET", "/v1/models", nil)
w := httptest.NewRecorder() w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
@@ -387,7 +413,7 @@ func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) {
// Request models list // Request models list
req := httptest.NewRequest("GET", "/v1/models", nil) req := httptest.NewRequest("GET", "/v1/models", nil)
w := httptest.NewRecorder() w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
@@ -448,7 +474,7 @@ func TestProxyManager_Shutdown(t *testing.T) {
defer wg.Done() defer wg.Done()
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName) reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := CreateTestResponseRecorder()
// send a request to trigger the proxy to load ... this should hang waiting for start up // send a request to trigger the proxy to load ... this should hang waiting for start up
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
@@ -476,12 +502,12 @@ func TestProxyManager_Unload(t *testing.T) {
proxy := New(conf) proxy := New(conf)
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1") reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady) assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
req = httptest.NewRequest("GET", "/unload", nil) req = httptest.NewRequest("GET", "/unload", nil)
w = httptest.NewRecorder() w = CreateTestResponseRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, w.Body.String(), "OK") assert.Equal(t, w.Body.String(), "OK")
@@ -519,7 +545,7 @@ func TestProxyManager_UnloadSingleModel(t *testing.T) {
for _, modelName := range []string{"model1", "model2"} { for _, modelName := range []string{"model1", "model2"} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName) reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
} }
@@ -527,7 +553,7 @@ func TestProxyManager_UnloadSingleModel(t *testing.T) {
assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model2"].CurrentState()) assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model2"].CurrentState())
req := httptest.NewRequest("POST", "/api/models/unload/model1", nil) req := httptest.NewRequest("POST", "/api/models/unload/model1", nil)
w := httptest.NewRecorder() w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
if !assert.Equal(t, w.Body.String(), "OK") { if !assert.Equal(t, w.Body.String(), "OK") {
@@ -571,7 +597,7 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
t.Run("no models loaded", func(t *testing.T) { t.Run("no models loaded", func(t *testing.T) {
req := httptest.NewRequest("GET", "/running", nil) req := httptest.NewRequest("GET", "/running", nil)
w := httptest.NewRecorder() w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
@@ -589,13 +615,13 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
// Load just a model. // Load just a model.
reqBody := `{"model":"model1"}` reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
// Simulate browser call for the `/running` endpoint. // Simulate browser call for the `/running` endpoint.
req = httptest.NewRequest("GET", "/running", nil) req = httptest.NewRequest("GET", "/running", nil)
w = httptest.NewRecorder() w = CreateTestResponseRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
var response RunningResponse var response RunningResponse
@@ -647,7 +673,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
// Create the request with the multipart form data // Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType()) req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder() rec := CreateTestResponseRecorder()
proxy.ServeHTTP(rec, req) proxy.ServeHTTP(rec, req)
// Verify the response // Verify the response
@@ -682,7 +708,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) { t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
@@ -716,7 +742,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
// Create the request with the multipart form data // Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType()) req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder() rec := CreateTestResponseRecorder()
proxy.ServeHTTP(rec, req) proxy.ServeHTTP(rec, req)
// Verify the response // Verify the response
@@ -784,7 +810,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
req.Header.Set(k, v) req.Header.Set(k, v)
} }
w := httptest.NewRecorder() w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, tt.expectedStatus, w.Code) assert.Equal(t, tt.expectedStatus, w.Code)
@@ -812,7 +838,7 @@ models:
defer proxy.StopProcesses(StopWaitForInflightRequest) defer proxy.StopProcesses(StopWaitForInflightRequest)
t.Run("main model name", func(t *testing.T) { t.Run("main model name", func(t *testing.T) {
req := httptest.NewRequest("GET", "/upstream/model1/test", nil) req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
rec := httptest.NewRecorder() rec := CreateTestResponseRecorder()
proxy.ServeHTTP(rec, req) proxy.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "model1", rec.Body.String()) assert.Equal(t, "model1", rec.Body.String())
@@ -820,7 +846,7 @@ models:
t.Run("model alias", func(t *testing.T) { t.Run("model alias", func(t *testing.T) {
req := httptest.NewRequest("GET", "/upstream/model-alias/test", nil) req := httptest.NewRequest("GET", "/upstream/model-alias/test", nil)
rec := httptest.NewRecorder() rec := CreateTestResponseRecorder()
proxy.ServeHTTP(rec, req) proxy.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "model1", rec.Body.String()) assert.Equal(t, "model1", rec.Body.String())
@@ -841,7 +867,7 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1") reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
@@ -869,7 +895,7 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
defer proxy.StopProcesses(StopWaitForInflightRequest) defer proxy.StopProcesses(StopWaitForInflightRequest)
reqBody := `{"model":"model1", "temperature":0.1, "x_param":"123", "y_param":"abc", "stream":true}` reqBody := `{"model":"model1", "temperature":0.1, "x_param":"123", "y_param":"abc", "stream":true}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
@@ -900,7 +926,7 @@ func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
// Make a non-streaming request // Make a non-streaming request
reqBody := `{"model":"model1", "stream": false}` reqBody := `{"model":"model1", "stream": false}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
@@ -935,7 +961,7 @@ func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
// Make a streaming request // Make a streaming request
reqBody := `{"model":"model1", "stream": true}` reqBody := `{"model":"model1", "stream": true}`
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
@@ -967,7 +993,7 @@ func TestProxyManager_HealthEndpoint(t *testing.T) {
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest) defer proxy.StopProcesses(StopWaitForInflightRequest)
req := httptest.NewRequest("GET", "/health", nil) req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder() rec := CreateTestResponseRecorder()
proxy.ServeHTTP(rec, req) proxy.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "OK", rec.Body.String()) assert.Equal(t, "OK", rec.Body.String())
@@ -988,7 +1014,7 @@ func TestProxyManager_CompletionEndpoint(t *testing.T) {
reqBody := `{"model":"model1"}` reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/completion", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/completion", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
@@ -1075,18 +1101,28 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
for _, endpoint := range endpoints { for _, endpoint := range endpoints {
t.Run(endpoint, func(t *testing.T) { t.Run(endpoint, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel() defer cancel()
req := httptest.NewRequest("GET", endpoint, nil) req := httptest.NewRequest("GET", endpoint, nil)
req = req.WithContext(ctx) req = req.WithContext(ctx)
rec := httptest.NewRecorder() rec := CreateTestResponseRecorder()
// We don't need the handler to fully complete, just to set the headers // Run handler in goroutine and wait for context timeout
// so run it in a goroutine and check the headers after a short delay done := make(chan struct{})
go proxy.ServeHTTP(rec, req) go func() {
time.Sleep(10 * time.Millisecond) // give it time to start and write headers defer close(done)
proxy.ServeHTTP(rec, req)
}()
// Wait for either the handler to complete or context to timeout
<-ctx.Done()
// At this point, the handler has either finished or been cancelled
// Wait for the goroutine to fully exit before reading
<-done
// Now it's safe to read from rec - no more concurrent writes
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering")) assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering"))
}) })
@@ -1109,7 +1145,7 @@ func TestProxyManager_ProxiedStreamingEndpointReturnsNoBufferingHeader(t *testin
reqBody := `{"model":"streaming-model"}` reqBody := `{"model":"streaming-model"}`
// simple-responder will return text/event-stream when stream=true is in the query // simple-responder will return text/event-stream when stream=true is in the query
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
rec := httptest.NewRecorder() rec := CreateTestResponseRecorder()
proxy.ServeHTTP(rec, req) proxy.ServeHTTP(rec, req)
+2 -2
View File
@@ -5,7 +5,7 @@ import { useTheme } from "../contexts/ThemeProvider";
import ConnectionStatusIcon from "./ConnectionStatus"; import ConnectionStatusIcon from "./ConnectionStatus";
export function Header() { export function Header() {
const { screenWidth, toggleTheme, isDarkMode, appTitle, setAppTitle } = useTheme(); const { screenWidth, toggleTheme, isDarkMode, appTitle, setAppTitle, isNarrow } = useTheme();
const handleTitleChange = useCallback( const handleTitleChange = useCallback(
(newTitle: string) => { (newTitle: string) => {
setAppTitle(newTitle.replace(/\n/g, "").trim().substring(0, 64) || "llama-swap"); setAppTitle(newTitle.replace(/\n/g, "").trim().substring(0, 64) || "llama-swap");
@@ -17,7 +17,7 @@ export function Header() {
`text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 ${isActive ? "font-semibold" : ""}`; `text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 ${isActive ? "font-semibold" : ""}`;
return ( return (
<header className="flex items-center justify-between bg-surface border-b border-border p-2 px-4 h-[75px]"> <header className={`flex items-center justify-between bg-surface border-b border-border px-4 ${isNarrow ? "py-1 h-[60px]" : "p-2 h-[75px]"}`}>
{screenWidth !== "xs" && screenWidth !== "sm" && ( {screenWidth !== "xs" && screenWidth !== "sm" && (
<h1 <h1
contentEditable contentEditable
+68 -22
View File
@@ -4,7 +4,7 @@ import { LogPanel } from "./LogViewer";
import { usePersistentState } from "../hooks/usePersistentState"; import { usePersistentState } from "../hooks/usePersistentState";
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels"; import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
import { useTheme } from "../contexts/ThemeProvider"; import { useTheme } from "../contexts/ThemeProvider";
import { RiEyeFill, RiEyeOffFill, RiSwapBoxFill, RiEjectLine } from "react-icons/ri"; import { RiEyeFill, RiEyeOffFill, RiSwapBoxFill, RiEjectLine, RiMenuFill } from "react-icons/ri";
export default function ModelsPage() { export default function ModelsPage() {
const { isNarrow } = useTheme(); const { isNarrow } = useTheme();
@@ -38,9 +38,11 @@ export default function ModelsPage() {
function ModelsPanel() { function ModelsPanel() {
const { models, loadModel, unloadAllModels, unloadSingleModel } = useAPI(); const { models, loadModel, unloadAllModels, unloadSingleModel } = useAPI();
const { isNarrow } = useTheme();
const [isUnloading, setIsUnloading] = useState(false); const [isUnloading, setIsUnloading] = useState(false);
const [showUnlisted, setShowUnlisted] = usePersistentState("showUnlisted", true); const [showUnlisted, setShowUnlisted] = usePersistentState("showUnlisted", true);
const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name
const [menuOpen, setMenuOpen] = useState(false);
const filteredModels = useMemo(() => { const filteredModels = useMemo(() => {
return models.filter((model) => showUnlisted || !model.unlisted); return models.filter((model) => showUnlisted || !model.unlisted);
@@ -66,33 +68,77 @@ function ModelsPanel() {
return ( return (
<div className="card h-full flex flex-col"> <div className="card h-full flex flex-col">
<div className="shrink-0"> <div className="shrink-0">
<h2>Models</h2> <div className="flex justify-between items-baseline">
<div className="flex justify-between"> <h2 className={isNarrow ? "text-xl" : ""}>Models</h2>
<div className="flex gap-2"> {isNarrow && (
<button <div className="relative">
className="btn text-base flex items-center gap-2" <button className="btn text-base flex items-center gap-2 py-1" onClick={() => setMenuOpen(!menuOpen)}>
onClick={toggleIdorName} <RiMenuFill size="20" />
style={{ lineHeight: "1.2" }} </button>
> {menuOpen && (
<RiSwapBoxFill size="20" /> {showIdorName === "id" ? "ID" : "Name"} <div className="absolute right-0 mt-2 w-48 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-20">
</button> <button
className="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
onClick={() => {
toggleIdorName();
setMenuOpen(false);
}}
>
<RiSwapBoxFill size="20" /> {showIdorName === "id" ? "Show Name" : "Show ID"}
</button>
<button
className="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
onClick={() => {
setShowUnlisted(!showUnlisted);
setMenuOpen(false);
}}
>
{showUnlisted ? <RiEyeOffFill size="20" /> : <RiEyeFill size="20" />}{" "}
{showUnlisted ? "Hide Unlisted" : "Show Unlisted"}
</button>
<button
className="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
onClick={() => {
handleUnloadAllModels();
setMenuOpen(false);
}}
disabled={isUnloading}
>
<RiEjectLine size="24" /> {isUnloading ? "Unloading..." : "Unload All"}
</button>
</div>
)}
</div>
)}
</div>
{!isNarrow && (
<div className="flex justify-between">
<div className="flex gap-2">
<button
className="btn text-base flex items-center gap-2"
onClick={toggleIdorName}
style={{ lineHeight: "1.2" }}
>
<RiSwapBoxFill size="20" /> {showIdorName === "id" ? "ID" : "Name"}
</button>
<button
className="btn text-base flex items-center gap-2"
onClick={() => setShowUnlisted(!showUnlisted)}
style={{ lineHeight: "1.2" }}
>
{showUnlisted ? <RiEyeFill size="20" /> : <RiEyeOffFill size="20" />} unlisted
</button>
</div>
<button <button
className="btn text-base flex items-center gap-2" className="btn text-base flex items-center gap-2"
onClick={() => setShowUnlisted(!showUnlisted)} onClick={handleUnloadAllModels}
style={{ lineHeight: "1.2" }} disabled={isUnloading}
> >
{showUnlisted ? <RiEyeFill size="20" /> : <RiEyeOffFill size="20" />} unlisted <RiEjectLine size="24" /> {isUnloading ? "Unloading..." : "Unload All"}
</button> </button>
</div> </div>
<button )}
className="btn text-base flex items-center gap-2"
onClick={handleUnloadAllModels}
disabled={isUnloading}
>
<RiEjectLine size="24" /> {isUnloading ? "Unloading..." : "Unload All"}
</button>
</div>
</div> </div>
<div className="flex-1 overflow-y-auto"> <div className="flex-1 overflow-y-auto">