Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| baeb0c4e7f | |||
| 2833517eef | |||
| abdc2bfdb3 | |||
| c3b834737f | |||
| 3c8e727b73 | |||
| 3a1e9f81f1 | |||
| 72c883f36c | |||
| 1b04d034cf | |||
| 2e45f5692a | |||
| c97b80bdfe | |||
| ae3ef9bc39 | |||
| db6715bec3 | |||
| da5d9e8a6a |
@@ -3,32 +3,40 @@
|
||||

|
||||
|
||||
# Introduction
|
||||
llama-swap is an OpenAI API compatible server that gives you complete control over how you use your hardware. It automatically swaps to the configuration of your choice for serving a model. Since [llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models, let's swap the server instead!
|
||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||
|
||||
Features:
|
||||
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file).
|
||||
|
||||
Download a pre-built [release](https://github.com/mostlygeek/llama-swap/releases) or build it yourself from source with `make clean all`.
|
||||
|
||||
## How does it work?
|
||||
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If a server is already running it will stop it and start the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
||||
|
||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
|
||||
|
||||
## Do I need to use llama.cpp's server (llama-server)?
|
||||
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported. For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
|
||||
|
||||
## Features:
|
||||
|
||||
- ✅ Easy to deploy: single binary with no dependencies
|
||||
- ✅ Easy to config: single yaml file
|
||||
- ✅ On-demand model switching
|
||||
- ✅ Full control over server settings per model
|
||||
- ✅ OpenAI API support (`v1/completions`, `v1/chat/completions`, `v1/embeddings` and `v1/rerank`)
|
||||
- ✅ OpenAI API supported endpoints:
|
||||
- `v1/completions`
|
||||
- `v1/chat/completions`
|
||||
- `v1/embeddings`
|
||||
- `v1/rerank`
|
||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||
- ✅ Multiple GPU support
|
||||
- ✅ Docker Support ([#40](https://github.com/mostlygeek/llama-swap/pull/40))
|
||||
- ✅ Run multiple models at once with `profiles`
|
||||
- ✅ Remote log monitoring at `/log`
|
||||
- ✅ Automatic unloading of models from GPUs after timeout
|
||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabblyAPI, etc)
|
||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
||||
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||
|
||||
## Releases
|
||||
|
||||
Builds for Linux and OSX are available on the [Releases](https://github.com/mostlygeek/llama-swap/releases) page.
|
||||
|
||||
### Building from source
|
||||
|
||||
1. Install golang for your system
|
||||
1. `git clone git@github.com:mostlygeek/llama-swap.git`
|
||||
1. `make clean all`
|
||||
1. Binaries will be in `build/` subdirectory
|
||||
-
|
||||
|
||||
## config.yaml
|
||||
|
||||
@@ -83,6 +91,20 @@ models:
|
||||
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||
unlisted: true
|
||||
|
||||
# Docker Support (Experimental)
|
||||
# see: https://github.com/mostlygeek/llama-swap/pull/40
|
||||
"dockertest":
|
||||
proxy: "http://127.0.0.1:9790"
|
||||
|
||||
# introduced to reliably stop containers
|
||||
cmd_stop: docker stop -t 2 dockertest
|
||||
|
||||
cmd: >
|
||||
docker run --name dockertest
|
||||
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
|
||||
ghcr.io/ggerganov/llama.cpp:server
|
||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||
|
||||
# profiles make it easy to managing multi model (and gpu) configurations.
|
||||
#
|
||||
# Tips:
|
||||
@@ -95,19 +117,26 @@ profiles:
|
||||
- "llama"
|
||||
```
|
||||
|
||||
**Advanced examples**
|
||||
### Advanced Examples
|
||||
|
||||
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
|
||||
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
||||
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
||||
|
||||
## Installation
|
||||
### Installation
|
||||
|
||||
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
|
||||
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
||||
* _Note: Windows currently untested._
|
||||
1. Run the binary with `llama-swap --config path/to/config.yaml`
|
||||
|
||||
### Building from source
|
||||
|
||||
1. Install golang for your system
|
||||
1. `git clone git@github.com:mostlygeek/llama-swap.git`
|
||||
1. `make clean all`
|
||||
1. Binaries will be in `build/` subdirectory
|
||||
|
||||
## Monitoring Logs
|
||||
|
||||
Open the `http://<host>/logs` with your browser to get a web interface with streaming logs.
|
||||
|
||||
@@ -53,6 +53,21 @@ models:
|
||||
--ctx-size 8192
|
||||
--reranking
|
||||
|
||||
# EXPERIMENTAL! Docker Support
|
||||
# see:
|
||||
# - https://github.com/mostlygeek/llama-swap/pull/40
|
||||
# - https://github.com/mostlygeek/llama-swap/issues/35
|
||||
"dockertest":
|
||||
proxy: "http://127.0.0.1:9790"
|
||||
|
||||
# use this to reliably stop named containers
|
||||
cmd_stop: docker stop -t 2 dockertest
|
||||
|
||||
cmd: >
|
||||
docker run --name dockertest
|
||||
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
|
||||
ghcr.io/ggerganov/llama.cpp:server
|
||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||
|
||||
"simple":
|
||||
# example of setting environment variables
|
||||
|
||||
@@ -33,7 +33,7 @@ require (
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.31.0 // indirect
|
||||
golang.org/x/net v0.25.0 // indirect
|
||||
golang.org/x/net v0.33.0 // indirect
|
||||
golang.org/x/sys v0.28.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
|
||||
@@ -70,6 +70,8 @@ golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/proxy"
|
||||
@@ -39,6 +41,16 @@ func main() {
|
||||
}
|
||||
|
||||
proxyManager := proxy.New(config)
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigChan
|
||||
fmt.Println("Shutting down llama-swap")
|
||||
proxyManager.StopProcesses()
|
||||
os.Exit(0)
|
||||
}()
|
||||
|
||||
fmt.Println("llama-swap listening on " + *listenStr)
|
||||
if err := proxyManager.Run(*listenStr); err != nil {
|
||||
fmt.Printf("Server error: %v\n", err)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
type ModelConfig struct {
|
||||
Cmd string `yaml:"cmd"`
|
||||
CmdStop string `yaml:"cmd_stop"`
|
||||
Proxy string `yaml:"proxy"`
|
||||
Aliases []string `yaml:"aliases"`
|
||||
Env []string `yaml:"env"`
|
||||
@@ -22,6 +23,9 @@ type ModelConfig struct {
|
||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
return SanitizeCommand(m.Cmd)
|
||||
}
|
||||
func (m *ModelConfig) SanitizeCommandStop() ([]string, error) {
|
||||
return SanitizeCommand(m.CmdStop)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
|
||||
@@ -35,6 +35,11 @@ models:
|
||||
aliases:
|
||||
- "m2"
|
||||
checkEndpoint: "/"
|
||||
docker:
|
||||
cmd: docker run -p 9999:8080 --name "my_container"
|
||||
cmd_stop: docker stop my_container
|
||||
proxy: "http://localhost:9999"
|
||||
checkEndpoint: "/health"
|
||||
healthCheckTimeout: 15
|
||||
profiles:
|
||||
test:
|
||||
@@ -56,6 +61,7 @@ profiles:
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
@@ -63,11 +69,19 @@ profiles:
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: nil,
|
||||
CheckEndpoint: "/",
|
||||
},
|
||||
"docker": {
|
||||
Cmd: `docker run -p 9999:8080 --name "my_container"`,
|
||||
CmdStop: "docker stop my_container",
|
||||
Proxy: "http://localhost:9999",
|
||||
Env: nil,
|
||||
CheckEndpoint: "/health",
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
@@ -99,6 +113,18 @@ func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
||||
}
|
||||
|
||||
func TestConfig_ModelConfigSanitizedCommandStop(t *testing.T) {
|
||||
config := &ModelConfig{
|
||||
CmdStop: `docker stop my_container \
|
||||
--arg1 1
|
||||
--arg2 2`,
|
||||
}
|
||||
|
||||
args, err := config.SanitizeCommandStop()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"docker", "stop", "my_container", "--arg1", "1", "--arg2", "2"}, args)
|
||||
}
|
||||
|
||||
func TestConfig_FindConfig(t *testing.T) {
|
||||
|
||||
// TODO?
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>llama-swap</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>llama-swap</h1>
|
||||
<p>
|
||||
<a href="/logs">view logs</a> | <a href="/upstream">configured models</a> | <a href="https://github.com/mostlygeek/llama-swap">github</a>
|
||||
</p>
|
||||
</body>
|
||||
</html>
|
||||
+113
-21
@@ -12,42 +12,134 @@
|
||||
flex-direction: column;
|
||||
font-family: "Courier New", Courier, monospace;
|
||||
}
|
||||
#log-controls {
|
||||
margin: 0.5em;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between; /* Spaces out elements evenly */
|
||||
}
|
||||
#log-controls input {
|
||||
flex: 1;
|
||||
}
|
||||
#log-controls input:focus {
|
||||
outline: none; /* Ensures no outline is shown when the input is focused */
|
||||
}
|
||||
#log-stream {
|
||||
flex: 1;
|
||||
margin: 1em;
|
||||
padding: 10px;
|
||||
margin: 0.5em;
|
||||
padding: 1em;
|
||||
background: #f4f4f4;
|
||||
overflow-y: auto;
|
||||
white-space: pre-wrap; /* Ensures line wrapping */
|
||||
word-wrap: break-word; /* Ensures long words wrap */
|
||||
}
|
||||
|
||||
.regex-error {
|
||||
background-color: #ff0000 !important;
|
||||
}
|
||||
|
||||
/* Dark mode styles */
|
||||
@media (prefers-color-scheme: dark) {
|
||||
body {
|
||||
background-color: #333;
|
||||
color: #fff;
|
||||
}
|
||||
|
||||
#log-stream {
|
||||
background: #444;
|
||||
color: #fff;
|
||||
}
|
||||
|
||||
#log-controls input {
|
||||
background: #555;
|
||||
color: #fff;
|
||||
border: 1px solid #777;
|
||||
}
|
||||
|
||||
#log-controls button {
|
||||
background: #555;
|
||||
color: #fff;
|
||||
border: 1px solid #777;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<pre id="log-stream">Waiting for logs...
|
||||
</pre>
|
||||
|
||||
<pre id="log-stream">Waiting for logs...</pre>
|
||||
<div id="log-controls">
|
||||
<input type="text" id="filter-input" placeholder="regex filter">
|
||||
<button id="clear-button">clear</button>
|
||||
</div>
|
||||
<script>
|
||||
// Establish an EventSource connection to the SSE endpoint
|
||||
if (typeof(EventSource) !== "undefined") {
|
||||
const eventSource = new EventSource("/logs/streamSSE");
|
||||
const logStream = document.getElementById('log-stream');
|
||||
const filterInput = document.getElementById('filter-input');
|
||||
var logData = "";
|
||||
let regexFilter = null;
|
||||
|
||||
eventSource.onmessage = function(event) {
|
||||
// Append the new log message to the <pre> element
|
||||
const logStream = document.getElementById('log-stream');
|
||||
function setupEventSource() {
|
||||
if (typeof(EventSource) !== "undefined") {
|
||||
const eventSource = new EventSource("/logs/streamSSE");
|
||||
|
||||
logStream.textContent += event.data;
|
||||
eventSource.onmessage = function(event) {
|
||||
logData += event.data;
|
||||
render()
|
||||
};
|
||||
|
||||
// Auto-scroll to the bottom
|
||||
logStream.scrollTop = logStream.scrollHeight;
|
||||
};
|
||||
|
||||
eventSource.onerror = function(err) {
|
||||
console.error("EventSource failed:", err);
|
||||
};
|
||||
} else {
|
||||
console.error("SSE not supported by this browser.");
|
||||
eventSource.onerror = function(err) {
|
||||
logData = "EventSource failed: " + err.message;
|
||||
};
|
||||
} else {
|
||||
logData = "SSE Not supported by this browser."
|
||||
}
|
||||
}
|
||||
|
||||
// poor-ai's react ¯\_(ツ)_/¯
|
||||
function render() {
|
||||
if (regexFilter) {
|
||||
const lines = logData.split('\n');
|
||||
const filteredLines = lines.filter(line => {
|
||||
return regexFilter === null || regexFilter.test(line);
|
||||
});
|
||||
|
||||
if (filteredLines.length > 0) {
|
||||
logStream.textContent = filteredLines.join('\n') + '\n';
|
||||
} else {
|
||||
logStream.textContent = "";
|
||||
}
|
||||
} else {
|
||||
logStream.textContent = logData;
|
||||
}
|
||||
|
||||
logStream.scrollTop = logStream.scrollHeight;
|
||||
}
|
||||
|
||||
function updateFilter() {
|
||||
const pattern = filterInput.value.trim();
|
||||
filterInput.classList.remove('regex-error');
|
||||
if (pattern) {
|
||||
try {
|
||||
regexFilter = new RegExp(pattern);
|
||||
} catch (e) {
|
||||
console.error("Invalid regex pattern:", e);
|
||||
regexFilter = null;
|
||||
filterInput.classList.add('regex-error');
|
||||
return
|
||||
}
|
||||
} else {
|
||||
regexFilter = null;
|
||||
}
|
||||
|
||||
render();
|
||||
}
|
||||
|
||||
filterInput.addEventListener('input', updateFilter);
|
||||
document.getElementById('clear-button').addEventListener('click', () => {
|
||||
filterInput.value = "";
|
||||
regexFilter = null;
|
||||
render();
|
||||
});
|
||||
setupEventSource();
|
||||
updateFilter();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,10 @@
|
||||
package proxy
|
||||
|
||||
import "embed"
|
||||
|
||||
//go:embed html
|
||||
var htmlFiles embed.FS
|
||||
|
||||
func getHTMLFile(path string) ([]byte, error) {
|
||||
return htmlFiles.ReadFile("html/" + path)
|
||||
}
|
||||
+52
-20
@@ -135,6 +135,7 @@ func (p *Process) start() error {
|
||||
if time.Since(p.lastRequestHandled) > maxDuration {
|
||||
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
|
||||
p.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -152,12 +153,13 @@ func (p *Process) Stop() {
|
||||
defer p.stateMutex.Unlock()
|
||||
|
||||
if p.state != StateReady {
|
||||
fmt.Fprintf(p.logMonitor, "!!! Stop() called but Process State is not READY\n")
|
||||
return
|
||||
}
|
||||
|
||||
if p.cmd == nil || p.cmd.Process == nil {
|
||||
// this situation should never happen... but if it does just update the state
|
||||
fmt.Fprintf(p.logMonitor, "!!! State is Ready but Command is nil.")
|
||||
fmt.Fprintf(p.logMonitor, "!!! State is Ready but Command is nil.\n")
|
||||
p.state = StateStopped
|
||||
return
|
||||
}
|
||||
@@ -165,29 +167,59 @@ func (p *Process) Stop() {
|
||||
// Pretty sure this stopping code needs some work for windows and
|
||||
// will be a source of pain in the future.
|
||||
|
||||
p.cmd.Process.Signal(syscall.SIGTERM)
|
||||
sigtermTimeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
sigtermNormal := make(chan error, 1)
|
||||
go func() {
|
||||
sigtermNormal <- p.cmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-sigtermTimeout.Done():
|
||||
fmt.Fprintf(p.logMonitor, "!!! process for %s timed out waiting to stop\n", p.ID)
|
||||
p.cmd.Process.Kill()
|
||||
p.cmd.Wait()
|
||||
case err := <-sigtermNormal:
|
||||
if p.config.CmdStop != "" {
|
||||
// for issue #35 to do things like `docker stop`
|
||||
args, err := p.config.SanitizeCommandStop()
|
||||
if err != nil {
|
||||
if err.Error() != "wait: no child processes" {
|
||||
// possible that simple-responder for testing is just not
|
||||
// existing right, so suppress those errors.
|
||||
fmt.Fprintf(p.logMonitor, "!!! process for %s stopped with error > %v\n", p.ID, err)
|
||||
fmt.Fprintf(p.logMonitor, "!!! Error sanitizing stop command: %v\n", err)
|
||||
|
||||
// leave the state as it is?
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintf(p.logMonitor, "!!! Running stop command: %s\n", strings.Join(args, " "))
|
||||
cmd := exec.Command(args[0], args[1:]...)
|
||||
cmd.Stdout = p.logMonitor
|
||||
cmd.Stderr = p.logMonitor
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
fmt.Fprintf(p.logMonitor, "!!! Error running stop command: %v\n", err)
|
||||
|
||||
// leave the state as it is?
|
||||
return
|
||||
}
|
||||
|
||||
err = cmd.Wait()
|
||||
if err != nil {
|
||||
fmt.Fprintf(p.logMonitor, "!!! WARNING error waiting for stop command to complete: %v\n", err)
|
||||
}
|
||||
} else {
|
||||
sigtermTimeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
sigtermNormal := make(chan error, 1)
|
||||
go func() {
|
||||
sigtermNormal <- p.cmd.Wait()
|
||||
}()
|
||||
|
||||
p.cmd.Process.Signal(syscall.SIGTERM)
|
||||
|
||||
select {
|
||||
case <-sigtermTimeout.Done():
|
||||
fmt.Fprintf(p.logMonitor, "XXX Process for %s timed out waiting to stop, sending SIGKILL to PID: %d\n", p.ID, p.cmd.Process.Pid)
|
||||
p.cmd.Process.Kill()
|
||||
p.cmd.Wait()
|
||||
case err := <-sigtermNormal:
|
||||
if err != nil {
|
||||
if err.Error() != "wait: no child processes" {
|
||||
// possible that simple-responder for testing is just not
|
||||
// existing right, so suppress those errors.
|
||||
fmt.Fprintf(p.logMonitor, "!!! process for %s stopped with error > %v\n", p.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p.state = StateStopped
|
||||
}
|
||||
|
||||
|
||||
+28
-2
@@ -67,7 +67,6 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||
assert.Contains(t, w.Body.String(), "unable to start process")
|
||||
}
|
||||
|
||||
// test that the process unloads after the TTL
|
||||
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long auto unload TTL test")
|
||||
@@ -79,7 +78,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
config.UnloadAfter = 3 // seconds
|
||||
assert.Equal(t, 3, config.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard))
|
||||
process := NewProcess("ttl_test", 2, config, NewLogMonitorWriter(io.Discard))
|
||||
defer process.Stop()
|
||||
|
||||
// this should take 4 seconds
|
||||
@@ -111,6 +110,33 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
assert.Equal(t, StateStopped, process.CurrentState())
|
||||
}
|
||||
|
||||
func TestProcess_LowTTLValue(t *testing.T) {
|
||||
if true { // change this code to run this ...
|
||||
t.Skip("skipping test, edit process_test.go to run it ")
|
||||
}
|
||||
|
||||
config := getTestSimpleResponderConfig("fast_ttl")
|
||||
assert.Equal(t, 0, config.UnloadAfter)
|
||||
config.UnloadAfter = 1 // second
|
||||
assert.Equal(t, 1, config.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(os.Stdout))
|
||||
defer process.Stop()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
t.Logf("Waiting before sending request %d", i)
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
expected := fmt.Sprintf("echo=test_%d", i)
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=50ms", expected), nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), expected)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// issue #19
|
||||
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
if testing.Short() {
|
||||
|
||||
+45
-15
@@ -2,7 +2,6 @@ package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -20,15 +19,6 @@ const (
|
||||
PROFILE_SPLIT_CHAR = ":"
|
||||
)
|
||||
|
||||
//go:embed html/favicon.ico
|
||||
var faviconData []byte
|
||||
|
||||
//go:embed html/logs.html
|
||||
var logsHTML []byte
|
||||
|
||||
// make sure embed is kept there by the IDE auto-package importer
|
||||
var _ = embed.FS{}
|
||||
|
||||
type ProxyManager struct {
|
||||
sync.Mutex
|
||||
|
||||
@@ -51,16 +41,17 @@ func New(config *Config) *ProxyManager {
|
||||
// Start timer
|
||||
start := time.Now()
|
||||
|
||||
// capture these because /upstream/:model rewrites them in c.Next()
|
||||
clientIP := c.ClientIP()
|
||||
method := c.Request.Method
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// Process request
|
||||
c.Next()
|
||||
|
||||
// Stop timer
|
||||
duration := time.Since(start)
|
||||
|
||||
// Log request details
|
||||
clientIP := c.ClientIP()
|
||||
method := c.Request.Method
|
||||
path := c.Request.URL.Path
|
||||
statusCode := c.Writer.Status()
|
||||
bodySize := c.Writer.Size()
|
||||
|
||||
@@ -87,6 +78,9 @@ func New(config *Config) *ProxyManager {
|
||||
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
|
||||
|
||||
// Support audio/speech endpoint
|
||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
||||
|
||||
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
||||
|
||||
// in proxymanager_loghandlers.go
|
||||
@@ -97,8 +91,29 @@ func New(config *Config) *ProxyManager {
|
||||
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
|
||||
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
||||
|
||||
pm.ginEngine.GET("/", func(c *gin.Context) {
|
||||
// Set the Content-Type header to text/html
|
||||
c.Header("Content-Type", "text/html")
|
||||
|
||||
// Write the embedded HTML content to the response
|
||||
htmlData, err := getHTMLFile("index.html")
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
_, err = c.Writer.Write(htmlData)
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
|
||||
c.Data(http.StatusOK, "image/x-icon", faviconData)
|
||||
if data, err := getHTMLFile("favicon.ico"); err == nil {
|
||||
c.Data(http.StatusOK, "image/x-icon", data)
|
||||
} else {
|
||||
c.String(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
})
|
||||
|
||||
// Disable console color for testing
|
||||
@@ -187,6 +202,21 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
||||
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
|
||||
}
|
||||
|
||||
// check if model is part of the profile
|
||||
if profileName != "" {
|
||||
found := false
|
||||
for _, item := range pm.config.Profiles[profileName] {
|
||||
if item == realModelName {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil, fmt.Errorf("model %s part of profile %s", realModelName, profileName)
|
||||
}
|
||||
}
|
||||
|
||||
// exit early when already running, otherwise stop everything and swap
|
||||
requestedProcessKey := ProcessKeyName(profileName, realModelName)
|
||||
|
||||
|
||||
@@ -16,9 +16,14 @@ func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/html")
|
||||
|
||||
// Write the embedded HTML content to the response
|
||||
_, err := c.Writer.Write(logsHTML)
|
||||
logsHTML, err := getHTMLFile("logs.html")
|
||||
if err != nil {
|
||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to write response: %v", err))
|
||||
c.String(http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
_, err = c.Writer.Write(logsHTML)
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
@@ -43,7 +48,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
notify := c.Request.Context().Done()
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("Streaming unsupported"))
|
||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -53,11 +58,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
if !skipHistory {
|
||||
history := pm.logMonitor.GetHistory()
|
||||
if len(history) != 0 {
|
||||
_, err := c.Writer.Write(history)
|
||||
if err != nil {
|
||||
c.AbortWithError(http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
c.Writer.Write(history)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
@@ -68,7 +69,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
case msg := <-ch:
|
||||
_, err := c.Writer.Write(msg)
|
||||
if err != nil {
|
||||
c.AbortWithError(http.StatusInternalServerError, err)
|
||||
// just break the loop if we can't write for some reason
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
@@ -210,3 +210,47 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
// Ensure all expected models were returned
|
||||
assert.Empty(t, expectedModels, "not all expected models were returned")
|
||||
}
|
||||
|
||||
func TestProxyManager_ProfileNonMember(t *testing.T) {
|
||||
|
||||
model1 := "path1/model1"
|
||||
model2 := "path2/model2"
|
||||
|
||||
profileMemberName := ProcessKeyName("test", model1)
|
||||
profileNonMemberName := ProcessKeyName("test", model2)
|
||||
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
model1: getTestSimpleResponderConfig("model1"),
|
||||
model2: getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
Profiles: map[string][]string{
|
||||
"test": {model1},
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
// actual member of profile
|
||||
{
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileMemberName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "model1")
|
||||
}
|
||||
|
||||
// actual model, but non-member will 404
|
||||
{
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileNonMemberName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user