Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5944a86e86 | |||
| 63d4a7d0eb | |||
| f45469f7ff | |||
| 34f9fd7340 | |||
| 8448efa7fc | |||
| 8cf2a389d8 | |||
| 0f133f5b74 | |||
| 1510b3fbd9 | |||
| 0f8a8e70f1 | |||
| 6c3819022c | |||
| 8580f0f733 | |||
| be82d1a6a0 | |||
| 6cf0962807 | |||
| 8eb5b7b6c4 | |||
| 5a57688aa8 | |||
| b79b7ef3d9 | |||
| 476086c066 | |||
| 4fae7cf946 |
+3
-1
@@ -1,3 +1,5 @@
|
||||
.aider*
|
||||
.env
|
||||
build/
|
||||
build/
|
||||
dist/
|
||||
.vscode
|
||||
@@ -0,0 +1,11 @@
|
||||
version: 2
|
||||
|
||||
builds:
|
||||
- env:
|
||||
- CGO_ENABLED=0
|
||||
goos:
|
||||
- linux
|
||||
- darwin
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
@@ -9,6 +9,9 @@ all: mac linux simple-responder
|
||||
clean:
|
||||
rm -rf $(BUILD_DIR)
|
||||
|
||||
test:
|
||||
go test -v ./proxy
|
||||
|
||||
# Build OSX binary
|
||||
mac:
|
||||
@echo "Building Mac binary..."
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
# llama-swap
|
||||
|
||||
[llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models, so let's swap llama-server instead!
|
||||

|
||||
|
||||
[llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models on demand. So let's swap the server on demand instead!
|
||||
|
||||
llama-swap is a proxy server that sits in front of llama-server. When a request for `/v1/chat/completions` comes in it will extract the `model` requested and change the underlying llama-server automatically.
|
||||
|
||||
- ✅ easy to deploy: single binary with no dependencies
|
||||
- ✅ full control over llama-server's startup settings
|
||||
- ✅ ❤️ for nvidia P40 users who are rely on llama.cpp for inference
|
||||
- ✅ ❤️ for users who are rely on llama.cpp for LLM inference
|
||||
|
||||
## config.yaml
|
||||
|
||||
@@ -20,34 +22,66 @@ healthCheckTimeout: 60
|
||||
# define valid model values and the upstream server start
|
||||
models:
|
||||
"llama":
|
||||
cmd: "llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf"
|
||||
cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf
|
||||
|
||||
# Where to proxy to, important it matches this format
|
||||
proxy: "http://127.0.0.1:8999"
|
||||
# where to reach the server started by cmd
|
||||
proxy: http://127.0.0.1:8999
|
||||
|
||||
# aliases model names to use this configuration for
|
||||
aliases:
|
||||
- "gpt-4o-mini"
|
||||
- "gpt-3.5-turbo"
|
||||
|
||||
# wait for this path to return an HTTP 200 before serving requests
|
||||
# defaults to /health to match llama.cpp
|
||||
#
|
||||
# use "none" to skip endpoint checking. This may cause requests to fail
|
||||
# until the server is ready
|
||||
checkEndpoint: /custom-endpoint
|
||||
|
||||
"qwen":
|
||||
# environment variables to pass to the command
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=0"
|
||||
cmd: "llama-server --port 8999 -m path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf"
|
||||
proxy: "http://127.0.0.1:8999"
|
||||
|
||||
# multiline for readability
|
||||
cmd: >
|
||||
llama-server --port 8999
|
||||
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||
|
||||
proxy: http://127.0.0.1:8999
|
||||
```
|
||||
|
||||
## Deployment
|
||||
## Installation
|
||||
|
||||
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
|
||||
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
||||
* _Note: Windows currently untested._
|
||||
1. Run the binary with `llama-swap --config path/to/config.yaml`
|
||||
|
||||
## Monitoring Logs
|
||||
|
||||
The `/logs` endpoint is available to monitor what llama-swap is doing. It will send the last 10KB of logs. Useful for monitoring the output of llama-server. It also supports streaming of logs.
|
||||
|
||||
Usage:
|
||||
|
||||
```
|
||||
# basic, sends up to the last 10KB of logs
|
||||
curl http://host/logs'
|
||||
|
||||
# add `stream` to stream new logs as they come in
|
||||
curl -Ns 'http://host/logs?stream'
|
||||
|
||||
# add `skip` to skip history (only useful if used with stream)
|
||||
curl -Ns 'http://host/logs?stream&skip'
|
||||
|
||||
# will output nothing :)
|
||||
curl -Ns 'http://host/logs?skip'
|
||||
```
|
||||
|
||||
## Systemd Unit Files
|
||||
|
||||
Use this unit file to start llama-swap on boot
|
||||
Use this unit file to start llama-swap on boot. This is only tested on Ubuntu.
|
||||
|
||||
`/etc/systemd/system/llama-swap.service`
|
||||
```
|
||||
@@ -68,4 +102,10 @@ StartLimitInterval=30
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
```
|
||||
```
|
||||
|
||||
## Building from Source
|
||||
|
||||
1. Install golang for your system
|
||||
1. run `make clean all`
|
||||
1. binaries will be built into `build/` directory
|
||||
|
||||
+28
-15
@@ -1,31 +1,44 @@
|
||||
# Seconds to wait for llama.cpp to be available to serve requests
|
||||
# Default (and minimum): 15 seconds
|
||||
healthCheckTimeout: 60
|
||||
healthCheckTimeout: 15
|
||||
|
||||
models:
|
||||
"llama":
|
||||
cmd: "models/llama-server-osx --port 8999 -m models/Llama-3.2-1B-Instruct-Q4_K_M.gguf"
|
||||
proxy: "http://127.0.0.1:8999"
|
||||
cmd: >
|
||||
models/llama-server-osx
|
||||
--port 8999
|
||||
-m models/Llama-3.2-1B-Instruct-Q4_K_M.gguf
|
||||
proxy: http://127.0.0.1:8999
|
||||
|
||||
# list of model name aliases this llama.cpp instance can serve
|
||||
aliases:
|
||||
- "gpt-4o-mini"
|
||||
- gpt-4o-mini
|
||||
|
||||
# check this path for a HTTP 200 response for the server to be ready
|
||||
checkEndpoint: /health
|
||||
|
||||
"qwen":
|
||||
cmd: "models/llama-server-osx --port 8999 -m models/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf"
|
||||
proxy: "http://127.0.0.1:8999"
|
||||
cmd: models/llama-server-osx --port 8999 -m models/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||
proxy: http://127.0.0.1:8999
|
||||
aliases:
|
||||
- "gpt-3.5-turbo"
|
||||
- gpt-3.5-turbo
|
||||
|
||||
"simple":
|
||||
# example of setting environment variables
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=0,1"
|
||||
- "env1=hello"
|
||||
cmd: "build/simple-responder --port 8999"
|
||||
proxy: "http://127.0.0.1:8999"
|
||||
- CUDA_VISIBLE_DEVICES=0,1
|
||||
- env1=hello
|
||||
cmd: build/simple-responder --port 8999
|
||||
proxy: http://127.0.0.1:8999
|
||||
|
||||
# don't use this, just for testing if things are broken
|
||||
# use "none" to skip check. Caution this may cause some requests to fail
|
||||
# until the upstream server is ready for traffic
|
||||
checkEndpoint: none
|
||||
|
||||
# don't use these, just for testing if things are broken
|
||||
"broken":
|
||||
cmd: "models/llama-server-osx --port 8999 -m models/doesnotexist.gguf"
|
||||
proxy: "http://127.0.0.1:8999"
|
||||
|
||||
cmd: models/llama-server-osx --port 8999 -m models/doesnotexist.gguf
|
||||
proxy: http://127.0.0.1:8999
|
||||
"broken_timeout":
|
||||
cmd: models/llama-server-osx --port 8999 -m models/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||
proxy: http://127.0.0.1:9000
|
||||
@@ -2,4 +2,12 @@ module github.com/mostlygeek/llama-swap
|
||||
|
||||
go 1.23.0
|
||||
|
||||
require gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
require (
|
||||
github.com/stretchr/testify v1.9.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
)
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
BIN
Binary file not shown.
|
After Width: | Height: | Size: 261 KiB |
+1
-1
@@ -25,7 +25,7 @@ func main() {
|
||||
proxyManager := proxy.New(config)
|
||||
http.HandleFunc("/", proxyManager.HandleFunc)
|
||||
|
||||
fmt.Println("llamagate listening on " + *listenStr)
|
||||
fmt.Println("llama-swap listening on " + *listenStr)
|
||||
if err := http.ListenAndServe(*listenStr, nil); err != nil {
|
||||
fmt.Printf("Error starting server: %v\n", err)
|
||||
os.Exit(1)
|
||||
|
||||
+27
-4
@@ -1,16 +1,23 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type ModelConfig struct {
|
||||
Cmd string `yaml:"cmd"`
|
||||
Proxy string `yaml:"proxy"`
|
||||
Aliases []string `yaml:"aliases"`
|
||||
Env []string `yaml:"env"`
|
||||
Cmd string `yaml:"cmd"`
|
||||
Proxy string `yaml:"proxy"`
|
||||
Aliases []string `yaml:"aliases"`
|
||||
Env []string `yaml:"env"`
|
||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||
}
|
||||
|
||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
return SanitizeCommand(m.Cmd)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -54,3 +61,19 @@ func LoadConfig(path string) (*Config, error) {
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
// Remove trailing backslashes
|
||||
cmdStr = strings.ReplaceAll(cmdStr, "\\ \n", " ")
|
||||
cmdStr = strings.ReplaceAll(cmdStr, "\\\n", " ")
|
||||
|
||||
// Split the command into arguments
|
||||
args := strings.Fields(cmdStr)
|
||||
|
||||
// Ensure the command is not empty
|
||||
if len(args) == 0 {
|
||||
return nil, fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLoadConfig(t *testing.T) {
|
||||
// Create a temporary YAML file for testing
|
||||
tempDir, err := os.MkdirTemp("", "test-config")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||
content := `models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
aliases:
|
||||
- "m1"
|
||||
- "model-one"
|
||||
env:
|
||||
- "VAR1=value1"
|
||||
- "VAR2=value2"
|
||||
checkEndpoint: "/health"
|
||||
healthCheckTimeout: 15
|
||||
`
|
||||
|
||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("Failed to write temporary file: %v", err)
|
||||
}
|
||||
|
||||
// Load the config and verify
|
||||
config, err := LoadConfig(tempFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
expected := &Config{
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, config)
|
||||
}
|
||||
|
||||
func TestModelConfigSanitizedCommand(t *testing.T) {
|
||||
config := &ModelConfig{
|
||||
Cmd: `python model1.py \
|
||||
--arg1 value1 \
|
||||
--arg2 value2`,
|
||||
}
|
||||
|
||||
args, err := config.SanitizedCommand()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
||||
}
|
||||
|
||||
func TestFindConfig(t *testing.T) {
|
||||
config := &Config{
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "python model1.py",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "python model2.py",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2", "model-two"},
|
||||
Env: []string{"VAR3=value3", "VAR4=value4"},
|
||||
CheckEndpoint: "/status",
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 10,
|
||||
}
|
||||
|
||||
// Test finding a model by its name
|
||||
modelConfig, found := config.FindConfig("model1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, config.Models["model1"], modelConfig)
|
||||
|
||||
// Test finding a model by its alias
|
||||
modelConfig, found = config.FindConfig("m1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, config.Models["model1"], modelConfig)
|
||||
|
||||
// Test finding a model that does not exist
|
||||
modelConfig, found = config.FindConfig("model3")
|
||||
assert.False(t, found)
|
||||
assert.Equal(t, ModelConfig{}, modelConfig)
|
||||
}
|
||||
|
||||
func TestSanitizeCommand(t *testing.T) {
|
||||
// Test a simple command
|
||||
args, err := SanitizeCommand("python model1.py")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"python", "model1.py"}, args)
|
||||
|
||||
// Test a command with spaces and newlines
|
||||
args, err = SanitizeCommand(`python model1.py \
|
||||
--arg1 value1 \
|
||||
--arg2 value2`)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
||||
|
||||
// Test an empty command
|
||||
args, err = SanitizeCommand("")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, args)
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"container/ring"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type LogMonitor struct {
|
||||
clients map[chan []byte]bool
|
||||
mu sync.RWMutex
|
||||
buffer *ring.Ring
|
||||
bufferMu sync.RWMutex
|
||||
|
||||
// typically this can be os.Stdout
|
||||
stdout io.Writer
|
||||
}
|
||||
|
||||
func NewLogMonitor() *LogMonitor {
|
||||
return NewLogMonitorWriter(os.Stdout)
|
||||
}
|
||||
|
||||
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
||||
return &LogMonitor{
|
||||
clients: make(map[chan []byte]bool),
|
||||
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
||||
stdout: stdout,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Write(p []byte) (n int, err error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
n, err = w.stdout.Write(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
w.bufferMu.Lock()
|
||||
bufferCopy := make([]byte, len(p))
|
||||
copy(bufferCopy, p)
|
||||
w.buffer.Value = bufferCopy
|
||||
w.buffer = w.buffer.Next()
|
||||
w.bufferMu.Unlock()
|
||||
|
||||
w.broadcast(p)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (w *LogMonitor) GetHistory() []byte {
|
||||
w.bufferMu.RLock()
|
||||
defer w.bufferMu.RUnlock()
|
||||
|
||||
var history []byte
|
||||
w.buffer.Do(func(p any) {
|
||||
if p != nil {
|
||||
if content, ok := p.([]byte); ok {
|
||||
history = append(history, content...)
|
||||
}
|
||||
}
|
||||
})
|
||||
return history
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Subscribe() chan []byte {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
ch := make(chan []byte, 100)
|
||||
w.clients[ch] = true
|
||||
return ch
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Unsubscribe(ch chan []byte) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
delete(w.clients, ch)
|
||||
close(ch)
|
||||
}
|
||||
|
||||
func (w *LogMonitor) broadcast(msg []byte) {
|
||||
w.mu.RLock()
|
||||
defer w.mu.RUnlock()
|
||||
|
||||
for client := range w.clients {
|
||||
select {
|
||||
case client <- msg:
|
||||
default:
|
||||
// If client buffer is full, skip
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLogMonitor(t *testing.T) {
|
||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Test subscription
|
||||
client1 := logMonitor.Subscribe()
|
||||
client2 := logMonitor.Subscribe()
|
||||
|
||||
defer logMonitor.Unsubscribe(client1)
|
||||
defer logMonitor.Unsubscribe(client2)
|
||||
|
||||
client1Messages := make([]byte, 0)
|
||||
client2Messages := make([]byte, 0)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case data := <-client1:
|
||||
client1Messages = append(client1Messages, data...)
|
||||
case data := <-client2:
|
||||
client2Messages = append(client2Messages, data...)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
logMonitor.Write([]byte("1"))
|
||||
logMonitor.Write([]byte("2"))
|
||||
logMonitor.Write([]byte("3"))
|
||||
|
||||
// Wait for the goroutine to finish
|
||||
wg.Wait()
|
||||
|
||||
// Check the buffer
|
||||
expectedHistory := "123"
|
||||
history := string(logMonitor.GetHistory())
|
||||
|
||||
if history != expectedHistory {
|
||||
t.Errorf("Expected history: %s, got: %s", expectedHistory, history)
|
||||
}
|
||||
|
||||
c1Data := string(client1Messages)
|
||||
if c1Data != expectedHistory {
|
||||
t.Errorf("Client1 expected %s, got: %s", expectedHistory, c1Data)
|
||||
}
|
||||
|
||||
c2Data := string(client2Messages)
|
||||
if c2Data != expectedHistory {
|
||||
t.Errorf("Client2 expected %s, got: %s", expectedHistory, c2Data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrite_ImmutableBuffer(t *testing.T) {
|
||||
// Create a new LogMonitor instance
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Prepare a message to write
|
||||
msg := []byte("Hello, World!")
|
||||
lenmsg := len(msg)
|
||||
|
||||
// Write the message to the LogMonitor
|
||||
n, err := lm.Write(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("Write failed: %v", err)
|
||||
}
|
||||
|
||||
if n != lenmsg {
|
||||
t.Errorf("Expected %d bytes written but got %d", lenmsg, n)
|
||||
}
|
||||
|
||||
// Change the original message
|
||||
msg[0] = 'B' // This should not affect the buffer
|
||||
|
||||
// Get the history from the LogMonitor
|
||||
history := lm.GetHistory()
|
||||
|
||||
// Check that the history contains the original message, not the modified one
|
||||
expected := []byte("Hello, World!")
|
||||
if !bytes.Equal(history, expected) {
|
||||
t.Errorf("Expected history to be %q, got %q", expected, history)
|
||||
}
|
||||
}
|
||||
+145
-25
@@ -4,10 +4,11 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -21,24 +22,90 @@ type ProxyManager struct {
|
||||
config *Config
|
||||
currentCmd *exec.Cmd
|
||||
currentConfig ModelConfig
|
||||
logMonitor *LogMonitor
|
||||
}
|
||||
|
||||
func New(config *Config) *ProxyManager {
|
||||
return &ProxyManager{config: config}
|
||||
return &ProxyManager{config: config, logMonitor: NewLogMonitor()}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) HandleFunc(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md#api-endpoints
|
||||
|
||||
if r.URL.Path == "/v1/chat/completions" {
|
||||
// extracts the `model` from json body
|
||||
pm.proxyChatRequest(w, r)
|
||||
} else if r.URL.Path == "/v1/models" {
|
||||
pm.listModels(w, r)
|
||||
} else if r.URL.Path == "/logs" {
|
||||
pm.streamLogs(w, r)
|
||||
} else {
|
||||
pm.proxyRequest(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) streamLogs(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
|
||||
ch := pm.logMonitor.Subscribe()
|
||||
defer pm.logMonitor.Unsubscribe(ch)
|
||||
|
||||
notify := r.Context().Done()
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
skipHistory := r.URL.Query().Has("skip")
|
||||
if !skipHistory {
|
||||
// Send history first
|
||||
history := pm.logMonitor.GetHistory()
|
||||
if len(history) != 0 {
|
||||
w.Write(history)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if !r.URL.Query().Has("stream") {
|
||||
return
|
||||
}
|
||||
|
||||
// Stream new logs
|
||||
for {
|
||||
select {
|
||||
case msg := <-ch:
|
||||
w.Write(msg)
|
||||
flusher.Flush()
|
||||
case <-notify:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) listModels(w http.ResponseWriter, _ *http.Request) {
|
||||
data := []interface{}{}
|
||||
for id := range pm.config.Models {
|
||||
data = append(data, map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "model",
|
||||
"created": time.Now().Unix(),
|
||||
"owned_by": "llama-swap",
|
||||
})
|
||||
}
|
||||
|
||||
// Set the Content-Type header to application/json
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
// Encode the data as JSON and write it to the response writer
|
||||
if err := json.NewEncoder(w).Encode(map[string]interface{}{"data": data}); err != nil {
|
||||
http.Error(w, "Error encoding JSON", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapModel(requestedModel string) error {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
@@ -57,40 +124,77 @@ func (pm *ProxyManager) swapModel(requestedModel string) error {
|
||||
// kill the current running one to swap it
|
||||
if pm.currentCmd != nil {
|
||||
pm.currentCmd.Process.Signal(syscall.SIGTERM)
|
||||
|
||||
// wait for it to end
|
||||
pm.currentCmd.Process.Wait()
|
||||
}
|
||||
|
||||
pm.currentConfig = modelConfig
|
||||
|
||||
args := strings.Fields(modelConfig.Cmd)
|
||||
args, err := modelConfig.SanitizedCommand()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get sanitized command: %v", err)
|
||||
}
|
||||
cmd := exec.Command(args[0], args[1:]...)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
// logMonitor only writes to stdout
|
||||
// so the upstream's stderr will go to os.Stdout
|
||||
cmd.Stdout = pm.logMonitor
|
||||
cmd.Stderr = pm.logMonitor
|
||||
|
||||
cmd.Env = modelConfig.Env
|
||||
|
||||
err := cmd.Start()
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pm.currentCmd = cmd
|
||||
|
||||
if err := pm.checkHealthEndpoint(); err != nil {
|
||||
// watch for the command to exist
|
||||
cmdCtx, cancel := context.WithCancelCause(context.Background())
|
||||
|
||||
// monitor the command's exist status
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
if err != nil {
|
||||
cancel(fmt.Errorf("command [%s] %s", strings.Join(cmd.Args, " "), err.Error()))
|
||||
} else {
|
||||
cancel(nil)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for checkHealthEndpoint
|
||||
if err := pm.checkHealthEndpoint(cmdCtx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) checkHealthEndpoint() error {
|
||||
func (pm *ProxyManager) checkHealthEndpoint(cmdCtx context.Context) error {
|
||||
|
||||
if pm.currentConfig.Proxy == "" {
|
||||
return fmt.Errorf("no upstream available to check /health")
|
||||
}
|
||||
|
||||
checkEndpoint := strings.TrimSpace(pm.currentConfig.CheckEndpoint)
|
||||
|
||||
if checkEndpoint == "none" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// keep default behaviour
|
||||
if checkEndpoint == "" {
|
||||
checkEndpoint = "/health"
|
||||
}
|
||||
|
||||
proxyTo := pm.currentConfig.Proxy
|
||||
|
||||
maxDuration := time.Second * time.Duration(pm.config.HealthCheckTimeout)
|
||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
|
||||
}
|
||||
|
||||
healthURL := proxyTo + "/health"
|
||||
client := &http.Client{}
|
||||
startTime := time.Now()
|
||||
|
||||
@@ -99,33 +203,50 @@ func (pm *ProxyManager) checkHealthEndpoint() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(req.Context(), 250*time.Millisecond)
|
||||
|
||||
ctx, cancel := context.WithTimeout(cmdCtx, time.Second)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
resp, err := client.Do(req)
|
||||
|
||||
ttl := (maxDuration - time.Since(startTime)).Seconds()
|
||||
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "connection refused") {
|
||||
|
||||
// if TCP dial can't connect any HTTP response after 5 seconds
|
||||
// exit quickly.
|
||||
if time.Since(startTime) > 5*time.Second {
|
||||
return fmt.Errorf("/healthy endpoint took more than 5 seconds to respond")
|
||||
// check if the context was cancelled
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err := context.Cause(ctx)
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
if time.Since(startTime) >= maxDuration {
|
||||
return fmt.Errorf("failed to check /healthy from: %s", healthURL)
|
||||
// wait a bit longer for TCP connection issues
|
||||
if strings.Contains(err.Error(), "connection refused") {
|
||||
fmt.Fprintf(pm.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
} else {
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
|
||||
if ttl < 0 {
|
||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
if time.Since(startTime) >= maxDuration {
|
||||
return fmt.Errorf("failed to check /healthy from: %s", healthURL)
|
||||
|
||||
if ttl < 0 {
|
||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
||||
}
|
||||
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
@@ -148,7 +269,7 @@ func (pm *ProxyManager) proxyChatRequest(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
if err := pm.swapModel(model); err != nil {
|
||||
http.Error(w, fmt.Sprintf("unable to swap to model: %s", err.Error()), http.StatusNotFound)
|
||||
http.Error(w, fmt.Sprintf("unable to swap to model, %s", err.Error()), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -190,7 +311,6 @@ func (pm *ProxyManager) proxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
|
||||
http.Error(w, writeErr.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if flusher, ok := w.(http.Flusher); ok {
|
||||
|
||||
Reference in New Issue
Block a user