Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 014a2fa9a3 | |||
| 5ceaef6144 |
@@ -1,7 +1,4 @@
|
|||||||

|

|
||||||

|
|
||||||

|
|
||||||

|
|
||||||
|
|
||||||
# llama-swap
|
# llama-swap
|
||||||
|
|
||||||
@@ -67,8 +64,8 @@ models:
|
|||||||
# Default (and minimum) is 15 seconds
|
# Default (and minimum) is 15 seconds
|
||||||
healthCheckTimeout: 60
|
healthCheckTimeout: 60
|
||||||
|
|
||||||
# Valid log levels: debug, info (default), warn, error
|
# Write HTTP logs (useful for troubleshooting), defaults to false
|
||||||
logLevel: info
|
logRequests: true
|
||||||
|
|
||||||
# define valid model values and the upstream server start
|
# define valid model values and the upstream server start
|
||||||
models:
|
models:
|
||||||
@@ -219,15 +216,9 @@ Of course, CLI access is also supported:
|
|||||||
# sends up to the last 10KB of logs
|
# sends up to the last 10KB of logs
|
||||||
curl http://host/logs'
|
curl http://host/logs'
|
||||||
|
|
||||||
# streams combined logs
|
# streams logs
|
||||||
curl -Ns 'http://host/logs/stream'
|
curl -Ns 'http://host/logs/stream'
|
||||||
|
|
||||||
# just llama-swap's logs
|
|
||||||
curl -Ns 'http://host/logs/stream/proxy'
|
|
||||||
|
|
||||||
# just upstream's logs
|
|
||||||
curl -Ns 'http://host/logs/stream/upstream'
|
|
||||||
|
|
||||||
# stream and filter logs with linux pipes
|
# stream and filter logs with linux pipes
|
||||||
curl -Ns http://host/logs/stream | grep 'eval time'
|
curl -Ns http://host/logs/stream | grep 'eval time'
|
||||||
|
|
||||||
|
|||||||
+3
-3
@@ -1,9 +1,9 @@
|
|||||||
# Seconds to wait for llama.cpp to be available to serve requests
|
# Seconds to wait for llama.cpp to be available to serve requests
|
||||||
# Default (and minimum): 15 seconds
|
# Default (and minimum): 15 seconds
|
||||||
healthCheckTimeout: 90
|
healthCheckTimeout: 15
|
||||||
|
|
||||||
# valid log levels: debug, info (default), warn, error
|
# Log HTTP requests helpful for troubleshoot, defaults to False
|
||||||
logLevel: info
|
logRequests: true
|
||||||
|
|
||||||
models:
|
models:
|
||||||
"llama":
|
"llama":
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
|||||||
type Config struct {
|
type Config struct {
|
||||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||||
LogRequests bool `yaml:"logRequests"`
|
LogRequests bool `yaml:"logRequests"`
|
||||||
LogLevel string `yaml:"logLevel"`
|
|
||||||
Models map[string]ModelConfig `yaml:"models"`
|
Models map[string]ModelConfig `yaml:"models"`
|
||||||
Profiles map[string][]string `yaml:"profiles"`
|
Profiles map[string][]string `yaml:"profiles"`
|
||||||
|
|
||||||
|
|||||||
@@ -2,21 +2,11 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"container/ring"
|
"container/ring"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LogLevel int
|
|
||||||
|
|
||||||
const (
|
|
||||||
LevelDebug LogLevel = iota
|
|
||||||
LevelInfo
|
|
||||||
LevelWarn
|
|
||||||
LevelError
|
|
||||||
)
|
|
||||||
|
|
||||||
type LogMonitor struct {
|
type LogMonitor struct {
|
||||||
clients map[chan []byte]bool
|
clients map[chan []byte]bool
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
@@ -25,10 +15,6 @@ type LogMonitor struct {
|
|||||||
|
|
||||||
// typically this can be os.Stdout
|
// typically this can be os.Stdout
|
||||||
stdout io.Writer
|
stdout io.Writer
|
||||||
|
|
||||||
// logging levels
|
|
||||||
level LogLevel
|
|
||||||
prefix string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLogMonitor() *LogMonitor {
|
func NewLogMonitor() *LogMonitor {
|
||||||
@@ -40,8 +26,6 @@ func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
|||||||
clients: make(map[chan []byte]bool),
|
clients: make(map[chan []byte]bool),
|
||||||
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
||||||
stdout: stdout,
|
stdout: stdout,
|
||||||
level: LevelInfo,
|
|
||||||
prefix: "",
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,77 +94,3 @@ func (w *LogMonitor) broadcast(msg []byte) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *LogMonitor) SetPrefix(prefix string) {
|
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
w.prefix = prefix
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *LogMonitor) SetLogLevel(level LogLevel) {
|
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
w.level = level
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
|
|
||||||
prefix := ""
|
|
||||||
if w.prefix != "" {
|
|
||||||
prefix = fmt.Sprintf("[%s] ", w.prefix)
|
|
||||||
}
|
|
||||||
return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *LogMonitor) log(level LogLevel, msg string) {
|
|
||||||
if level < w.level {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.Write(w.formatMessage(level.String(), msg))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *LogMonitor) Debug(msg string) {
|
|
||||||
w.log(LevelDebug, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *LogMonitor) Info(msg string) {
|
|
||||||
w.log(LevelInfo, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *LogMonitor) Warn(msg string) {
|
|
||||||
w.log(LevelWarn, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *LogMonitor) Error(msg string) {
|
|
||||||
w.log(LevelError, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *LogMonitor) Debugf(format string, args ...interface{}) {
|
|
||||||
w.log(LevelDebug, fmt.Sprintf(format, args...))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *LogMonitor) Infof(format string, args ...interface{}) {
|
|
||||||
w.log(LevelInfo, fmt.Sprintf(format, args...))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *LogMonitor) Warnf(format string, args ...interface{}) {
|
|
||||||
w.log(LevelWarn, fmt.Sprintf(format, args...))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *LogMonitor) Errorf(format string, args ...interface{}) {
|
|
||||||
w.log(LevelError, fmt.Sprintf(format, args...))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l LogLevel) String() string {
|
|
||||||
switch l {
|
|
||||||
case LevelDebug:
|
|
||||||
return "DEBUG"
|
|
||||||
case LevelInfo:
|
|
||||||
return "INFO"
|
|
||||||
case LevelWarn:
|
|
||||||
return "WARN"
|
|
||||||
case LevelError:
|
|
||||||
return "ERROR"
|
|
||||||
default:
|
|
||||||
return "UNKNOWN"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
+22
-35
@@ -30,12 +30,10 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Process struct {
|
type Process struct {
|
||||||
ID string
|
ID string
|
||||||
config ModelConfig
|
config ModelConfig
|
||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
|
logMonitor *LogMonitor
|
||||||
processLogger *LogMonitor
|
|
||||||
proxyLogger *LogMonitor
|
|
||||||
|
|
||||||
healthCheckTimeout int
|
healthCheckTimeout int
|
||||||
healthCheckLoopInterval time.Duration
|
healthCheckLoopInterval time.Duration
|
||||||
@@ -55,14 +53,13 @@ type Process struct {
|
|||||||
shutdownCancel context.CancelFunc
|
shutdownCancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
|
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
return &Process{
|
return &Process{
|
||||||
ID: ID,
|
ID: ID,
|
||||||
config: config,
|
config: config,
|
||||||
cmd: nil,
|
cmd: nil,
|
||||||
processLogger: processLogger,
|
logMonitor: logMonitor,
|
||||||
proxyLogger: proxyLogger,
|
|
||||||
healthCheckTimeout: healthCheckTimeout,
|
healthCheckTimeout: healthCheckTimeout,
|
||||||
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
|
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
|
||||||
state: StateStopped,
|
state: StateStopped,
|
||||||
@@ -71,11 +68,6 @@ func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// LogMonitor returns the log monitor associated with the process.
|
|
||||||
func (p *Process) LogMonitor() *LogMonitor {
|
|
||||||
return p.processLogger
|
|
||||||
}
|
|
||||||
|
|
||||||
// custom error types for swapping state
|
// custom error types for swapping state
|
||||||
var (
|
var (
|
||||||
ErrExpectedStateMismatch = errors.New("expected state mismatch")
|
ErrExpectedStateMismatch = errors.New("expected state mismatch")
|
||||||
@@ -93,11 +85,9 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !isValidTransition(p.state, newState) {
|
if !isValidTransition(p.state, newState) {
|
||||||
p.proxyLogger.Warnf("Invalid state transition from %s to %s", p.state, newState)
|
|
||||||
return p.state, ErrInvalidStateTransition
|
return p.state, ErrInvalidStateTransition
|
||||||
}
|
}
|
||||||
|
|
||||||
p.proxyLogger.Debugf("State transition from %s to %s", expectedState, newState)
|
|
||||||
p.state = newState
|
p.state = newState
|
||||||
return p.state, nil
|
return p.state, nil
|
||||||
}
|
}
|
||||||
@@ -162,8 +152,8 @@ func (p *Process) start() error {
|
|||||||
defer p.waitStarting.Done()
|
defer p.waitStarting.Done()
|
||||||
|
|
||||||
p.cmd = exec.Command(args[0], args[1:]...)
|
p.cmd = exec.Command(args[0], args[1:]...)
|
||||||
p.cmd.Stdout = p.processLogger
|
p.cmd.Stdout = p.logMonitor
|
||||||
p.cmd.Stderr = p.processLogger
|
p.cmd.Stderr = p.logMonitor
|
||||||
p.cmd.Env = p.config.Env
|
p.cmd.Env = p.config.Env
|
||||||
|
|
||||||
err = p.cmd.Start()
|
err = p.cmd.Start()
|
||||||
@@ -224,16 +214,15 @@ func (p *Process) start() error {
|
|||||||
return errors.New("health check interrupted due to shutdown")
|
return errors.New("health check interrupted due to shutdown")
|
||||||
default:
|
default:
|
||||||
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
||||||
p.proxyLogger.Infof("Health check passed on %s", healthURL)
|
|
||||||
cancelHealthCheck()
|
cancelHealthCheck()
|
||||||
break loop
|
break loop
|
||||||
} else {
|
} else {
|
||||||
if strings.Contains(err.Error(), "connection refused") {
|
if strings.Contains(err.Error(), "connection refused") {
|
||||||
endTime, _ := checkDeadline.Deadline()
|
endTime, _ := checkDeadline.Deadline()
|
||||||
ttl := time.Until(endTime)
|
ttl := time.Until(endTime)
|
||||||
p.proxyLogger.Infof("Connection refused on %s, retrying in %.0fs", healthURL, ttl.Seconds())
|
fmt.Fprintf(p.logMonitor, "!!! Connection refused on %s, ttl %.0fs\n", healthURL, ttl.Seconds())
|
||||||
} else {
|
} else {
|
||||||
p.proxyLogger.Infof("Health check error on %s, %v", healthURL, err)
|
fmt.Fprintf(p.logMonitor, "!!! Health check error: %v\n", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -257,8 +246,7 @@ func (p *Process) start() error {
|
|||||||
p.inFlightRequests.Wait()
|
p.inFlightRequests.Wait()
|
||||||
|
|
||||||
if time.Since(p.lastRequestHandled) > maxDuration {
|
if time.Since(p.lastRequestHandled) > maxDuration {
|
||||||
|
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
|
||||||
p.proxyLogger.Infof("Unloading model %s, TTL of %ds reached.", p.ID, p.config.UnloadAfter)
|
|
||||||
p.Stop()
|
p.Stop()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -279,7 +267,7 @@ func (p *Process) Stop() {
|
|||||||
|
|
||||||
// calling Stop() when state is invalid is a no-op
|
// calling Stop() when state is invalid is a no-op
|
||||||
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
|
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
|
||||||
p.proxyLogger.Infof("Stop() Ready -> StateStopping err: %v, current state: %v", err, curState)
|
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() Ready -> StateStopping err: %v, current state: %v\n", err, curState)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -287,7 +275,7 @@ func (p *Process) Stop() {
|
|||||||
p.stopCommand(5 * time.Second)
|
p.stopCommand(5 * time.Second)
|
||||||
|
|
||||||
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||||
p.proxyLogger.Infof("Stop() StateStopping -> StateStopped err: %v, current state: %v", err, curState)
|
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() StateStopping -> StateStopped err: %v, current state: %v\n", err, curState)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -312,32 +300,31 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
if p.cmd == nil || p.cmd.Process == nil {
|
if p.cmd == nil || p.cmd.Process == nil {
|
||||||
p.proxyLogger.Warnf("Process [%s] cmd or cmd.Process is nil", p.ID)
|
fmt.Fprintf(p.logMonitor, "!!! process [%s] cmd or cmd.Process is nil", p.ID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.terminateProcess(); err != nil {
|
p.cmd.Process.Signal(syscall.SIGTERM)
|
||||||
p.proxyLogger.Infof("Failed to gracefully terminate process [%s]: %v", p.ID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-sigtermTimeout.Done():
|
case <-sigtermTimeout.Done():
|
||||||
p.proxyLogger.Infof("Process [%s] timed out waiting to stop, sending KILL signal", p.ID)
|
fmt.Fprintf(p.logMonitor, "!!! process [%s] timed out waiting to stop, sending KILL signal\n", p.ID)
|
||||||
p.cmd.Process.Kill()
|
p.cmd.Process.Kill()
|
||||||
case err := <-sigtermNormal:
|
case err := <-sigtermNormal:
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errno, ok := err.(syscall.Errno); ok {
|
if errno, ok := err.(syscall.Errno); ok {
|
||||||
p.proxyLogger.Errorf("Process [%s] errno >> %v", p.ID, errno)
|
fmt.Fprintf(p.logMonitor, "!!! process [%s] errno >> %v\n", p.ID, errno)
|
||||||
} else if exitError, ok := err.(*exec.ExitError); ok {
|
} else if exitError, ok := err.(*exec.ExitError); ok {
|
||||||
if strings.Contains(exitError.String(), "signal: terminated") {
|
if strings.Contains(exitError.String(), "signal: terminated") {
|
||||||
p.proxyLogger.Infof("Process [%s] stopped OK", p.ID)
|
fmt.Fprintf(p.logMonitor, "!!! process [%s] stopped OK\n", p.ID)
|
||||||
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
||||||
p.proxyLogger.Infof("Process [%s] interrupted OK", p.ID)
|
fmt.Fprintf(p.logMonitor, "!!! process [%s] interrupted OK\n", p.ID)
|
||||||
} else {
|
} else {
|
||||||
p.proxyLogger.Warnf("Process [%s] ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
|
fmt.Fprintf(p.logMonitor, "!!! process [%s] ExitError >> %v, exit code: %d\n", p.ID, exitError, exitError.ExitCode())
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
p.proxyLogger.Errorf("Process [%s] exited >> %v", p.ID, err)
|
fmt.Fprintf(p.logMonitor, "!!! process [%s] exited >> %v\n", p.ID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
//go:build !windows
|
|
||||||
|
|
||||||
package proxy
|
|
||||||
|
|
||||||
import "syscall"
|
|
||||||
|
|
||||||
func (p *Process) terminateProcess() error {
|
|
||||||
return p.cmd.Process.Signal(syscall.SIGTERM)
|
|
||||||
}
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
//go:build windows
|
|
||||||
|
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os/exec"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (p *Process) terminateProcess() error {
|
|
||||||
pid := fmt.Sprintf("%d", p.cmd.Process.Pid)
|
|
||||||
cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid)
|
|
||||||
return cmd.Run()
|
|
||||||
}
|
|
||||||
+14
-14
@@ -5,6 +5,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -12,17 +13,13 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
discardLogger = NewLogMonitorWriter(io.Discard)
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
||||||
|
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||||
expectedMessage := "testing91931"
|
expectedMessage := "testing91931"
|
||||||
config := getTestSimpleResponderConfig(expectedMessage)
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
|
||||||
// Create a process
|
// Create a process
|
||||||
process := NewProcess("test-process", 5, config, discardLogger, discardLogger)
|
process := NewProcess("test-process", 5, config, logMonitor)
|
||||||
defer process.Stop()
|
defer process.Stop()
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/test", nil)
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
@@ -55,10 +52,11 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
|||||||
// are all handled successfully, even though they all may ask for the process to .start()
|
// are all handled successfully, even though they all may ask for the process to .start()
|
||||||
func TestProcess_WaitOnMultipleStarts(t *testing.T) {
|
func TestProcess_WaitOnMultipleStarts(t *testing.T) {
|
||||||
|
|
||||||
|
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||||
expectedMessage := "testing91931"
|
expectedMessage := "testing91931"
|
||||||
config := getTestSimpleResponderConfig(expectedMessage)
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
|
||||||
process := NewProcess("test-process", 5, config, discardLogger, discardLogger)
|
process := NewProcess("test-process", 5, config, logMonitor)
|
||||||
defer process.Stop()
|
defer process.Stop()
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
@@ -86,7 +84,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
|
|||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
}
|
}
|
||||||
|
|
||||||
process := NewProcess("broken", 1, config, discardLogger, discardLogger)
|
process := NewProcess("broken", 1, config, NewLogMonitor())
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -111,7 +109,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
|||||||
config.UnloadAfter = 3 // seconds
|
config.UnloadAfter = 3 // seconds
|
||||||
assert.Equal(t, 3, config.UnloadAfter)
|
assert.Equal(t, 3, config.UnloadAfter)
|
||||||
|
|
||||||
process := NewProcess("ttl_test", 2, config, discardLogger, discardLogger)
|
process := NewProcess("ttl_test", 2, config, NewLogMonitorWriter(io.Discard))
|
||||||
defer process.Stop()
|
defer process.Stop()
|
||||||
|
|
||||||
// this should take 4 seconds
|
// this should take 4 seconds
|
||||||
@@ -153,7 +151,7 @@ func TestProcess_LowTTLValue(t *testing.T) {
|
|||||||
config.UnloadAfter = 1 // second
|
config.UnloadAfter = 1 // second
|
||||||
assert.Equal(t, 1, config.UnloadAfter)
|
assert.Equal(t, 1, config.UnloadAfter)
|
||||||
|
|
||||||
process := NewProcess("ttl", 2, config, discardLogger, discardLogger)
|
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(os.Stdout))
|
||||||
defer process.Stop()
|
defer process.Stop()
|
||||||
|
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
@@ -180,7 +178,7 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
|||||||
|
|
||||||
expectedMessage := "12345"
|
expectedMessage := "12345"
|
||||||
config := getTestSimpleResponderConfig(expectedMessage)
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
process := NewProcess("t", 10, config, discardLogger, discardLogger)
|
process := NewProcess("t", 10, config, NewLogMonitorWriter(os.Stdout))
|
||||||
defer process.Stop()
|
defer process.Stop()
|
||||||
|
|
||||||
results := map[string]string{
|
results := map[string]string{
|
||||||
@@ -257,8 +255,9 @@ func TestProcess_SwapState(t *testing.T) {
|
|||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), discardLogger, discardLogger)
|
p := &Process{
|
||||||
p.state = test.currentState
|
state: test.currentState,
|
||||||
|
}
|
||||||
|
|
||||||
resultState, err := p.swapState(test.expectedState, test.newState)
|
resultState, err := p.swapState(test.expectedState, test.newState)
|
||||||
if err != nil && test.expectedError == nil {
|
if err != nil && test.expectedError == nil {
|
||||||
@@ -283,6 +282,7 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
|
|||||||
t.Skip("skipping long shutdown test")
|
t.Skip("skipping long shutdown test")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||||
expectedMessage := "testing91931"
|
expectedMessage := "testing91931"
|
||||||
|
|
||||||
// make a config where the healthcheck will always fail because port is wrong
|
// make a config where the healthcheck will always fail because port is wrong
|
||||||
@@ -290,7 +290,7 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
|
|||||||
config.Proxy = "http://localhost:9998/test"
|
config.Proxy = "http://localhost:9998/test"
|
||||||
|
|
||||||
healthCheckTTLSeconds := 30
|
healthCheckTTLSeconds := 30
|
||||||
process := NewProcess("test-process", healthCheckTTLSeconds, config, discardLogger, discardLogger)
|
process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor)
|
||||||
|
|
||||||
// make it a lot faster
|
// make it a lot faster
|
||||||
process.healthCheckLoopInterval = time.Second
|
process.healthCheckLoopInterval = time.Second
|
||||||
|
|||||||
+37
-77
@@ -7,7 +7,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -28,96 +27,59 @@ type ProxyManager struct {
|
|||||||
|
|
||||||
config *Config
|
config *Config
|
||||||
currentProcesses map[string]*Process
|
currentProcesses map[string]*Process
|
||||||
|
logMonitor *LogMonitor
|
||||||
ginEngine *gin.Engine
|
ginEngine *gin.Engine
|
||||||
|
|
||||||
// logging
|
|
||||||
proxyLogger *LogMonitor
|
|
||||||
upstreamLogger *LogMonitor
|
|
||||||
muxLogger *LogMonitor
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(config *Config) *ProxyManager {
|
func New(config *Config) *ProxyManager {
|
||||||
// set up loggers
|
|
||||||
stdoutLogger := NewLogMonitorWriter(os.Stdout)
|
|
||||||
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
|
|
||||||
proxyLogger := NewLogMonitorWriter(stdoutLogger)
|
|
||||||
|
|
||||||
if config.LogRequests {
|
|
||||||
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
|
|
||||||
}
|
|
||||||
|
|
||||||
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
|
|
||||||
case "debug":
|
|
||||||
proxyLogger.SetLogLevel(LevelDebug)
|
|
||||||
case "info":
|
|
||||||
proxyLogger.SetLogLevel(LevelInfo)
|
|
||||||
case "warn":
|
|
||||||
proxyLogger.SetLogLevel(LevelWarn)
|
|
||||||
case "error":
|
|
||||||
proxyLogger.SetLogLevel(LevelError)
|
|
||||||
default:
|
|
||||||
proxyLogger.SetLogLevel(LevelInfo)
|
|
||||||
}
|
|
||||||
|
|
||||||
pm := &ProxyManager{
|
pm := &ProxyManager{
|
||||||
config: config,
|
config: config,
|
||||||
currentProcesses: make(map[string]*Process),
|
currentProcesses: make(map[string]*Process),
|
||||||
|
logMonitor: NewLogMonitor(),
|
||||||
ginEngine: gin.New(),
|
ginEngine: gin.New(),
|
||||||
|
|
||||||
proxyLogger: proxyLogger,
|
|
||||||
muxLogger: stdoutLogger,
|
|
||||||
upstreamLogger: upstreamLogger,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pm.ginEngine.Use(func(c *gin.Context) {
|
if config.LogRequests {
|
||||||
// Start timer
|
pm.ginEngine.Use(func(c *gin.Context) {
|
||||||
start := time.Now()
|
// Start timer
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
// capture these because /upstream/:model rewrites them in c.Next()
|
// capture these because /upstream/:model rewrites them in c.Next()
|
||||||
clientIP := c.ClientIP()
|
clientIP := c.ClientIP()
|
||||||
method := c.Request.Method
|
method := c.Request.Method
|
||||||
path := c.Request.URL.Path
|
path := c.Request.URL.Path
|
||||||
|
|
||||||
// Process request
|
// Process request
|
||||||
c.Next()
|
c.Next()
|
||||||
|
|
||||||
// Stop timer
|
// Stop timer
|
||||||
duration := time.Since(start)
|
duration := time.Since(start)
|
||||||
|
|
||||||
statusCode := c.Writer.Status()
|
statusCode := c.Writer.Status()
|
||||||
bodySize := c.Writer.Size()
|
bodySize := c.Writer.Size()
|
||||||
|
|
||||||
pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
|
fmt.Fprintf(pm.logMonitor, "[llama-swap] %s [%s] \"%s %s %s\" %d %d \"%s\" %v\n",
|
||||||
clientIP,
|
clientIP,
|
||||||
method,
|
time.Now().Format("2006-01-02 15:04:05"),
|
||||||
path,
|
method,
|
||||||
c.Request.Proto,
|
path,
|
||||||
statusCode,
|
c.Request.Proto,
|
||||||
bodySize,
|
statusCode,
|
||||||
c.Request.UserAgent(),
|
bodySize,
|
||||||
duration,
|
c.Request.UserAgent(),
|
||||||
)
|
duration,
|
||||||
})
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// see: issue: #81, #77 and #42 for CORS issues
|
// see: https://github.com/mostlygeek/llama-swap/issues/42
|
||||||
// respond with permissive OPTIONS for any endpoint
|
// respond with permissive OPTIONS for any endpoint
|
||||||
pm.ginEngine.Use(func(c *gin.Context) {
|
pm.ginEngine.Use(func(c *gin.Context) {
|
||||||
if c.Request.Method == "OPTIONS" {
|
if c.Request.Method == "OPTIONS" {
|
||||||
c.Header("Access-Control-Allow-Origin", "*")
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||||
|
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||||
// allow whatever the client requested by default
|
c.AbortWithStatus(204)
|
||||||
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
|
|
||||||
sanitized := SanitizeAccessControlRequestHeaderValues(headers)
|
|
||||||
c.Header("Access-Control-Allow-Headers", sanitized)
|
|
||||||
} else {
|
|
||||||
c.Header(
|
|
||||||
"Access-Control-Allow-Headers",
|
|
||||||
"Content-Type, Authorization, Accept, X-Requested-With",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
c.Header("Access-Control-Max-Age", "86400")
|
|
||||||
c.AbortWithStatus(http.StatusNoContent)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Next()
|
c.Next()
|
||||||
@@ -142,8 +104,6 @@ func New(config *Config) *ProxyManager {
|
|||||||
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
||||||
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
||||||
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
|
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
|
||||||
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
|
|
||||||
pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE)
|
|
||||||
|
|
||||||
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
|
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
|
||||||
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
||||||
@@ -303,20 +263,19 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
|||||||
requestedProcessKey := ProcessKeyName(profileName, realModelName)
|
requestedProcessKey := ProcessKeyName(profileName, realModelName)
|
||||||
|
|
||||||
if process, found := pm.currentProcesses[requestedProcessKey]; found {
|
if process, found := pm.currentProcesses[requestedProcessKey]; found {
|
||||||
pm.proxyLogger.Debugf("No-swap, using existing process for model [%s]", requestedModel)
|
|
||||||
return process, nil
|
return process, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// stop all running models
|
// stop all running models
|
||||||
pm.proxyLogger.Infof("Swapping model to [%s]", requestedModel)
|
|
||||||
pm.stopProcesses()
|
pm.stopProcesses()
|
||||||
|
|
||||||
if profileName == "" {
|
if profileName == "" {
|
||||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||||
if !found {
|
if !found {
|
||||||
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
|
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger)
|
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
||||||
processKey := ProcessKeyName(profileName, modelID)
|
processKey := ProcessKeyName(profileName, modelID)
|
||||||
pm.currentProcesses[processKey] = process
|
pm.currentProcesses[processKey] = process
|
||||||
} else {
|
} else {
|
||||||
@@ -327,7 +286,7 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
|||||||
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
|
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
|
||||||
}
|
}
|
||||||
|
|
||||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger)
|
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
||||||
processKey := ProcessKeyName(profileName, modelID)
|
processKey := ProcessKeyName(profileName, modelID)
|
||||||
pm.currentProcesses[processKey] = process
|
pm.currentProcesses[processKey] = process
|
||||||
}
|
}
|
||||||
@@ -415,6 +374,7 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
||||||
|
|
||||||
accept := c.GetHeader("Accept")
|
accept := c.GetHeader("Accept")
|
||||||
if strings.Contains(accept, "text/html") {
|
if strings.Contains(accept, "text/html") {
|
||||||
// Set the Content-Type header to text/html
|
// Set the Content-Type header to text/html
|
||||||
@@ -27,7 +28,7 @@ func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
c.Header("Content-Type", "text/plain")
|
c.Header("Content-Type", "text/plain")
|
||||||
history := pm.muxLogger.GetHistory()
|
history := pm.logMonitor.GetHistory()
|
||||||
_, err := c.Writer.Write(history)
|
_, err := c.Writer.Write(history)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithError(http.StatusInternalServerError, err)
|
c.AbortWithError(http.StatusInternalServerError, err)
|
||||||
@@ -41,14 +42,8 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
|||||||
c.Header("Transfer-Encoding", "chunked")
|
c.Header("Transfer-Encoding", "chunked")
|
||||||
c.Header("X-Content-Type-Options", "nosniff")
|
c.Header("X-Content-Type-Options", "nosniff")
|
||||||
|
|
||||||
logMonitorId := c.Param("logMonitorID")
|
ch := pm.logMonitor.Subscribe()
|
||||||
logger, err := pm.getLogger(logMonitorId)
|
defer pm.logMonitor.Unsubscribe(ch)
|
||||||
if err != nil {
|
|
||||||
c.String(http.StatusBadRequest, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ch := logger.Subscribe()
|
|
||||||
defer logger.Unsubscribe(ch)
|
|
||||||
|
|
||||||
notify := c.Request.Context().Done()
|
notify := c.Request.Context().Done()
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
@@ -61,7 +56,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
|||||||
// Send history first if not skipped
|
// Send history first if not skipped
|
||||||
|
|
||||||
if !skipHistory {
|
if !skipHistory {
|
||||||
history := logger.GetHistory()
|
history := pm.logMonitor.GetHistory()
|
||||||
if len(history) != 0 {
|
if len(history) != 0 {
|
||||||
c.Writer.Write(history)
|
c.Writer.Write(history)
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
@@ -90,21 +85,15 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
|
|||||||
c.Header("Connection", "keep-alive")
|
c.Header("Connection", "keep-alive")
|
||||||
c.Header("X-Content-Type-Options", "nosniff")
|
c.Header("X-Content-Type-Options", "nosniff")
|
||||||
|
|
||||||
logMonitorId := c.Param("logMonitorID")
|
ch := pm.logMonitor.Subscribe()
|
||||||
logger, err := pm.getLogger(logMonitorId)
|
defer pm.logMonitor.Unsubscribe(ch)
|
||||||
if err != nil {
|
|
||||||
c.String(http.StatusBadRequest, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ch := logger.Subscribe()
|
|
||||||
defer logger.Unsubscribe(ch)
|
|
||||||
|
|
||||||
notify := c.Request.Context().Done()
|
notify := c.Request.Context().Done()
|
||||||
|
|
||||||
// Send history first if not skipped
|
// Send history first if not skipped
|
||||||
_, skipHistory := c.GetQuery("no-history")
|
_, skipHistory := c.GetQuery("no-history")
|
||||||
if !skipHistory {
|
if !skipHistory {
|
||||||
history := logger.GetHistory()
|
history := pm.logMonitor.GetHistory()
|
||||||
if len(history) != 0 {
|
if len(history) != 0 {
|
||||||
c.SSEvent("message", string(history))
|
c.SSEvent("message", string(history))
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
@@ -122,21 +111,3 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getLogger searches for the appropriate logger based on the logMonitorId
|
|
||||||
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
|
|
||||||
var logger *LogMonitor
|
|
||||||
|
|
||||||
if logMonitorId == "" {
|
|
||||||
// maintain the default
|
|
||||||
logger = pm.muxLogger
|
|
||||||
} else if logMonitorId == "proxy" {
|
|
||||||
logger = pm.proxyLogger
|
|
||||||
} else if logMonitorId == "upstream" {
|
|
||||||
logger = pm.upstreamLogger
|
|
||||||
} else {
|
|
||||||
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
|
|
||||||
}
|
|
||||||
|
|
||||||
return logger, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -639,72 +639,5 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
|||||||
assert.Equal(t, upstreamModelName, response["model"])
|
assert.Equal(t, upstreamModelName, response["model"])
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
|
||||||
config := &Config{
|
|
||||||
HealthCheckTimeout: 15,
|
|
||||||
Models: map[string]ModelConfig{
|
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
|
||||||
},
|
|
||||||
LogRequests: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
method string
|
|
||||||
requestHeaders map[string]string
|
|
||||||
expectedStatus int
|
|
||||||
expectedHeaders map[string]string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "OPTIONS with no headers",
|
|
||||||
method: "OPTIONS",
|
|
||||||
expectedStatus: http.StatusNoContent,
|
|
||||||
expectedHeaders: map[string]string{
|
|
||||||
"Access-Control-Allow-Origin": "*",
|
|
||||||
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
|
|
||||||
"Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "OPTIONS with specific headers",
|
|
||||||
method: "OPTIONS",
|
|
||||||
requestHeaders: map[string]string{
|
|
||||||
"Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header",
|
|
||||||
},
|
|
||||||
expectedStatus: http.StatusNoContent,
|
|
||||||
expectedHeaders: map[string]string{
|
|
||||||
"Access-Control-Allow-Origin": "*",
|
|
||||||
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
|
|
||||||
"Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Non-OPTIONS request",
|
|
||||||
method: "GET",
|
|
||||||
expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
proxy := New(config)
|
|
||||||
defer proxy.StopProcesses()
|
|
||||||
|
|
||||||
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
|
|
||||||
for k, v := range tt.requestHeaders {
|
|
||||||
req.Header.Set(k, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
proxy.ginEngine.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
|
||||||
|
|
||||||
for header, expectedValue := range tt.expectedHeaders {
|
|
||||||
assert.Equal(t, expectedValue, w.Header().Get(header))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,43 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func isTokenChar(r rune) bool {
|
|
||||||
switch {
|
|
||||||
case r >= 'a' && r <= 'z':
|
|
||||||
case r >= 'A' && r <= 'Z':
|
|
||||||
case r >= '0' && r <= '9':
|
|
||||||
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func SanitizeAccessControlRequestHeaderValues(headerValues string) string {
|
|
||||||
parts := strings.Split(headerValues, ",")
|
|
||||||
valid := make([]string, 0, len(parts))
|
|
||||||
|
|
||||||
for _, p := range parts {
|
|
||||||
v := strings.TrimSpace(p)
|
|
||||||
if v == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
validPart := true
|
|
||||||
for _, c := range v {
|
|
||||||
if !isTokenChar(c) {
|
|
||||||
validPart = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if validPart {
|
|
||||||
valid = append(valid, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return strings.Join(valid, ", ")
|
|
||||||
}
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestSanitizeAccessControlRequestHeaderValues(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "empty string",
|
|
||||||
input: "",
|
|
||||||
expected: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "whitespace only",
|
|
||||||
input: " ",
|
|
||||||
expected: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "single valid value",
|
|
||||||
input: "content-type",
|
|
||||||
expected: "content-type",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple valid values",
|
|
||||||
input: "content-type, authorization, x-requested-with",
|
|
||||||
expected: "content-type, authorization, x-requested-with",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "values with extra spaces",
|
|
||||||
input: " content-type , authorization ",
|
|
||||||
expected: "content-type, authorization",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "values with tabs",
|
|
||||||
input: "content-type,\tauthorization",
|
|
||||||
expected: "content-type, authorization",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "values with invalid characters",
|
|
||||||
input: "content-type, auth\n, x-requested-with\r",
|
|
||||||
expected: "content-type, auth, x-requested-with",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty values in list",
|
|
||||||
input: "content-type,,authorization",
|
|
||||||
expected: "content-type, authorization",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "leading and trailing commas",
|
|
||||||
input: ",content-type,authorization,",
|
|
||||||
expected: "content-type, authorization",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "mixed valid and invalid values",
|
|
||||||
input: "content-type, \x00invalid, x-requested-with",
|
|
||||||
expected: "content-type, x-requested-with",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "mixed case values",
|
|
||||||
input: "Content-Type, my-Valid-Header, Another-hEader",
|
|
||||||
expected: "Content-Type, my-Valid-Header, Another-hEader",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got := SanitizeAccessControlRequestHeaderValues(tt.input)
|
|
||||||
if got != tt.expected {
|
|
||||||
t.Errorf("SanitizeAccessControlRequestHeaderValues(%q) = %q, want %q",
|
|
||||||
tt.input, got, tt.expected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user