Compare commits

...

29 Commits

Author SHA1 Message Date
Benson Wong 85cd74a51c Improve process start and stop reliability (#38)
Refactor Process.start()/Stop() logic (#38)
- remove cmd.Wait() call in start(). This seems to conflict with the one
  in .Stop(). Removing it eliminated no child errors
- eliminate goroutines in .start() as it no longer required
2025-02-03 11:50:38 -08:00
Benson Wong 314d2f2212 remove cmd_stop configuration and functionality from PR #40 (#44)
* remove cmd_stop functionality from #40
2025-01-31 12:42:44 -08:00
Benson Wong fad25f3e11 Use client request context in proxy request (#43)
Canceled or closed HTTP requests from clients will also stop the proxied
HTTP requests to upstreamed servers.
2025-01-31 10:21:49 -08:00
Benson Wong 2c3e3e27f7 Support OPTIONS requests (#42)
Add middleware that responds with permissive OPTIONS headers
for all request paths.
2025-01-31 10:09:07 -08:00
Benson Wong baeb0c4e7f Add cmd_stop configuration to better support docker (#35)
Add `cmd_stop` to model configuration to run a command instead of sending a SIGTERM to shutdown a process before swapping.
2025-01-30 16:59:57 -08:00
Benson Wong 2833517eef Improve handling of process that do not handle SIGTERM (#38)
- Process TTL goroutine did not have a return after .Stop()
- Improve logging
- Add test TestProcess_LowTTLValue to measure SIGTERM error rate
2025-01-20 14:39:52 -08:00
Benson Wong abdc2bfdb3 Fix panic when requesting non-members of profiles
A panic occurs when a request for an invalid profile:model pair is made.
The edge case is that the profile exists and the model exists but they're
not configured as a pair.

This adds an additional check to make sure the profile:model pair is
valid before attempting to swap the model.
2025-01-16 12:06:38 -08:00
Benson Wong c3b834737f Update README.md 2025-01-13 22:37:30 -08:00
Benson Wong 3c8e727b73 Update README.md 2025-01-12 19:48:35 -08:00
Benson Wong 3a1e9f81f1 support TTS /v1/audio/speech (#36) 2025-01-12 16:27:01 -08:00
Benson Wong 72c883f36c Update README.md 2025-01-02 09:01:51 -08:00
Benson Wong 1b04d034cf Update README.md 2025-01-02 08:59:11 -08:00
Benson Wong 2e45f5692a Update README.md
Improve README documentation.
2025-01-01 12:51:24 -08:00
Benson Wong c97b80bdfe Update README.md 2025-01-01 12:25:45 -08:00
Benson Wong ae3ef9bc39 Refactor UI (#33)
- add html to / instead of 404
- add client side regex to /logs
2024-12-23 19:48:59 -08:00
Benson Wong db6715bec3 update golang.org/x/net -> v0.33.0 for dependabot 2024-12-20 11:28:32 -08:00
Benson Wong da5d9e8a6a fix HTTP logging so true path is printed 2024-12-20 11:25:01 -08:00
Benson Wong 84b667ca7a improve logging and error reporting for troubleshooting 2024-12-20 10:46:56 -08:00
Benson Wong 29657106fc add more OpenAI API supported in README 2024-12-20 10:08:20 -08:00
Benson Wong 9c8860471e support v1/rerank endpoint 2024-12-17 21:22:25 -08:00
Benson Wong 9b4e3f307e rename proxy handler 2024-12-17 17:25:10 -08:00
Benson Wong 6fe37c3abf support /v1/embeddings (#4) 2024-12-17 17:25:10 -08:00
Benson Wong 7f45493a37 Update README.md 2024-12-17 14:45:41 -08:00
Benson Wong 891f6a5b5a Add /upstream endpoint (#30)
* remove catch-all route to upstream proxy (it was broken anyways)
* add /upstream/:model_id to swap and route to upstream path
* add /upstream HTML endpoint and unlisted option
* add /upstream endpoint to show a list of available models
* add `unlisted` configuration option to omit a model from /v1/models and /upstream lists
* add favicon.ico
2024-12-17 14:37:44 -08:00
Benson Wong 7183f6b43d fix bad logging due to wrong []byte used #28 2024-12-16 16:22:14 -08:00
Benson Wong d89bfeb441 add .DS_Store to .gitignore 2024-12-16 12:30:31 -08:00
Benson Wong 9a0c6bed40 Improve stop exceptions (#28) (#29)
Stop Process TTL goroutine when process is not ready (#28)

- fix issue where the goroutine will continue even though the child
  process is no longer running and the Process' state is not Ready
- fix issue where some logs were going to stdout instead of p.logMonitor
  causing them to not show up in the /logs
- add units to unloading model message
2024-12-16 12:29:25 -08:00
Benson Wong d6ca535939 tweak release tagging so it is not based on number of commits 2024-12-14 15:46:10 -08:00
Benson Wong 27302c0c02 change llama-swap to use goreleaser default ldflag values 2024-12-14 10:30:06 -08:00
21 changed files with 564 additions and 152 deletions
+2 -1
View File
@@ -2,4 +2,5 @@
.env .env
build/ build/
dist/ dist/
.vscode .vscode
.DS_Store
+10 -7
View File
@@ -9,8 +9,8 @@ ifneq ($(shell git status --porcelain),)
GIT_HASH := $(GIT_HASH)+ GIT_HASH := $(GIT_HASH)+
endif endif
# Get the build number from the commit count on the main branch # Capture the current build date in RFC3339 format
COMMIT_COUNT := $(shell git rev-list --count HEAD) BUILD_DATE := $(shell date -u +"%Y-%m-%dT%H:%M:%SZ")
# Default target: Builds binaries for both OSX and Linux # Default target: Builds binaries for both OSX and Linux
all: mac linux simple-responder all: mac linux simple-responder
@@ -28,12 +28,12 @@ test-all:
# Build OSX binary # Build OSX binary
mac: mac:
@echo "Building Mac binary..." @echo "Building Mac binary..."
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.GIT_HASH=${GIT_HASH} -X main.COMMIT_COUNT=${COMMIT_COUNT}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64 GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
# Build Linux binary # Build Linux binary
linux: linux:
@echo "Building Linux binary..." @echo "Building Linux binary..."
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.GIT_HASH=${GIT_HASH} -X main.COMMIT_COUNT=${COMMIT_COUNT}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64 GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
# for testing proxy.Process # for testing proxy.Process
simple-responder: simple-responder:
@@ -52,9 +52,12 @@ release:
echo "Error: There are unstaged changes. Please commit or stash your changes before creating a release tag." >&2; \ echo "Error: There are unstaged changes. Please commit or stash your changes before creating a release tag." >&2; \
exit 1; \ exit 1; \
fi fi
@echo "Creating release tag v$(COMMIT_COUNT)..."
git tag v$(COMMIT_COUNT) # Get the highest tag in v{number} format, increment it, and create a new tag
git push origin v$(COMMIT_COUNT) @highest_tag=$$(git tag --sort=-v:refname | grep -E '^v[0-9]+$$' | head -n 1 || echo "v0"); \
new_tag="v$$(( $${highest_tag#v} + 1 ))"; \
echo "tagging new version: $$new_tag"; \
git tag "$$new_tag";
# Phony targets # Phony targets
.PHONY: all clean osx linux .PHONY: all clean osx linux
+55 -18
View File
@@ -3,31 +3,39 @@
![llama-swap header image](header.jpeg) ![llama-swap header image](header.jpeg)
# Introduction # Introduction
llama-swap is an OpenAI API compatible server that gives you complete control over how you use your hardware. It automatically swaps to the configuration of your choice for serving a model. Since [llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models, let's swap the server instead! llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
Features: Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file).
Download a pre-built [release](https://github.com/mostlygeek/llama-swap/releases) or build it yourself from source with `make clean all`.
## How does it work?
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If a server is already running it will stop it and start the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
## Do I need to use llama.cpp's server (llama-server)?
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported. For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
## Features:
- ✅ Easy to deploy: single binary with no dependencies - ✅ Easy to deploy: single binary with no dependencies
-Single yaml configuration file -Easy to config: single yaml file
- ✅ On-demand model switching - ✅ On-demand model switching
- ✅ Full control over server settings per model - ✅ Full control over server settings per model
- ✅ OpenAI API support (`v1/completions` and `v1/chat/completions`) - ✅ OpenAI API supported endpoints:
- `v1/completions`
- `v1/chat/completions`
- `v1/embeddings`
- `v1/rerank`
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
- ✅ Multiple GPU support - ✅ Multiple GPU support
- ✅ Docker and Podman support
- ✅ Run multiple models at once with `profiles` - ✅ Run multiple models at once with `profiles`
- ✅ Remote log monitoring at `/log` - ✅ Remote log monitoring at `/log`
- ✅ Automatic unloading of models from GPUs after timeout - ✅ Automatic unloading of models from GPUs after timeout
- ✅ Use any local server that provides an OpenAI compatible API (llama.cpp, vllm, tabblyAPI, etc) - ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
## Releases
Builds for Linux and OSX are available on the [Releases](https://github.com/mostlygeek/llama-swap/releases) page.
### Building from source
1. Install golang for your system
1. `git clone git@github.com:mostlygeek/llama-swap.git`
1. `make clean all`
1. Binaries will be in `build/` subdirectory
## config.yaml ## config.yaml
@@ -38,6 +46,9 @@ llama-swap's configuration is purposefully simple.
# Default (and minimum) is 15 seconds # Default (and minimum) is 15 seconds
healthCheckTimeout: 60 healthCheckTimeout: 60
# Write HTTP logs (useful for troubleshooting), defaults to false
logRequests: true
# define valid model values and the upstream server start # define valid model values and the upstream server start
models: models:
"llama": "llama":
@@ -73,6 +84,21 @@ models:
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf --model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
proxy: http://127.0.0.1:8999 proxy: http://127.0.0.1:8999
# unlisted models do not show up in /v1/models or /upstream lists
# but they can still be requested as normal
"qwen-unlisted":
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
unlisted: true
# Docker Support (v26.1.4+ required!)
"docker-llama":
proxy: "http://127.0.0.1:9790"
cmd: >
docker run --name dockertest
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
ghcr.io/ggerganov/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
# profiles make it easy to managing multi model (and gpu) configurations. # profiles make it easy to managing multi model (and gpu) configurations.
# #
# Tips: # Tips:
@@ -85,15 +111,26 @@ profiles:
- "llama" - "llama"
``` ```
More [examples](examples/README.md) are available for different use cases. ### Advanced Examples
## Installation - [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
### Installation
1. Create a configuration file, see [config.example.yaml](config.example.yaml) 1. Create a configuration file, see [config.example.yaml](config.example.yaml)
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture. 1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
* _Note: Windows currently untested._ * _Note: Windows currently untested._
1. Run the binary with `llama-swap --config path/to/config.yaml` 1. Run the binary with `llama-swap --config path/to/config.yaml`
### Building from source
1. Install golang for your system
1. `git clone git@github.com:mostlygeek/llama-swap.git`
1. `make clean all`
1. Binaries will be in `build/` subdirectory
## Monitoring Logs ## Monitoring Logs
Open the `http://<host>/logs` with your browser to get a web interface with streaming logs. Open the `http://<host>/logs` with your browser to get a web interface with streaming logs.
+39
View File
@@ -2,6 +2,9 @@
# Default (and minimum): 15 seconds # Default (and minimum): 15 seconds
healthCheckTimeout: 15 healthCheckTimeout: 15
# Log HTTP requests helpful for troubleshoot, defaults to False
logRequests: true
models: models:
"llama": "llama":
cmd: > cmd: >
@@ -26,6 +29,39 @@ models:
aliases: aliases:
- gpt-3.5-turbo - gpt-3.5-turbo
# Embedding example with Nomic
# https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF
"nomic":
proxy: http://127.0.0.1:9005
cmd: >
models/llama-server-osx --port 9005
-m models/nomic-embed-text-v1.5.Q8_0.gguf
--ctx-size 8192
--batch-size 8192
--rope-scaling yarn
--rope-freq-scale 0.75
-ngl 99
--embeddings
# Reranking example with bge-reranker
# https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF
"bge-reranker":
proxy: http://127.0.0.1:9006
cmd: >
models/llama-server-osx --port 9006
-m models/bge-reranker-v2-m3-Q4_K_M.gguf
--ctx-size 8192
--reranking
# Docker Support (v26.1.4+ required!)
"dockertest":
proxy: "http://127.0.0.1:9790"
cmd: >
docker run --name dockertest
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
ghcr.io/ggerganov/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
"simple": "simple":
# example of setting environment variables # example of setting environment variables
env: env:
@@ -33,6 +69,7 @@ models:
- env1=hello - env1=hello
cmd: build/simple-responder --port 8999 cmd: build/simple-responder --port 8999
proxy: http://127.0.0.1:8999 proxy: http://127.0.0.1:8999
unlisted: true
# use "none" to skip check. Caution this may cause some requests to fail # use "none" to skip check. Caution this may cause some requests to fail
# until the upstream server is ready for traffic # until the upstream server is ready for traffic
@@ -42,9 +79,11 @@ models:
"broken": "broken":
cmd: models/llama-server-osx --port 8999 -m models/doesnotexist.gguf cmd: models/llama-server-osx --port 8999 -m models/doesnotexist.gguf
proxy: http://127.0.0.1:8999 proxy: http://127.0.0.1:8999
unlisted: true
"broken_timeout": "broken_timeout":
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
proxy: http://127.0.0.1:9000 proxy: http://127.0.0.1:9000
unlisted: true
# creating a coding profile with models for code generation and general questions # creating a coding profile with models for code generation and general questions
profiles: profiles:
+1 -1
View File
@@ -33,7 +33,7 @@ require (
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.31.0 // indirect golang.org/x/crypto v0.31.0 // indirect
golang.org/x/net v0.25.0 // indirect golang.org/x/net v0.33.0 // indirect
golang.org/x/sys v0.28.0 // indirect golang.org/x/sys v0.28.0 // indirect
golang.org/x/text v0.21.0 // indirect golang.org/x/text v0.21.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect google.golang.org/protobuf v1.34.1 // indirect
+2
View File
@@ -70,6 +70,8 @@ golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
+16 -4
View File
@@ -4,14 +4,16 @@ import (
"flag" "flag"
"fmt" "fmt"
"os" "os"
"os/signal"
"syscall"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/proxy" "github.com/mostlygeek/llama-swap/proxy"
) )
// see Makefile which injects new values at build time var version string = "0"
var GIT_HASH string = "abcd1234" var commit string = "abcd1234"
var COMMIT_COUNT string = "0-dev" var date = "unknown"
func main() { func main() {
// Define a command-line flag for the port // Define a command-line flag for the port
@@ -22,7 +24,7 @@ func main() {
flag.Parse() // Parse the command-line flags flag.Parse() // Parse the command-line flags
if *showVersion { if *showVersion {
fmt.Printf("version: v%s (%s)\n", COMMIT_COUNT, GIT_HASH) fmt.Printf("version: %s (%s), built at %s\n", version, commit, date)
os.Exit(0) os.Exit(0)
} }
@@ -39,6 +41,16 @@ func main() {
} }
proxyManager := proxy.New(config) proxyManager := proxy.New(config)
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigChan
fmt.Println("Shutting down llama-swap")
proxyManager.StopProcesses()
os.Exit(0)
}()
fmt.Println("llama-swap listening on " + *listenStr) fmt.Println("llama-swap listening on " + *listenStr)
if err := proxyManager.Run(*listenStr); err != nil { if err := proxyManager.Run(*listenStr); err != nil {
fmt.Printf("Server error: %v\n", err) fmt.Printf("Server error: %v\n", err)
Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

+4
View File
@@ -0,0 +1,4 @@
The rerank-test.json data is from https://github.com/ggerganov/llama.cpp/pull/9510
To run it:
> curl http://127.0.0.1:8080/v1/rerank -H "Content-Type: application/json" -d @reranker-test.json -v | jq .
+17
View File
@@ -0,0 +1,17 @@
{
"model": "bge-reranker",
"query": "Organic skincare products for sensitive skin",
"top_n": 3,
"documents": [
"Organic skincare for sensitive skin with aloe vera and chamomile: Imagine the soothing embrace of nature with our organic skincare range, crafted specifically for sensitive skin. Infused with the calming properties of aloe vera and chamomile, each product provides gentle nourishment and protection. Say goodbye to irritation and hello to a glowing, healthy complexion.",
"New makeup trends focus on bold colors and innovative techniques: Step into the world of cutting-edge beauty with this seasons makeup trends. Bold, vibrant colors and groundbreaking techniques are redefining the art of makeup. From neon eyeliners to holographic highlighters, unleash your creativity and make a statement with every look.",
"Bio-Hautpflege für empfindliche Haut mit Aloe Vera und Kamille: Erleben Sie die wohltuende Wirkung unserer Bio-Hautpflege, speziell für empfindliche Haut entwickelt. Mit den beruhigenden Eigenschaften von Aloe Vera und Kamille pflegen und schützen unsere Produkte Ihre Haut auf natürliche Weise. Verabschieden Sie sich von Hautirritationen und genießen Sie einen strahlenden Teint.",
"Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken: Tauchen Sie ein in die Welt der modernen Schönheit mit den neuesten Make-up-Trends. Kräftige, lebendige Farben und innovative Techniken setzen neue Maßstäbe. Von auffälligen Eyelinern bis hin zu holografischen Highlightern lassen Sie Ihrer Kreativität freien Lauf und setzen Sie jedes Mal ein Statement.",
"Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla: Descubre el poder de la naturaleza con nuestra línea de cuidado de la piel orgánico, diseñada especialmente para pieles sensibles. Enriquecidos con aloe vera y manzanilla, estos productos ofrecen una hidratación y protección suave. Despídete de las irritaciones y saluda a una piel radiante y saludable.",
"Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras: Entra en el fascinante mundo del maquillaje con las tendencias más actuales. Colores vivos y técnicas innovadoras están revolucionando el arte del maquillaje. Desde delineadores neón hasta iluminadores holográficos, desata tu creatividad y destaca en cada look.",
"针对敏感肌专门设计的天然有机护肤产品:体验由芦荟和洋甘菊提取物带来的自然呵护。我们的护肤产品特别为敏感肌设计,温和滋润,保护您的肌肤不受刺激。让您的肌肤告别不适,迎来健康光彩。",
"新的化妆趋势注重鲜艳的颜色和创新的技巧:进入化妆艺术的新纪元,本季的化妆趋势以大胆的颜色和创新的技巧为主。无论是霓虹眼线还是全息高光,每一款妆容都能让您脱颖而出,展现独特魅力。",
"敏感肌のために特別に設計された天然有機スキンケア製品: アロエベラとカモミールのやさしい力で、自然の抱擁を感じてください。敏感肌用に特別に設計された私たちのスキンケア製品は、肌に優しく栄養を与え、保護します。肌トラブルにさようなら、輝く健康な肌にこんにちは。",
"新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています: 今シーズンのメイクアップトレンドは、大胆な色彩と革新的な技術に注目しています。ネオンアイライナーからホログラフィックハイライターまで、クリエイティビティを解き放ち、毎回ユニークなルックを演出しましょう。"
]
}
+2
View File
@@ -16,6 +16,7 @@ type ModelConfig struct {
Env []string `yaml:"env"` Env []string `yaml:"env"`
CheckEndpoint string `yaml:"checkEndpoint"` CheckEndpoint string `yaml:"checkEndpoint"`
UnloadAfter int `yaml:"ttl"` UnloadAfter int `yaml:"ttl"`
Unlisted bool `yaml:"unlisted"`
} }
func (m *ModelConfig) SanitizedCommand() ([]string, error) { func (m *ModelConfig) SanitizedCommand() ([]string, error) {
@@ -24,6 +25,7 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
type Config struct { type Config struct {
HealthCheckTimeout int `yaml:"healthCheckTimeout"` HealthCheckTimeout int `yaml:"healthCheckTimeout"`
LogRequests bool `yaml:"logRequests"`
Models map[string]ModelConfig `yaml:"models"` Models map[string]ModelConfig `yaml:"models"`
Profiles map[string][]string `yaml:"profiles"` Profiles map[string][]string `yaml:"profiles"`
Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

+14
View File
@@ -0,0 +1,14 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>llama-swap</title>
</head>
<body>
<h1>llama-swap</h1>
<p>
<a href="/logs">view logs</a> | <a href="/upstream">configured models</a> | <a href="https://github.com/mostlygeek/llama-swap">github</a>
</p>
</body>
</html>
+113 -21
View File
@@ -12,42 +12,134 @@
flex-direction: column; flex-direction: column;
font-family: "Courier New", Courier, monospace; font-family: "Courier New", Courier, monospace;
} }
#log-controls {
margin: 0.5em;
display: flex;
align-items: center;
justify-content: space-between; /* Spaces out elements evenly */
}
#log-controls input {
flex: 1;
}
#log-controls input:focus {
outline: none; /* Ensures no outline is shown when the input is focused */
}
#log-stream { #log-stream {
flex: 1; flex: 1;
margin: 1em; margin: 0.5em;
padding: 10px; padding: 1em;
background: #f4f4f4; background: #f4f4f4;
overflow-y: auto; overflow-y: auto;
white-space: pre-wrap; /* Ensures line wrapping */ white-space: pre-wrap; /* Ensures line wrapping */
word-wrap: break-word; /* Ensures long words wrap */ word-wrap: break-word; /* Ensures long words wrap */
} }
.regex-error {
background-color: #ff0000 !important;
}
/* Dark mode styles */
@media (prefers-color-scheme: dark) {
body {
background-color: #333;
color: #fff;
}
#log-stream {
background: #444;
color: #fff;
}
#log-controls input {
background: #555;
color: #fff;
border: 1px solid #777;
}
#log-controls button {
background: #555;
color: #fff;
border: 1px solid #777;
}
}
</style> </style>
</head> </head>
<body> <body>
<pre id="log-stream">Waiting for logs... <pre id="log-stream">Waiting for logs...</pre>
</pre> <div id="log-controls">
<input type="text" id="filter-input" placeholder="regex filter">
<button id="clear-button">clear</button>
</div>
<script> <script>
// Establish an EventSource connection to the SSE endpoint const logStream = document.getElementById('log-stream');
if (typeof(EventSource) !== "undefined") { const filterInput = document.getElementById('filter-input');
const eventSource = new EventSource("/logs/streamSSE"); var logData = "";
let regexFilter = null;
eventSource.onmessage = function(event) { function setupEventSource() {
// Append the new log message to the <pre> element if (typeof(EventSource) !== "undefined") {
const logStream = document.getElementById('log-stream'); const eventSource = new EventSource("/logs/streamSSE");
logStream.textContent += event.data; eventSource.onmessage = function(event) {
logData += event.data;
render()
};
// Auto-scroll to the bottom eventSource.onerror = function(err) {
logStream.scrollTop = logStream.scrollHeight; logData = "EventSource failed: " + err.message;
}; };
} else {
eventSource.onerror = function(err) { logData = "SSE Not supported by this browser."
console.error("EventSource failed:", err); }
};
} else {
console.error("SSE not supported by this browser.");
} }
// poor-ai's react ¯\_(ツ)_/¯
function render() {
if (regexFilter) {
const lines = logData.split('\n');
const filteredLines = lines.filter(line => {
return regexFilter === null || regexFilter.test(line);
});
if (filteredLines.length > 0) {
logStream.textContent = filteredLines.join('\n') + '\n';
} else {
logStream.textContent = "";
}
} else {
logStream.textContent = logData;
}
logStream.scrollTop = logStream.scrollHeight;
}
function updateFilter() {
const pattern = filterInput.value.trim();
filterInput.classList.remove('regex-error');
if (pattern) {
try {
regexFilter = new RegExp(pattern);
} catch (e) {
console.error("Invalid regex pattern:", e);
regexFilter = null;
filterInput.classList.add('regex-error');
return
}
} else {
regexFilter = null;
}
render();
}
filterInput.addEventListener('input', updateFilter);
document.getElementById('clear-button').addEventListener('click', () => {
filterInput.value = "";
regexFilter = null;
render();
});
setupEventSource();
updateFilter();
</script> </script>
</body> </body>
</html> </html>
+10
View File
@@ -0,0 +1,10 @@
package proxy
import "embed"
//go:embed html
var htmlFiles embed.FS
func getHTMLFile(path string) ([]byte, error) {
return htmlFiles.ReadFile("html/" + path)
}
+1 -1
View File
@@ -46,7 +46,7 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
w.buffer = w.buffer.Next() w.buffer = w.buffer.Next()
w.bufferMu.Unlock() w.bufferMu.Unlock()
w.broadcast(p) w.broadcast(bufferCopy)
return n, nil return n, nil
} }
+34 -60
View File
@@ -2,7 +2,6 @@ package proxy
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@@ -86,36 +85,11 @@ func (p *Process) start() error {
// 3. The health check passes // 3. The health check passes
// //
// only in the third case will the process be considered Ready to accept // only in the third case will the process be considered Ready to accept
healthCheckContext, cancelHealthCheck := context.WithCancelCause(context.Background()) <-time.After(250 * time.Millisecond) // give process a bit of time to start
defer cancelHealthCheck(nil) // clean up
cmdWaitChan := make(chan error, 1)
healthCheckChan := make(chan error, 1)
go func() { if err := p.checkHealthEndpoint(); err != nil {
// possible cmd exits early
cmdWaitChan <- p.cmd.Wait()
}()
go func() {
<-time.After(250 * time.Millisecond) // give process a bit of time to start
healthCheckChan <- p.checkHealthEndpoint(healthCheckContext)
}()
select {
case err := <-cmdWaitChan:
p.state = StateFailed p.state = StateFailed
if err != nil {
err = fmt.Errorf("command [%s] %s", strings.Join(p.cmd.Args, " "), err.Error())
} else {
err = fmt.Errorf("command [%s] exited unexpected", strings.Join(p.cmd.Args, " "))
}
cancelHealthCheck(err)
return err return err
case err := <-healthCheckChan:
if err != nil {
p.state = StateFailed
return err
}
} }
if p.config.UnloadAfter > 0 { if p.config.UnloadAfter > 0 {
@@ -125,12 +99,17 @@ func (p *Process) start() error {
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
for range time.Tick(time.Second) { for range time.Tick(time.Second) {
if p.state != StateReady {
return
}
// wait for all inflight requests to complete and ticker // wait for all inflight requests to complete and ticker
p.inFlightRequests.Wait() p.inFlightRequests.Wait()
if time.Since(p.lastRequestHandled) > maxDuration { if time.Since(p.lastRequestHandled) > maxDuration {
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %d reached.\n", p.ID, p.config.UnloadAfter) fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
p.Stop() p.Stop()
return
} }
} }
}() }()
@@ -148,42 +127,50 @@ func (p *Process) Stop() {
defer p.stateMutex.Unlock() defer p.stateMutex.Unlock()
if p.state != StateReady { if p.state != StateReady {
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() called but Process State is not READY\n")
return return
} }
if p.cmd == nil || p.cmd.Process == nil { if p.cmd == nil || p.cmd.Process == nil {
// this situation should never happen... but if it does just update the state // this situation should never happen... but if it does just update the state
fmt.Fprintf(p.logMonitor, "!!! State is Ready but Command is nil.") fmt.Fprintf(p.logMonitor, "!!! State is Ready but Command is nil.\n")
p.state = StateStopped p.state = StateStopped
return return
} }
// Pretty sure this stopping code needs some work for windows and sigtermTimeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
// will be a source of pain in the future.
p.cmd.Process.Signal(syscall.SIGTERM)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
done := make(chan error, 1) sigtermNormal := make(chan error, 1)
go func() { go func() {
done <- p.cmd.Wait() sigtermNormal <- p.cmd.Wait()
}() }()
p.cmd.Process.Signal(syscall.SIGTERM)
select { select {
case <-ctx.Done(): case <-sigtermTimeout.Done():
fmt.Printf("!!! process for %s timed out waiting to stop\n", p.ID) fmt.Fprintf(p.logMonitor, "!!! process [%s] timed out waiting to stop, sending KILL signal\n", p.ID)
p.cmd.Process.Kill() p.cmd.Process.Kill()
p.cmd.Wait() case err := <-sigtermNormal:
case err := <-done:
if err != nil { if err != nil {
if err.Error() != "wait: no child processes" { if errno, ok := err.(syscall.Errno); ok {
// possible that simple-responder for testing is just not fmt.Fprintf(p.logMonitor, "!!! process [%s] errno >> %v\n", p.ID, errno)
// existing right, so suppress those errors. } else if exitError, ok := err.(*exec.ExitError); ok {
fmt.Printf("!!! process for %s stopped with error > %v\n", p.ID, err) if strings.Contains(exitError.String(), "signal: terminated") {
fmt.Fprintf(p.logMonitor, "!!! process [%s] stopped OK\n", p.ID)
} else if strings.Contains(exitError.String(), "signal: interrupt") {
fmt.Fprintf(p.logMonitor, "!!! process [%s] interrupted OK\n", p.ID)
} else {
fmt.Fprintf(p.logMonitor, "!!! process [%s] ExitError >> %v, exit code: %d\n", p.ID, exitError, exitError.ExitCode())
}
} else {
fmt.Fprintf(p.logMonitor, "!!! process [%s] exited >> %v\n", p.ID, err)
} }
} }
} }
p.state = StateStopped p.state = StateStopped
} }
@@ -193,7 +180,7 @@ func (p *Process) CurrentState() ProcessState {
return p.state return p.state
} }
func (p *Process) checkHealthEndpoint(ctxFromStart context.Context) error { func (p *Process) checkHealthEndpoint() error {
if p.config.Proxy == "" { if p.config.Proxy == "" {
return fmt.Errorf("no upstream available to check /health") return fmt.Errorf("no upstream available to check /health")
} }
@@ -225,24 +212,11 @@ func (p *Process) checkHealthEndpoint(ctxFromStart context.Context) error {
return err return err
} }
ctx, cancel := context.WithTimeout(ctxFromStart, time.Second)
defer cancel()
req = req.WithContext(ctx)
resp, err := client.Do(req) resp, err := client.Do(req)
ttl := (maxDuration - time.Since(startTime)).Seconds() ttl := (maxDuration - time.Since(startTime)).Seconds()
if err != nil { if err != nil {
// check if the context was cancelled
select {
case <-ctx.Done():
err := context.Cause(ctx)
if !errors.Is(err, context.DeadlineExceeded) {
return err
}
default:
}
// wait a bit longer for TCP connection issues // wait a bit longer for TCP connection issues
if strings.Contains(err.Error(), "connection refused") { if strings.Contains(err.Error(), "connection refused") {
fmt.Fprintf(p.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl) fmt.Fprintf(p.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl)
@@ -290,7 +264,7 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
proxyTo := p.config.Proxy proxyTo := p.config.Proxy
client := &http.Client{} client := &http.Client{}
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body) req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyTo+r.URL.String(), r.Body)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
+28 -2
View File
@@ -67,7 +67,6 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
assert.Contains(t, w.Body.String(), "unable to start process") assert.Contains(t, w.Body.String(), "unable to start process")
} }
// test that the process unloads after the TTL
func TestProcess_UnloadAfterTTL(t *testing.T) { func TestProcess_UnloadAfterTTL(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("skipping long auto unload TTL test") t.Skip("skipping long auto unload TTL test")
@@ -79,7 +78,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
config.UnloadAfter = 3 // seconds config.UnloadAfter = 3 // seconds
assert.Equal(t, 3, config.UnloadAfter) assert.Equal(t, 3, config.UnloadAfter)
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard)) process := NewProcess("ttl_test", 2, config, NewLogMonitorWriter(io.Discard))
defer process.Stop() defer process.Stop()
// this should take 4 seconds // this should take 4 seconds
@@ -111,6 +110,33 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
assert.Equal(t, StateStopped, process.CurrentState()) assert.Equal(t, StateStopped, process.CurrentState())
} }
func TestProcess_LowTTLValue(t *testing.T) {
if true { // change this code to run this ...
t.Skip("skipping test, edit process_test.go to run it ")
}
config := getTestSimpleResponderConfig("fast_ttl")
assert.Equal(t, 0, config.UnloadAfter)
config.UnloadAfter = 1 // second
assert.Equal(t, 1, config.UnloadAfter)
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(os.Stdout))
defer process.Stop()
for i := 0; i < 100; i++ {
t.Logf("Waiting before sending request %d", i)
time.Sleep(1500 * time.Millisecond)
expected := fmt.Sprintf("echo=test_%d", i)
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=50ms", expected), nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), expected)
}
}
// issue #19 // issue #19
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) { func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
if testing.Short() { if testing.Short() {
+162 -21
View File
@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -35,11 +36,63 @@ func New(config *Config) *ProxyManager {
ginEngine: gin.New(), ginEngine: gin.New(),
} }
// Set up routes using the Gin engine if config.LogRequests {
pm.ginEngine.POST("/v1/chat/completions", pm.proxyChatRequestHandler) pm.ginEngine.Use(func(c *gin.Context) {
// Start timer
start := time.Now()
// capture these because /upstream/:model rewrites them in c.Next()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// Process request
c.Next()
// Stop timer
duration := time.Since(start)
statusCode := c.Writer.Status()
bodySize := c.Writer.Size()
fmt.Fprintf(pm.logMonitor, "[llama-swap] %s [%s] \"%s %s %s\" %d %d \"%s\" %v\n",
clientIP,
time.Now().Format("2006-01-02 15:04:05"),
method,
path,
c.Request.Proto,
statusCode,
bodySize,
c.Request.UserAgent(),
duration,
)
})
}
// see: https://github.com/mostlygeek/llama-swap/issues/42
// respond with permissive OPTIONS for any endpoint
pm.ginEngine.Use(func(c *gin.Context) {
if c.Request.Method == "OPTIONS" {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
c.AbortWithStatus(204)
return
}
c.Next()
})
// Set up routes using the Gin engine
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
// Support legacy /v1/completions api, see issue #12 // Support legacy /v1/completions api, see issue #12
pm.ginEngine.POST("/v1/completions", pm.proxyChatRequestHandler) pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler)
// Support embeddings
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
// Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
pm.ginEngine.GET("/v1/models", pm.listModelsHandler) pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
@@ -48,7 +101,33 @@ func New(config *Config) *ProxyManager {
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler) pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE) pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
pm.ginEngine.NoRoute(pm.proxyNoRouteHandler) pm.ginEngine.GET("/upstream", pm.upstreamIndex)
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
pm.ginEngine.GET("/", func(c *gin.Context) {
// Set the Content-Type header to text/html
c.Header("Content-Type", "text/html")
// Write the embedded HTML content to the response
htmlData, err := getHTMLFile("index.html")
if err != nil {
c.String(http.StatusInternalServerError, err.Error())
return
}
_, err = c.Writer.Write(htmlData)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
return
}
})
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
if data, err := getHTMLFile("favicon.ico"); err == nil {
c.Data(http.StatusOK, "image/x-icon", data)
} else {
c.String(http.StatusInternalServerError, err.Error())
}
})
// Disable console color for testing // Disable console color for testing
gin.DisableConsoleColor() gin.DisableConsoleColor()
@@ -86,7 +165,11 @@ func (pm *ProxyManager) stopProcesses() {
func (pm *ProxyManager) listModelsHandler(c *gin.Context) { func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
data := []interface{}{} data := []interface{}{}
for id := range pm.config.Models { for id, modelConfig := range pm.config.Models {
if modelConfig.Unlisted {
continue
}
data = append(data, map[string]interface{}{ data = append(data, map[string]interface{}{
"id": id, "id": id,
"object": "model", "object": "model",
@@ -104,7 +187,7 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
// Encode the data as JSON and write it to the response writer // Encode the data as JSON and write it to the response writer
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil { if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("error encoding JSON")) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error encoding JSON %s", err.Error()))
return return
} }
} }
@@ -113,7 +196,7 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
pm.Lock() pm.Lock()
defer pm.Unlock() defer pm.Unlock()
// Check if requestedModel contains a / // Check if requestedModel contains a PROFILE_SPLIT_CHAR
profileName, modelName := "", requestedModel profileName, modelName := "", requestedModel
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 { if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
profileName = requestedModel[:idx] profileName = requestedModel[:idx]
@@ -132,6 +215,21 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
return nil, fmt.Errorf("could not find modelID for %s", requestedModel) return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
} }
// check if model is part of the profile
if profileName != "" {
found := false
for _, item := range pm.config.Profiles[profileName] {
if item == realModelName {
found = true
break
}
}
if !found {
return nil, fmt.Errorf("model %s part of profile %s", realModelName, profileName)
}
}
// exit early when already running, otherwise stop everything and swap // exit early when already running, otherwise stop everything and swap
requestedProcessKey := ProcessKeyName(profileName, realModelName) requestedProcessKey := ProcessKeyName(profileName, realModelName)
@@ -170,25 +268,68 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
return pm.currentProcesses[requestedProcessKey], nil return pm.currentProcesses[requestedProcessKey], nil
} }
func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) { func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body) requestedModel := c.Param("model_id")
if err != nil {
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("invalid JSON")) if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
return return
} }
if process, err := pm.swapModel(requestedModel); err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
} else {
// rewrite the path
c.Request.URL.Path = c.Param("upstreamPath")
process.ProxyRequest(c.Writer, c.Request)
}
}
func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
var html strings.Builder
html.WriteString("<!doctype HTML>\n<html><body><h1>Available Models</h1><ul>")
// Extract keys and sort them
var modelIDs []string
for modelID, modelConfig := range pm.config.Models {
if modelConfig.Unlisted {
continue
}
modelIDs = append(modelIDs, modelID)
}
sort.Strings(modelIDs)
// Iterate over sorted keys
for _, modelID := range modelIDs {
html.WriteString(fmt.Sprintf("<li><a href=\"/upstream/%s\">%s</a></li>", modelID, modelID))
}
html.WriteString("</ul></body></html>")
c.Header("Content-Type", "text/html")
c.String(http.StatusOK, html.String())
}
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
return
}
var requestBody map[string]interface{} var requestBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil { if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("invalid JSON")) pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("invalid JSON: %s", err.Error()))
return return
} }
model, ok := requestBody["model"].(string) model, ok := requestBody["model"].(string)
if !ok { if !ok {
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("missing or invalid 'model' key")) pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
return return
} }
if process, err := pm.swapModel(model); err != nil { if process, err := pm.swapModel(model); err != nil {
c.AbortWithError(http.StatusNotFound, fmt.Errorf("unable to swap to model, %s", err.Error())) pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
return return
} else { } else {
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
@@ -201,14 +342,14 @@ func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) {
} }
} }
func (pm *ProxyManager) proxyNoRouteHandler(c *gin.Context) { func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
// since maps are unordered, just use the first available process if one exists acceptHeader := c.GetHeader("Accept")
for _, process := range pm.currentProcesses {
process.ProxyRequest(c.Writer, c.Request)
return
}
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("no strategy to handle request")) if strings.Contains(acceptHeader, "application/json") {
c.JSON(statusCode, gin.H{"error": message})
} else {
c.String(statusCode, message)
}
} }
func ProcessKeyName(groupName, modelName string) string { func ProcessKeyName(groupName, modelName string) string {
+10 -16
View File
@@ -1,7 +1,6 @@
package proxy package proxy
import ( import (
"embed"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
@@ -9,12 +8,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
//go:embed html/logs.html
var logsHTML []byte
// make sure embed is kept there by the IDE auto-package importer
var _ = embed.FS{}
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) { func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
accept := c.GetHeader("Accept") accept := c.GetHeader("Accept")
@@ -23,9 +16,14 @@ func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
c.Header("Content-Type", "text/html") c.Header("Content-Type", "text/html")
// Write the embedded HTML content to the response // Write the embedded HTML content to the response
_, err := c.Writer.Write(logsHTML) logsHTML, err := getHTMLFile("logs.html")
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to write response: %v", err)) c.String(http.StatusInternalServerError, err.Error())
return
}
_, err = c.Writer.Write(logsHTML)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
return return
} }
} else { } else {
@@ -50,7 +48,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
notify := c.Request.Context().Done() notify := c.Request.Context().Done()
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
if !ok { if !ok {
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("Streaming unsupported")) c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported"))
return return
} }
@@ -60,11 +58,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
if !skipHistory { if !skipHistory {
history := pm.logMonitor.GetHistory() history := pm.logMonitor.GetHistory()
if len(history) != 0 { if len(history) != 0 {
_, err := c.Writer.Write(history) c.Writer.Write(history)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}
flusher.Flush() flusher.Flush()
} }
} }
@@ -75,7 +69,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
case msg := <-ch: case msg := <-ch:
_, err := c.Writer.Write(msg) _, err := c.Writer.Write(msg)
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, err) // just break the loop if we can't write for some reason
return return
} }
flusher.Flush() flusher.Flush()
+44
View File
@@ -210,3 +210,47 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
// Ensure all expected models were returned // Ensure all expected models were returned
assert.Empty(t, expectedModels, "not all expected models were returned") assert.Empty(t, expectedModels, "not all expected models were returned")
} }
func TestProxyManager_ProfileNonMember(t *testing.T) {
model1 := "path1/model1"
model2 := "path2/model2"
profileMemberName := ProcessKeyName("test", model1)
profileNonMemberName := ProcessKeyName("test", model2)
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
model1: getTestSimpleResponderConfig("model1"),
model2: getTestSimpleResponderConfig("model2"),
},
Profiles: map[string][]string{
"test": {model1},
},
}
proxy := New(config)
defer proxy.StopProcesses()
// actual member of profile
{
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileMemberName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "model1")
}
// actual model, but non-member will 404
{
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileNonMemberName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}
}